agents_macros/
lib.rs

1//! Procedural macros for Rust Deep Agents SDK
2//!
3//! This crate provides the `#[tool]` macro that converts regular Rust functions
4//! into AI agent tools with automatic JSON Schema generation.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, FnArg, ItemFn, LitStr, Pat, Type};
9
10/// Converts a Rust function into an AI agent tool with automatic schema generation.
11///
12/// # Examples
13///
14/// ```rust
15/// use agents_macros::tool;
16///
17/// #[tool("Greets a person by name")]
18/// fn greet(name: String) -> String {
19///     format!("Hello, {}!", name)
20/// }
21///
22/// #[tool("Searches the web for information")]
23/// async fn web_search(query: String, max_results: Option<u32>) -> Vec<String> {
24///     // Implementation
25///     vec![]
26/// }
27/// ```
28#[proc_macro_attribute]
29pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
30    let description = parse_macro_input!(attr as LitStr);
31    let input_fn = parse_macro_input!(item as ItemFn);
32
33    let fn_name = &input_fn.sig.ident;
34    let fn_name_str = fn_name.to_string();
35    let description_str = description.value();
36    let is_async = input_fn.sig.asyncness.is_some();
37
38    // Extract parameters
39    let mut param_schemas = Vec::new();
40    let mut param_idents = Vec::new();
41    let mut required_params = Vec::new();
42    let mut param_extractions = Vec::new();
43
44    for input in &input_fn.sig.inputs {
45        if let FnArg::Typed(pat_type) = input {
46            if let Pat::Ident(pat_ident) = &*pat_type.pat {
47                let param_name = pat_ident.ident.to_string();
48                let param_ident = &pat_ident.ident;
49                let param_type = &*pat_type.ty;
50
51                // Check if it's Option<T> (optional parameter)
52                let is_optional = is_option_type(param_type);
53
54                if !is_optional {
55                    required_params.push(param_name.clone());
56                }
57
58                param_idents.push(param_ident.clone());
59
60                // Generate schema for this parameter
61                let schema_gen = generate_param_schema(&param_name, param_type, is_optional);
62                param_schemas.push(quote! {
63                    properties.insert(
64                        #param_name.to_string(),
65                        #schema_gen
66                    );
67                });
68
69                // Generate extraction code
70                let extraction = generate_param_extraction(&param_name, param_type, is_optional);
71                param_extractions.push(extraction);
72            }
73        }
74    }
75
76    // Generate the tool wrapper
77    let tool_struct_name = syn::Ident::new(
78        &format!("{}Tool", to_pascal_case(&fn_name_str)),
79        fn_name.span(),
80    );
81
82    let execute_body = if is_async {
83        quote! {
84            let result = #fn_name(#(#param_idents),*).await;
85            let output = serde_json::to_string(&result)
86                .unwrap_or_else(|_| format!("{:?}", result));
87            Ok(::agents_core::tools::ToolResult::text(&ctx, output))
88        }
89    } else {
90        quote! {
91            let result = #fn_name(#(#param_idents),*);
92            let output = serde_json::to_string(&result)
93                .unwrap_or_else(|_| format!("{:?}", result));
94            Ok(::agents_core::tools::ToolResult::text(&ctx, output))
95        }
96    };
97
98    let expanded = quote! {
99        #input_fn
100
101        pub struct #tool_struct_name;
102
103        impl #tool_struct_name {
104            pub fn as_tool() -> ::std::sync::Arc<dyn ::agents_core::tools::Tool> {
105                ::std::sync::Arc::new(#tool_struct_name)
106            }
107        }
108
109        #[::async_trait::async_trait]
110        impl ::agents_core::tools::Tool for #tool_struct_name {
111            fn schema(&self) -> ::agents_core::tools::ToolSchema {
112                use ::std::collections::HashMap;
113                use ::agents_core::tools::{ToolSchema, ToolParameterSchema};
114
115                let mut properties = HashMap::new();
116                #(#param_schemas)*
117
118                ToolSchema::new(
119                    #fn_name_str,
120                    #description_str,
121                    ToolParameterSchema::object(
122                        concat!(#fn_name_str, " parameters"),
123                        properties,
124                        vec![#(#required_params.to_string()),*],
125                    ),
126                )
127            }
128
129            async fn execute(
130                &self,
131                args: ::serde_json::Value,
132                ctx: ::agents_core::tools::ToolContext,
133            ) -> ::anyhow::Result<::agents_core::tools::ToolResult> {
134                #(#param_extractions)*
135                #execute_body
136            }
137        }
138    };
139
140    TokenStream::from(expanded)
141}
142
143fn is_option_type(ty: &Type) -> bool {
144    if let Type::Path(type_path) = ty {
145        if let Some(segment) = type_path.path.segments.last() {
146            return segment.ident == "Option";
147        }
148    }
149    false
150}
151
152fn generate_param_schema(
153    param_name: &str,
154    param_type: &Type,
155    is_optional: bool,
156) -> proc_macro2::TokenStream {
157    let description = format!("Parameter: {}", param_name);
158
159    // Extract the inner type if it's Option<T>
160    let inner_type = if is_optional {
161        extract_option_inner_type(param_type)
162    } else {
163        param_type
164    };
165
166    // Generate schema based on type
167    match type_to_string(inner_type).as_str() {
168        "String" | "str" => quote! {
169            ::agents_core::tools::ToolParameterSchema::string(#description)
170        },
171        "i32" | "i64" | "u32" | "u64" | "isize" | "usize" => quote! {
172            ::agents_core::tools::ToolParameterSchema::integer(#description)
173        },
174        "f32" | "f64" => quote! {
175            ::agents_core::tools::ToolParameterSchema::number(#description)
176        },
177        "bool" => quote! {
178            ::agents_core::tools::ToolParameterSchema::boolean(#description)
179        },
180        _ => {
181            // For complex types (Vec, custom structs), default to string
182            quote! {
183                ::agents_core::tools::ToolParameterSchema::string(#description)
184            }
185        }
186    }
187}
188
189fn generate_param_extraction(
190    param_name: &str,
191    param_type: &Type,
192    is_optional: bool,
193) -> proc_macro2::TokenStream {
194    let param_ident = syn::Ident::new(param_name, proc_macro2::Span::call_site());
195
196    if is_optional {
197        let inner_type = extract_option_inner_type(param_type);
198        let conversion = generate_type_conversion(inner_type);
199        quote! {
200            let #param_ident: Option<_> = args.get(#param_name)
201                .and_then(|v| #conversion);
202        }
203    } else {
204        let conversion = generate_type_conversion(param_type);
205        quote! {
206            let #param_ident: #param_type = args.get(#param_name)
207                .and_then(|v| #conversion)
208                .ok_or_else(|| ::anyhow::anyhow!(concat!("Missing required parameter: ", #param_name)))?;
209        }
210    }
211}
212
213fn generate_type_conversion(ty: &Type) -> proc_macro2::TokenStream {
214    match type_to_string(ty).as_str() {
215        "String" => quote! { Some(v.as_str()?.to_string()) },
216        "str" => quote! { Some(v.as_str()?.to_string()) },
217        "i32" | "i64" => quote! { Some(v.as_i64()? as _) },
218        "u32" | "u64" => quote! { Some(v.as_u64()? as _) },
219        "f32" | "f64" => quote! { Some(v.as_f64()? as _) },
220        "bool" => quote! { v.as_bool() },
221        _ => quote! { ::serde_json::from_value(v.clone()).ok() },
222    }
223}
224
225fn extract_option_inner_type(ty: &Type) -> &Type {
226    if let Type::Path(type_path) = ty {
227        if let Some(segment) = type_path.path.segments.last() {
228            if segment.ident == "Option" {
229                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
230                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
231                        return inner;
232                    }
233                }
234            }
235        }
236    }
237    ty
238}
239
240fn type_to_string(ty: &Type) -> String {
241    if let Type::Path(type_path) = ty {
242        if let Some(segment) = type_path.path.segments.last() {
243            return segment.ident.to_string();
244        }
245    }
246    "Unknown".to_string()
247}
248
249fn to_pascal_case(s: &str) -> String {
250    s.split('_')
251        .map(|word| {
252            let mut chars = word.chars();
253            match chars.next() {
254                None => String::new(),
255                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
256            }
257        })
258        .collect()
259}