1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, FnArg, ItemFn, LitStr, Pat, Type};
9
10#[proc_macro_attribute]
29pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
30 let description = parse_macro_input!(attr as LitStr);
31 let input_fn = parse_macro_input!(item as ItemFn);
32
33 let fn_name = &input_fn.sig.ident;
34 let fn_name_str = fn_name.to_string();
35 let description_str = description.value();
36 let is_async = input_fn.sig.asyncness.is_some();
37
38 let mut param_schemas = Vec::new();
40 let mut param_idents = Vec::new();
41 let mut required_params = Vec::new();
42 let mut param_extractions = Vec::new();
43
44 for input in &input_fn.sig.inputs {
45 if let FnArg::Typed(pat_type) = input {
46 if let Pat::Ident(pat_ident) = &*pat_type.pat {
47 let param_name = pat_ident.ident.to_string();
48 let param_ident = &pat_ident.ident;
49 let param_type = &*pat_type.ty;
50
51 let is_optional = is_option_type(param_type);
53
54 if !is_optional {
55 required_params.push(param_name.clone());
56 }
57
58 param_idents.push(param_ident.clone());
59
60 let schema_gen = generate_param_schema(¶m_name, param_type, is_optional);
62 param_schemas.push(quote! {
63 properties.insert(
64 #param_name.to_string(),
65 #schema_gen
66 );
67 });
68
69 let extraction = generate_param_extraction(¶m_name, param_type, is_optional);
71 param_extractions.push(extraction);
72 }
73 }
74 }
75
76 let tool_struct_name = syn::Ident::new(
78 &format!("{}Tool", to_pascal_case(&fn_name_str)),
79 fn_name.span(),
80 );
81
82 let execute_body = if is_async {
83 quote! {
84 let result = #fn_name(#(#param_idents),*).await;
85 let output = serde_json::to_string(&result)
86 .unwrap_or_else(|_| format!("{:?}", result));
87 Ok(::agents_core::tools::ToolResult::text(&ctx, output))
88 }
89 } else {
90 quote! {
91 let result = #fn_name(#(#param_idents),*);
92 let output = serde_json::to_string(&result)
93 .unwrap_or_else(|_| format!("{:?}", result));
94 Ok(::agents_core::tools::ToolResult::text(&ctx, output))
95 }
96 };
97
98 let expanded = quote! {
99 #input_fn
100
101 pub struct #tool_struct_name;
102
103 impl #tool_struct_name {
104 pub fn as_tool() -> ::std::sync::Arc<dyn ::agents_core::tools::Tool> {
105 ::std::sync::Arc::new(#tool_struct_name)
106 }
107 }
108
109 #[::async_trait::async_trait]
110 impl ::agents_core::tools::Tool for #tool_struct_name {
111 fn schema(&self) -> ::agents_core::tools::ToolSchema {
112 use ::std::collections::HashMap;
113 use ::agents_core::tools::{ToolSchema, ToolParameterSchema};
114
115 let mut properties = HashMap::new();
116 #(#param_schemas)*
117
118 ToolSchema::new(
119 #fn_name_str,
120 #description_str,
121 ToolParameterSchema::object(
122 concat!(#fn_name_str, " parameters"),
123 properties,
124 vec![#(#required_params.to_string()),*],
125 ),
126 )
127 }
128
129 async fn execute(
130 &self,
131 args: ::serde_json::Value,
132 ctx: ::agents_core::tools::ToolContext,
133 ) -> ::anyhow::Result<::agents_core::tools::ToolResult> {
134 #(#param_extractions)*
135 #execute_body
136 }
137 }
138 };
139
140 TokenStream::from(expanded)
141}
142
143fn is_option_type(ty: &Type) -> bool {
144 if let Type::Path(type_path) = ty {
145 if let Some(segment) = type_path.path.segments.last() {
146 return segment.ident == "Option";
147 }
148 }
149 false
150}
151
152fn generate_param_schema(
153 param_name: &str,
154 param_type: &Type,
155 is_optional: bool,
156) -> proc_macro2::TokenStream {
157 let description = format!("Parameter: {}", param_name);
158
159 let inner_type = if is_optional {
161 extract_option_inner_type(param_type)
162 } else {
163 param_type
164 };
165
166 match type_to_string(inner_type).as_str() {
168 "String" | "str" => quote! {
169 ::agents_core::tools::ToolParameterSchema::string(#description)
170 },
171 "i32" | "i64" | "u32" | "u64" | "isize" | "usize" => quote! {
172 ::agents_core::tools::ToolParameterSchema::integer(#description)
173 },
174 "f32" | "f64" => quote! {
175 ::agents_core::tools::ToolParameterSchema::number(#description)
176 },
177 "bool" => quote! {
178 ::agents_core::tools::ToolParameterSchema::boolean(#description)
179 },
180 _ => {
181 quote! {
183 ::agents_core::tools::ToolParameterSchema::string(#description)
184 }
185 }
186 }
187}
188
189fn generate_param_extraction(
190 param_name: &str,
191 param_type: &Type,
192 is_optional: bool,
193) -> proc_macro2::TokenStream {
194 let param_ident = syn::Ident::new(param_name, proc_macro2::Span::call_site());
195
196 if is_optional {
197 let inner_type = extract_option_inner_type(param_type);
198 let conversion = generate_type_conversion(inner_type);
199 quote! {
200 let #param_ident: Option<_> = args.get(#param_name)
201 .and_then(|v| #conversion);
202 }
203 } else {
204 let conversion = generate_type_conversion(param_type);
205 quote! {
206 let #param_ident: #param_type = args.get(#param_name)
207 .and_then(|v| #conversion)
208 .ok_or_else(|| ::anyhow::anyhow!(concat!("Missing required parameter: ", #param_name)))?;
209 }
210 }
211}
212
213fn generate_type_conversion(ty: &Type) -> proc_macro2::TokenStream {
214 match type_to_string(ty).as_str() {
215 "String" => quote! { Some(v.as_str()?.to_string()) },
216 "str" => quote! { Some(v.as_str()?.to_string()) },
217 "i32" | "i64" => quote! { Some(v.as_i64()? as _) },
218 "u32" | "u64" => quote! { Some(v.as_u64()? as _) },
219 "f32" | "f64" => quote! { Some(v.as_f64()? as _) },
220 "bool" => quote! { v.as_bool() },
221 _ => quote! { ::serde_json::from_value(v.clone()).ok() },
222 }
223}
224
225fn extract_option_inner_type(ty: &Type) -> &Type {
226 if let Type::Path(type_path) = ty {
227 if let Some(segment) = type_path.path.segments.last() {
228 if segment.ident == "Option" {
229 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
230 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
231 return inner;
232 }
233 }
234 }
235 }
236 }
237 ty
238}
239
240fn type_to_string(ty: &Type) -> String {
241 if let Type::Path(type_path) = ty {
242 if let Some(segment) = type_path.path.segments.last() {
243 return segment.ident.to_string();
244 }
245 }
246 "Unknown".to_string()
247}
248
249fn to_pascal_case(s: &str) -> String {
250 s.split('_')
251 .map(|word| {
252 let mut chars = word.chars();
253 match chars.next() {
254 None => String::new(),
255 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
256 }
257 })
258 .collect()
259}