agents_runtime/providers/
openai.rs

1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12    pub api_key: String,
13    pub model: String,
14    pub api_url: Option<String>,
15}
16
17impl OpenAiConfig {
18    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19        Self {
20            api_key: api_key.into(),
21            model: model.into(),
22            api_url: None,
23        }
24    }
25
26    pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
27        self.api_url = api_url;
28        self
29    }
30}
31
32pub struct OpenAiChatModel {
33    client: Client,
34    config: OpenAiConfig,
35}
36
37impl OpenAiChatModel {
38    pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
39        Ok(Self {
40            client: Client::builder()
41                .user_agent("rust-deep-agents-sdk/0.1")
42                .build()?,
43            config,
44        })
45    }
46}
47
48#[derive(Serialize)]
49struct ChatRequest<'a> {
50    model: &'a str,
51    messages: &'a [OpenAiMessage],
52    #[serde(skip_serializing_if = "Option::is_none")]
53    stream: Option<bool>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    tools: Option<Vec<OpenAiTool>>,
56}
57
58#[derive(Serialize)]
59struct OpenAiMessage {
60    role: &'static str,
61    content: String,
62}
63
64#[derive(Clone, Serialize)]
65struct OpenAiTool {
66    #[serde(rename = "type")]
67    tool_type: String,
68    function: OpenAiFunction,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiFunction {
73    name: String,
74    description: String,
75    parameters: serde_json::Value,
76}
77
78#[derive(Deserialize)]
79struct ChatResponse {
80    choices: Vec<Choice>,
81}
82
83#[derive(Deserialize)]
84struct Choice {
85    message: ChoiceMessage,
86}
87
88#[derive(Deserialize)]
89struct ChoiceMessage {
90    content: Option<String>,
91    #[serde(default)]
92    tool_calls: Vec<OpenAiToolCall>,
93}
94
95#[derive(Deserialize)]
96struct OpenAiToolCall {
97    #[allow(dead_code)]
98    id: String,
99    #[serde(rename = "type")]
100    #[allow(dead_code)]
101    tool_type: String,
102    function: OpenAiFunctionCall,
103}
104
105#[derive(Deserialize)]
106struct OpenAiFunctionCall {
107    name: String,
108    arguments: String,
109}
110
111// Streaming response structures
112#[derive(Deserialize)]
113struct StreamResponse {
114    choices: Vec<StreamChoice>,
115}
116
117#[derive(Deserialize)]
118struct StreamChoice {
119    delta: StreamDelta,
120    finish_reason: Option<String>,
121}
122
123#[derive(Deserialize)]
124struct StreamDelta {
125    content: Option<String>,
126}
127
128fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
129    let mut messages = Vec::with_capacity(request.messages.len() + 1);
130    messages.push(OpenAiMessage {
131        role: "system",
132        content: request.system_prompt.clone(),
133    });
134
135    // Filter and validate message sequence for OpenAI compatibility
136    let mut last_was_tool_call = false;
137
138    for msg in &request.messages {
139        let role = match msg.role {
140            MessageRole::User => "user",
141            MessageRole::Agent => "assistant",
142            MessageRole::Tool => {
143                // Only include tool messages if they follow a tool call
144                if !last_was_tool_call {
145                    tracing::warn!("Skipping tool message without preceding tool_calls");
146                    continue;
147                }
148                "tool"
149            }
150            MessageRole::System => "system",
151        };
152
153        let content = match &msg.content {
154            MessageContent::Text(text) => text.clone(),
155            MessageContent::Json(value) => value.to_string(),
156        };
157
158        // Check if this assistant message contains tool calls
159        last_was_tool_call =
160            matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
161
162        messages.push(OpenAiMessage { role, content });
163    }
164    messages
165}
166
167/// Convert tool schemas to OpenAI function calling format
168fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
169    if tools.is_empty() {
170        return None;
171    }
172
173    Some(
174        tools
175            .iter()
176            .map(|tool| OpenAiTool {
177                tool_type: "function".to_string(),
178                function: OpenAiFunction {
179                    name: tool.name.clone(),
180                    description: tool.description.clone(),
181                    parameters: serde_json::to_value(&tool.parameters)
182                        .unwrap_or_else(|_| serde_json::json!({})),
183                },
184            })
185            .collect(),
186    )
187}
188
189#[async_trait]
190impl LanguageModel for OpenAiChatModel {
191    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
192        let messages = to_openai_messages(&request);
193        let tools = to_openai_tools(&request.tools);
194
195        let body = ChatRequest {
196            model: &self.config.model,
197            messages: &messages,
198            stream: None,
199            tools: tools.clone(),
200        };
201        let url = self
202            .config
203            .api_url
204            .as_deref()
205            .unwrap_or("https://api.openai.com/v1/chat/completions");
206
207        // Debug logging
208        tracing::debug!(
209            "OpenAI request: model={}, messages={}, tools={}",
210            self.config.model,
211            messages.len(),
212            tools.as_ref().map(|t| t.len()).unwrap_or(0)
213        );
214        for (i, msg) in messages.iter().enumerate() {
215            tracing::debug!(
216                "Message {}: role={}, content_len={}",
217                i,
218                msg.role,
219                msg.content.len()
220            );
221            if msg.content.len() < 500 {
222                tracing::debug!("Message {} content: {}", i, msg.content);
223            }
224        }
225
226        let response = self
227            .client
228            .post(url)
229            .bearer_auth(&self.config.api_key)
230            .json(&body)
231            .send()
232            .await?;
233
234        if !response.status().is_success() {
235            let status = response.status();
236            let error_text = response.text().await.unwrap_or_default();
237            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
238            return Err(anyhow::anyhow!(
239                "OpenAI API error: {} - {}",
240                status,
241                error_text
242            ));
243        }
244
245        let data: ChatResponse = response.json().await?;
246        let choice = data
247            .choices
248            .into_iter()
249            .next()
250            .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
251
252        // Handle tool calls if present
253        if !choice.message.tool_calls.is_empty() {
254            // Convert OpenAI tool_calls format to our JSON format
255            let tool_calls: Vec<_> = choice
256                .message
257                .tool_calls
258                .iter()
259                .map(|tc| {
260                    serde_json::json!({
261                        "name": tc.function.name,
262                        "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
263                            .unwrap_or_else(|_| serde_json::json!({}))
264                    })
265                })
266                .collect();
267
268            // Enhanced logging for tool call detection
269            let tool_names: Vec<&str> = choice
270                .message
271                .tool_calls
272                .iter()
273                .map(|tc| tc.function.name.as_str())
274                .collect();
275
276            tracing::warn!(
277                "🔧 LLM CALLED {} TOOL(S): {:?}",
278                tool_calls.len(),
279                tool_names
280            );
281
282            // Log argument sizes for debugging
283            for (i, tc) in choice.message.tool_calls.iter().enumerate() {
284                tracing::debug!(
285                    "Tool call {}: {} with {} bytes of arguments",
286                    i + 1,
287                    tc.function.name,
288                    tc.function.arguments.len()
289                );
290            }
291
292            return Ok(LlmResponse {
293                message: AgentMessage {
294                    role: MessageRole::Agent,
295                    content: MessageContent::Json(serde_json::json!({
296                        "tool_calls": tool_calls
297                    })),
298                    metadata: None,
299                },
300            });
301        }
302
303        // Regular text response
304        let content = choice.message.content.unwrap_or_else(|| "".to_string());
305
306        Ok(LlmResponse {
307            message: AgentMessage {
308                role: MessageRole::Agent,
309                content: MessageContent::Text(content),
310                metadata: None,
311            },
312        })
313    }
314
315    async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
316        let messages = to_openai_messages(&request);
317        let tools = to_openai_tools(&request.tools);
318
319        let body = ChatRequest {
320            model: &self.config.model,
321            messages: &messages,
322            stream: Some(true),
323            tools,
324        };
325        let url = self
326            .config
327            .api_url
328            .as_deref()
329            .unwrap_or("https://api.openai.com/v1/chat/completions");
330
331        tracing::debug!(
332            "OpenAI streaming request: model={}, messages={}, tools={}",
333            self.config.model,
334            messages.len(),
335            request.tools.len()
336        );
337
338        let response = self
339            .client
340            .post(url)
341            .bearer_auth(&self.config.api_key)
342            .json(&body)
343            .send()
344            .await?;
345
346        if !response.status().is_success() {
347            let status = response.status();
348            let error_text = response.text().await.unwrap_or_default();
349            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
350            return Err(anyhow::anyhow!(
351                "OpenAI API error: {} - {}",
352                status,
353                error_text
354            ));
355        }
356
357        // Create stream from SSE response
358        let stream = response.bytes_stream();
359        let accumulated_content = Arc::new(Mutex::new(String::new()));
360        let buffer = Arc::new(Mutex::new(String::new()));
361
362        let is_done = Arc::new(Mutex::new(false));
363
364        // Clone Arcs for use in finale
365        let final_accumulated = accumulated_content.clone();
366        let final_is_done = is_done.clone();
367
368        let chunk_stream = stream.map(move |result| {
369            let accumulated = accumulated_content.clone();
370            let buffer = buffer.clone();
371            let is_done = is_done.clone();
372
373            // Check if we're already done
374            if *is_done.lock().unwrap() {
375                return Ok(StreamChunk::TextDelta(String::new()));
376            }
377
378            match result {
379                Ok(bytes) => {
380                    let text = String::from_utf8_lossy(&bytes);
381
382                    // Append to buffer
383                    buffer.lock().unwrap().push_str(&text);
384
385                    let mut buf = buffer.lock().unwrap();
386
387                    // Process complete SSE messages (separated by \n\n)
388                    let mut collected_deltas = String::new();
389                    let mut found_done = false;
390                    let mut found_finish = false;
391
392                    // Split on double newline to get complete SSE messages
393                    let parts: Vec<&str> = buf.split("\n\n").collect();
394                    let complete_messages = if parts.len() > 1 {
395                        &parts[..parts.len() - 1] // All but last (potentially incomplete)
396                    } else {
397                        &[] // No complete messages yet
398                    };
399
400                    // Process each complete SSE message
401                    for msg in complete_messages {
402                        for line in msg.lines() {
403                            if let Some(data) = line.strip_prefix("data: ") {
404                                let json_str = data.trim();
405
406                                // Check for [DONE] marker
407                                if json_str == "[DONE]" {
408                                    found_done = true;
409                                    break;
410                                }
411
412                                // Parse JSON chunk
413                                match serde_json::from_str::<StreamResponse>(json_str) {
414                                    Ok(chunk) => {
415                                        if let Some(choice) = chunk.choices.first() {
416                                            // Collect delta content
417                                            if let Some(content) = &choice.delta.content {
418                                                if !content.is_empty() {
419                                                    accumulated.lock().unwrap().push_str(content);
420                                                    collected_deltas.push_str(content);
421                                                }
422                                            }
423
424                                            // Check if stream is finished
425                                            if choice.finish_reason.is_some() {
426                                                found_finish = true;
427                                            }
428                                        }
429                                    }
430                                    Err(e) => {
431                                        tracing::debug!("Failed to parse SSE message: {}", e);
432                                    }
433                                }
434                            }
435                        }
436                        if found_done || found_finish {
437                            break;
438                        }
439                    }
440
441                    // Clear processed messages from buffer, keep only incomplete part
442                    if !complete_messages.is_empty() {
443                        *buf = parts.last().unwrap_or(&"").to_string();
444                    }
445
446                    // Handle completion
447                    if found_done || found_finish {
448                        let content = accumulated.lock().unwrap().clone();
449                        let final_message = AgentMessage {
450                            role: MessageRole::Agent,
451                            content: MessageContent::Text(content),
452                            metadata: None,
453                        };
454                        *is_done.lock().unwrap() = true;
455                        buf.clear();
456                        return Ok(StreamChunk::Done {
457                            message: final_message,
458                        });
459                    }
460
461                    // Return collected deltas (may be empty)
462                    if !collected_deltas.is_empty() {
463                        return Ok(StreamChunk::TextDelta(collected_deltas));
464                    }
465
466                    Ok(StreamChunk::TextDelta(String::new()))
467                }
468                Err(e) => {
469                    // Stream ended - check if we have accumulated content
470                    if !*is_done.lock().unwrap() {
471                        let content = accumulated.lock().unwrap().clone();
472                        if !content.is_empty() {
473                            let final_message = AgentMessage {
474                                role: MessageRole::Agent,
475                                content: MessageContent::Text(content),
476                                metadata: None,
477                            };
478                            *is_done.lock().unwrap() = true;
479                            return Ok(StreamChunk::Done {
480                                message: final_message,
481                            });
482                        }
483                    }
484                    Err(anyhow::anyhow!("Stream error: {}", e))
485                }
486            }
487        });
488
489        // Chain a final chunk to ensure Done is sent when stream completes
490        let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
491            // Check if we already sent Done
492            if !*final_is_done.lock().unwrap() {
493                let content = final_accumulated.lock().unwrap().clone();
494                if !content.is_empty() {
495                    let final_message = AgentMessage {
496                        role: MessageRole::Agent,
497                        content: MessageContent::Text(content),
498                        metadata: None,
499                    };
500                    let content_text = match &final_message.content {
501                        MessageContent::Text(t) => t.as_str(),
502                        _ => "non-text",
503                    };
504                    tracing::debug!(
505                        "Stream ended naturally, sending final Done chunk with {} chars",
506                        content_text.len()
507                    );
508                    return Ok(StreamChunk::Done {
509                        message: final_message,
510                    });
511                }
512            }
513            // Return empty delta if already done or no content
514            Ok(StreamChunk::TextDelta(String::new()))
515        }));
516
517        Ok(Box::pin(stream_with_finale))
518    }
519}