agents_runtime/providers/
openai.rs1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12 pub api_key: String,
13 pub model: String,
14 pub api_url: Option<String>,
15}
16
17impl OpenAiConfig {
18 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19 Self {
20 api_key: api_key.into(),
21 model: model.into(),
22 api_url: None,
23 }
24 }
25
26 pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
27 self.api_url = api_url;
28 self
29 }
30}
31
32pub struct OpenAiChatModel {
33 client: Client,
34 config: OpenAiConfig,
35}
36
37impl OpenAiChatModel {
38 pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
39 Ok(Self {
40 client: Client::builder()
41 .user_agent("rust-deep-agents-sdk/0.1")
42 .build()?,
43 config,
44 })
45 }
46}
47
48#[derive(Serialize)]
49struct ChatRequest<'a> {
50 model: &'a str,
51 messages: &'a [OpenAiMessage],
52 #[serde(skip_serializing_if = "Option::is_none")]
53 stream: Option<bool>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 tools: Option<Vec<OpenAiTool>>,
56}
57
58#[derive(Serialize)]
59struct OpenAiMessage {
60 role: &'static str,
61 content: String,
62}
63
64#[derive(Clone, Serialize)]
65struct OpenAiTool {
66 #[serde(rename = "type")]
67 tool_type: String,
68 function: OpenAiFunction,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiFunction {
73 name: String,
74 description: String,
75 parameters: serde_json::Value,
76}
77
78#[derive(Deserialize)]
79struct ChatResponse {
80 choices: Vec<Choice>,
81}
82
83#[derive(Deserialize)]
84struct Choice {
85 message: ChoiceMessage,
86}
87
88#[derive(Deserialize)]
89struct ChoiceMessage {
90 content: Option<String>,
91 #[serde(default)]
92 tool_calls: Vec<OpenAiToolCall>,
93}
94
95#[derive(Deserialize)]
96struct OpenAiToolCall {
97 #[allow(dead_code)]
98 id: String,
99 #[serde(rename = "type")]
100 #[allow(dead_code)]
101 tool_type: String,
102 function: OpenAiFunctionCall,
103}
104
105#[derive(Deserialize)]
106struct OpenAiFunctionCall {
107 name: String,
108 arguments: String,
109}
110
111#[derive(Deserialize)]
113struct StreamResponse {
114 choices: Vec<StreamChoice>,
115}
116
117#[derive(Deserialize)]
118struct StreamChoice {
119 delta: StreamDelta,
120 finish_reason: Option<String>,
121}
122
123#[derive(Deserialize)]
124struct StreamDelta {
125 content: Option<String>,
126}
127
128fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
129 let mut messages = Vec::with_capacity(request.messages.len() + 1);
130 messages.push(OpenAiMessage {
131 role: "system",
132 content: request.system_prompt.clone(),
133 });
134
135 let mut last_was_tool_call = false;
137
138 for msg in &request.messages {
139 let role = match msg.role {
140 MessageRole::User => "user",
141 MessageRole::Agent => "assistant",
142 MessageRole::Tool => {
143 if !last_was_tool_call {
145 tracing::warn!("Skipping tool message without preceding tool_calls");
146 continue;
147 }
148 "tool"
149 }
150 MessageRole::System => "system",
151 };
152
153 let content = match &msg.content {
154 MessageContent::Text(text) => text.clone(),
155 MessageContent::Json(value) => value.to_string(),
156 };
157
158 last_was_tool_call =
160 matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
161
162 messages.push(OpenAiMessage { role, content });
163 }
164 messages
165}
166
167fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
169 if tools.is_empty() {
170 return None;
171 }
172
173 Some(
174 tools
175 .iter()
176 .map(|tool| OpenAiTool {
177 tool_type: "function".to_string(),
178 function: OpenAiFunction {
179 name: tool.name.clone(),
180 description: tool.description.clone(),
181 parameters: serde_json::to_value(&tool.parameters)
182 .unwrap_or_else(|_| serde_json::json!({})),
183 },
184 })
185 .collect(),
186 )
187}
188
189#[async_trait]
190impl LanguageModel for OpenAiChatModel {
191 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
192 let messages = to_openai_messages(&request);
193 let tools = to_openai_tools(&request.tools);
194
195 let body = ChatRequest {
196 model: &self.config.model,
197 messages: &messages,
198 stream: None,
199 tools: tools.clone(),
200 };
201 let url = self
202 .config
203 .api_url
204 .as_deref()
205 .unwrap_or("https://api.openai.com/v1/chat/completions");
206
207 tracing::debug!(
209 "OpenAI request: model={}, messages={}, tools={}",
210 self.config.model,
211 messages.len(),
212 tools.as_ref().map(|t| t.len()).unwrap_or(0)
213 );
214 for (i, msg) in messages.iter().enumerate() {
215 tracing::debug!(
216 "Message {}: role={}, content_len={}",
217 i,
218 msg.role,
219 msg.content.len()
220 );
221 if msg.content.len() < 500 {
222 tracing::debug!("Message {} content: {}", i, msg.content);
223 }
224 }
225
226 let response = self
227 .client
228 .post(url)
229 .bearer_auth(&self.config.api_key)
230 .json(&body)
231 .send()
232 .await?;
233
234 if !response.status().is_success() {
235 let status = response.status();
236 let error_text = response.text().await.unwrap_or_default();
237 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
238 return Err(anyhow::anyhow!(
239 "OpenAI API error: {} - {}",
240 status,
241 error_text
242 ));
243 }
244
245 let data: ChatResponse = response.json().await?;
246 let choice = data
247 .choices
248 .into_iter()
249 .next()
250 .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
251
252 if !choice.message.tool_calls.is_empty() {
254 let tool_calls: Vec<_> = choice
256 .message
257 .tool_calls
258 .iter()
259 .map(|tc| {
260 serde_json::json!({
261 "name": tc.function.name,
262 "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
263 .unwrap_or_else(|_| serde_json::json!({}))
264 })
265 })
266 .collect();
267
268 let tool_names: Vec<&str> = choice
270 .message
271 .tool_calls
272 .iter()
273 .map(|tc| tc.function.name.as_str())
274 .collect();
275
276 tracing::warn!(
277 "🔧 LLM CALLED {} TOOL(S): {:?}",
278 tool_calls.len(),
279 tool_names
280 );
281
282 for (i, tc) in choice.message.tool_calls.iter().enumerate() {
284 tracing::debug!(
285 "Tool call {}: {} with {} bytes of arguments",
286 i + 1,
287 tc.function.name,
288 tc.function.arguments.len()
289 );
290 }
291
292 return Ok(LlmResponse {
293 message: AgentMessage {
294 role: MessageRole::Agent,
295 content: MessageContent::Json(serde_json::json!({
296 "tool_calls": tool_calls
297 })),
298 metadata: None,
299 },
300 });
301 }
302
303 let content = choice.message.content.unwrap_or_else(|| "".to_string());
305
306 Ok(LlmResponse {
307 message: AgentMessage {
308 role: MessageRole::Agent,
309 content: MessageContent::Text(content),
310 metadata: None,
311 },
312 })
313 }
314
315 async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
316 let messages = to_openai_messages(&request);
317 let tools = to_openai_tools(&request.tools);
318
319 let body = ChatRequest {
320 model: &self.config.model,
321 messages: &messages,
322 stream: Some(true),
323 tools,
324 };
325 let url = self
326 .config
327 .api_url
328 .as_deref()
329 .unwrap_or("https://api.openai.com/v1/chat/completions");
330
331 tracing::debug!(
332 "OpenAI streaming request: model={}, messages={}, tools={}",
333 self.config.model,
334 messages.len(),
335 request.tools.len()
336 );
337
338 let response = self
339 .client
340 .post(url)
341 .bearer_auth(&self.config.api_key)
342 .json(&body)
343 .send()
344 .await?;
345
346 if !response.status().is_success() {
347 let status = response.status();
348 let error_text = response.text().await.unwrap_or_default();
349 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
350 return Err(anyhow::anyhow!(
351 "OpenAI API error: {} - {}",
352 status,
353 error_text
354 ));
355 }
356
357 let stream = response.bytes_stream();
359 let accumulated_content = Arc::new(Mutex::new(String::new()));
360 let buffer = Arc::new(Mutex::new(String::new()));
361
362 let is_done = Arc::new(Mutex::new(false));
363
364 let final_accumulated = accumulated_content.clone();
366 let final_is_done = is_done.clone();
367
368 let chunk_stream = stream.map(move |result| {
369 let accumulated = accumulated_content.clone();
370 let buffer = buffer.clone();
371 let is_done = is_done.clone();
372
373 if *is_done.lock().unwrap() {
375 return Ok(StreamChunk::TextDelta(String::new()));
376 }
377
378 match result {
379 Ok(bytes) => {
380 let text = String::from_utf8_lossy(&bytes);
381
382 buffer.lock().unwrap().push_str(&text);
384
385 let mut buf = buffer.lock().unwrap();
386
387 let mut collected_deltas = String::new();
389 let mut found_done = false;
390 let mut found_finish = false;
391
392 let parts: Vec<&str> = buf.split("\n\n").collect();
394 let complete_messages = if parts.len() > 1 {
395 &parts[..parts.len() - 1] } else {
397 &[] };
399
400 for msg in complete_messages {
402 for line in msg.lines() {
403 if let Some(data) = line.strip_prefix("data: ") {
404 let json_str = data.trim();
405
406 if json_str == "[DONE]" {
408 found_done = true;
409 break;
410 }
411
412 match serde_json::from_str::<StreamResponse>(json_str) {
414 Ok(chunk) => {
415 if let Some(choice) = chunk.choices.first() {
416 if let Some(content) = &choice.delta.content {
418 if !content.is_empty() {
419 accumulated.lock().unwrap().push_str(content);
420 collected_deltas.push_str(content);
421 }
422 }
423
424 if choice.finish_reason.is_some() {
426 found_finish = true;
427 }
428 }
429 }
430 Err(e) => {
431 tracing::debug!("Failed to parse SSE message: {}", e);
432 }
433 }
434 }
435 }
436 if found_done || found_finish {
437 break;
438 }
439 }
440
441 if !complete_messages.is_empty() {
443 *buf = parts.last().unwrap_or(&"").to_string();
444 }
445
446 if found_done || found_finish {
448 let content = accumulated.lock().unwrap().clone();
449 let final_message = AgentMessage {
450 role: MessageRole::Agent,
451 content: MessageContent::Text(content),
452 metadata: None,
453 };
454 *is_done.lock().unwrap() = true;
455 buf.clear();
456 return Ok(StreamChunk::Done {
457 message: final_message,
458 });
459 }
460
461 if !collected_deltas.is_empty() {
463 return Ok(StreamChunk::TextDelta(collected_deltas));
464 }
465
466 Ok(StreamChunk::TextDelta(String::new()))
467 }
468 Err(e) => {
469 if !*is_done.lock().unwrap() {
471 let content = accumulated.lock().unwrap().clone();
472 if !content.is_empty() {
473 let final_message = AgentMessage {
474 role: MessageRole::Agent,
475 content: MessageContent::Text(content),
476 metadata: None,
477 };
478 *is_done.lock().unwrap() = true;
479 return Ok(StreamChunk::Done {
480 message: final_message,
481 });
482 }
483 }
484 Err(anyhow::anyhow!("Stream error: {}", e))
485 }
486 }
487 });
488
489 let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
491 if !*final_is_done.lock().unwrap() {
493 let content = final_accumulated.lock().unwrap().clone();
494 if !content.is_empty() {
495 let final_message = AgentMessage {
496 role: MessageRole::Agent,
497 content: MessageContent::Text(content),
498 metadata: None,
499 };
500 let content_text = match &final_message.content {
501 MessageContent::Text(t) => t.as_str(),
502 _ => "non-text",
503 };
504 tracing::debug!(
505 "Stream ended naturally, sending final Done chunk with {} chars",
506 content_text.len()
507 );
508 return Ok(StreamChunk::Done {
509 message: final_message,
510 });
511 }
512 }
513 Ok(StreamChunk::TextDelta(String::new()))
515 }));
516
517 Ok(Box::pin(stream_with_finale))
518 }
519}