agents_runtime/providers/
gemini.rs

1use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[derive(Clone)]
10pub struct GeminiConfig {
11    pub api_key: String,
12    pub model: String,
13    pub api_url: Option<String>,
14}
15
16pub struct GeminiChatModel {
17    client: Client,
18    config: GeminiConfig,
19}
20
21impl GeminiChatModel {
22    pub fn new(config: GeminiConfig) -> anyhow::Result<Self> {
23        Ok(Self {
24            client: Client::builder()
25                .user_agent("rust-deep-agents-sdk/0.1")
26                .build()?,
27            config,
28        })
29    }
30}
31
32#[derive(Serialize)]
33struct GeminiRequest {
34    contents: Vec<GeminiContent>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    system_instruction: Option<GeminiContent>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    tools: Option<Vec<GeminiToolDeclaration>>,
39}
40
41#[derive(Clone, Serialize)]
42struct GeminiToolDeclaration {
43    function_declarations: Vec<GeminiFunctionDeclaration>,
44}
45
46#[derive(Clone, Serialize)]
47struct GeminiFunctionDeclaration {
48    name: String,
49    description: String,
50    parameters: Value,
51}
52
53#[derive(Serialize)]
54struct GeminiContent {
55    role: String,
56    parts: Vec<GeminiPart>,
57}
58
59#[derive(Serialize)]
60struct GeminiPart {
61    text: String,
62}
63
64#[derive(Deserialize)]
65struct GeminiResponse {
66    candidates: Vec<GeminiCandidate>,
67}
68
69#[derive(Deserialize)]
70struct GeminiCandidate {
71    content: Option<GeminiContentResponse>,
72}
73
74#[derive(Deserialize)]
75struct GeminiContentResponse {
76    parts: Vec<GeminiPartResponse>,
77}
78
79#[derive(Deserialize)]
80struct GeminiPartResponse {
81    text: Option<String>,
82    #[serde(rename = "functionCall")]
83    function_call: Option<GeminiFunctionCall>,
84}
85
86#[derive(Deserialize)]
87struct GeminiFunctionCall {
88    name: String,
89    args: Value,
90}
91
92fn to_gemini_contents(request: &LlmRequest) -> (Vec<GeminiContent>, Option<GeminiContent>) {
93    let mut contents = Vec::new();
94    for message in &request.messages {
95        let role = match message.role {
96            MessageRole::User => "user",
97            MessageRole::Agent => "model",
98            MessageRole::Tool => "user",
99            MessageRole::System => "user",
100        };
101        let text = match &message.content {
102            MessageContent::Text(text) => text.clone(),
103            MessageContent::Json(value) => value.to_string(),
104        };
105        contents.push(GeminiContent {
106            role: role.into(),
107            parts: vec![GeminiPart { text }],
108        });
109    }
110
111    let system_instruction = if request.system_prompt.trim().is_empty() {
112        None
113    } else {
114        Some(GeminiContent {
115            role: "system".into(),
116            parts: vec![GeminiPart {
117                text: request.system_prompt.clone(),
118            }],
119        })
120    };
121
122    (contents, system_instruction)
123}
124
125/// Convert tool schemas to Gemini function declarations format
126fn to_gemini_tools(tools: &[ToolSchema]) -> Option<Vec<GeminiToolDeclaration>> {
127    if tools.is_empty() {
128        return None;
129    }
130
131    Some(vec![GeminiToolDeclaration {
132        function_declarations: tools
133            .iter()
134            .map(|tool| GeminiFunctionDeclaration {
135                name: tool.name.clone(),
136                description: tool.description.clone(),
137                parameters: serde_json::to_value(&tool.parameters)
138                    .unwrap_or_else(|_| serde_json::json!({})),
139            })
140            .collect(),
141    }])
142}
143
144#[async_trait]
145impl LanguageModel for GeminiChatModel {
146    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
147        let (contents, system_instruction) = to_gemini_contents(&request);
148        let tools = to_gemini_tools(&request.tools);
149
150        // Debug logging (before moving contents)
151        tracing::debug!(
152            "Gemini request: model={}, contents={}, tools={}",
153            self.config.model,
154            contents.len(),
155            tools
156                .as_ref()
157                .map(|t| t
158                    .iter()
159                    .map(|td| td.function_declarations.len())
160                    .sum::<usize>())
161                .unwrap_or(0)
162        );
163
164        let body = GeminiRequest {
165            contents,
166            system_instruction,
167            tools,
168        };
169
170        let base_url = self
171            .config
172            .api_url
173            .clone()
174            .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".into());
175        let url = format!(
176            "{}/models/{}:generateContent?key={}",
177            base_url, self.config.model, self.config.api_key
178        );
179
180        let response = self
181            .client
182            .post(&url)
183            .json(&body)
184            .send()
185            .await?
186            .error_for_status()?;
187
188        let data: GeminiResponse = response.json().await?;
189
190        // Check if response contains function calls
191        let function_calls: Vec<_> = data
192            .candidates
193            .iter()
194            .filter_map(|candidate| candidate.content.as_ref())
195            .flat_map(|content| &content.parts)
196            .filter_map(|part| part.function_call.as_ref())
197            .collect();
198
199        if !function_calls.is_empty() {
200            // Convert Gemini functionCall format to our JSON format
201            let tool_calls: Vec<_> = function_calls
202                .iter()
203                .map(|fc| {
204                    serde_json::json!({
205                        "name": fc.name,
206                        "args": fc.args
207                    })
208                })
209                .collect();
210
211            tracing::debug!(
212                "Gemini response contains {} function calls",
213                tool_calls.len()
214            );
215
216            return Ok(LlmResponse {
217                message: AgentMessage {
218                    role: MessageRole::Agent,
219                    content: MessageContent::Json(serde_json::json!({
220                        "tool_calls": tool_calls
221                    })),
222                    metadata: None,
223                },
224            });
225        }
226
227        // Regular text response
228        let text = data
229            .candidates
230            .into_iter()
231            .filter_map(|candidate| candidate.content)
232            .flat_map(|content| content.parts)
233            .find_map(|part| part.text)
234            .unwrap_or_default();
235
236        Ok(LlmResponse {
237            message: AgentMessage {
238                role: MessageRole::Agent,
239                content: MessageContent::Text(text),
240                metadata: None,
241            },
242        })
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn gemini_conversion_handles_system_prompt() {
252        let request = LlmRequest::new(
253            "You are concise",
254            vec![AgentMessage {
255                role: MessageRole::User,
256                content: MessageContent::Text("Hello".into()),
257                metadata: None,
258            }],
259        );
260        let (contents, system) = to_gemini_contents(&request);
261        assert_eq!(contents.len(), 1);
262        assert_eq!(contents[0].role, "user");
263        assert!(system.is_some());
264        assert_eq!(system.unwrap().parts[0].text, "You are concise");
265    }
266}