agents_runtime/
planner.rs

1use std::sync::Arc;
2
3use agents_core::agent::{PlannerAction, PlannerContext, PlannerDecision, PlannerHandle};
4use agents_core::llm::{LanguageModel, LlmRequest};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
6use agents_core::state::AgentStateSnapshot;
7use async_trait::async_trait;
8use serde::Deserialize;
9use serde_json::Value;
10
11#[derive(Clone)]
12pub struct LlmBackedPlanner {
13    model: Arc<dyn LanguageModel>,
14}
15
16impl LlmBackedPlanner {
17    pub fn new(model: Arc<dyn LanguageModel>) -> Self {
18        Self { model }
19    }
20
21    /// Get the underlying language model for direct access (e.g., streaming)
22    pub fn model(&self) -> &Arc<dyn LanguageModel> {
23        &self.model
24    }
25}
26
27#[derive(Debug, Deserialize)]
28struct ToolCall {
29    name: String,
30    #[serde(default)]
31    args: Value,
32}
33
34#[derive(Debug, Deserialize)]
35struct PlannerOutput {
36    #[serde(default)]
37    tool_calls: Vec<ToolCall>,
38    #[serde(default)]
39    response: Option<String>,
40}
41
42#[async_trait]
43impl PlannerHandle for LlmBackedPlanner {
44    async fn plan(
45        &self,
46        context: PlannerContext,
47        _state: Arc<AgentStateSnapshot>,
48    ) -> anyhow::Result<PlannerDecision> {
49        let request = LlmRequest::new(context.system_prompt.clone(), context.history.clone())
50            .with_tools(context.tools.clone());
51        let response = self.model.generate(request).await?;
52        let message = response.message;
53
54        match parse_planner_output(&message)? {
55            PlannerOutputVariant::ToolCall { name, args } => Ok(PlannerDecision {
56                next_action: PlannerAction::CallTool {
57                    tool_name: name,
58                    payload: args,
59                },
60            }),
61            PlannerOutputVariant::Respond(text) => Ok(PlannerDecision {
62                next_action: PlannerAction::Respond {
63                    message: AgentMessage {
64                        role: MessageRole::Agent,
65                        content: MessageContent::Text(text),
66                        metadata: message.metadata,
67                    },
68                },
69            }),
70        }
71    }
72
73    fn as_any(&self) -> &dyn std::any::Any {
74        self
75    }
76}
77
78enum PlannerOutputVariant {
79    ToolCall { name: String, args: Value },
80    Respond(String),
81}
82
83fn parse_planner_output(message: &AgentMessage) -> anyhow::Result<PlannerOutputVariant> {
84    match &message.content {
85        MessageContent::Json(value) => parse_from_value(value.clone()),
86        MessageContent::Text(text) => {
87            // Try to parse JSON even when returned as text, optionally in code fences.
88            if let Some(parsed) = parse_from_text(text) {
89                if let Some(tc) = parsed.tool_calls.first() {
90                    return Ok(PlannerOutputVariant::ToolCall {
91                        name: tc.name.clone(),
92                        args: tc.args.clone(),
93                    });
94                }
95                if let Some(resp) = parsed.response {
96                    return Ok(PlannerOutputVariant::Respond(resp));
97                }
98            }
99            Ok(PlannerOutputVariant::Respond(text.clone()))
100        }
101    }
102}
103
104fn parse_from_value(value: Value) -> anyhow::Result<PlannerOutputVariant> {
105    let parsed: PlannerOutput = serde_json::from_value(value)?;
106    if let Some(tool_call) = parsed.tool_calls.first() {
107        Ok(PlannerOutputVariant::ToolCall {
108            name: tool_call.name.clone(),
109            args: tool_call.args.clone(),
110        })
111    } else if let Some(response) = parsed.response {
112        Ok(PlannerOutputVariant::Respond(response))
113    } else {
114        anyhow::bail!("LLM response missing tool call and response fields")
115    }
116}
117
118fn parse_from_text(text: &str) -> Option<PlannerOutput> {
119    // 1) Raw JSON
120    if let Some(parsed) = decode_output_from_str(text) {
121        return Some(parsed);
122    }
123    // 2) Remove common code fences ```json ... ``` or ``` ... ```
124    let trimmed = text.trim();
125    if trimmed.starts_with("```") {
126        let without_ticks = trimmed.trim_start_matches("```");
127        // optional language tag (e.g., json)
128        let without_lang = without_ticks
129            .trim_start_matches(|c: char| c.is_alphabetic())
130            .trim_start();
131        let inner = if let Some(end) = without_lang.rfind("```") {
132            &without_lang[..end]
133        } else {
134            without_lang
135        };
136        if let Some(parsed) = decode_output_from_str(inner) {
137            return Some(parsed);
138        }
139    }
140    None
141}
142
143/// Attempt to decode PlannerOutput from a JSON string; returns None on failure.
144fn decode_output_from_str(s: &str) -> Option<PlannerOutput> {
145    serde_json::from_str::<Value>(s)
146        .ok()
147        .and_then(|v| serde_json::from_value::<PlannerOutput>(v).ok())
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use agents_core::llm::{LanguageModel, LlmResponse};
154    use agents_core::messaging::MessageMetadata;
155    use async_trait::async_trait;
156
157    struct EchoModel;
158
159    #[async_trait]
160    impl LanguageModel for EchoModel {
161        async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
162            Ok(LlmResponse {
163                message: request.messages.last().cloned().unwrap_or(AgentMessage {
164                    role: MessageRole::Agent,
165                    content: MessageContent::Text("".into()),
166                    metadata: None,
167                }),
168            })
169        }
170    }
171
172    #[tokio::test]
173    async fn planner_falls_back_to_text_response() {
174        let planner = LlmBackedPlanner::new(Arc::new(EchoModel));
175        let context = PlannerContext {
176            history: vec![AgentMessage {
177                role: MessageRole::User,
178                content: MessageContent::Text("Hi".into()),
179                metadata: None,
180            }],
181            system_prompt: "Be helpful".into(),
182            tools: vec![],
183        };
184
185        let decision = planner
186            .plan(context, Arc::new(AgentStateSnapshot::default()))
187            .await
188            .unwrap();
189
190        match decision.next_action {
191            PlannerAction::Respond { message } => match message.content {
192                MessageContent::Text(text) => assert_eq!(text, "Hi"),
193                other => panic!("expected text, got {other:?}"),
194            },
195            _ => panic!("expected respond"),
196        }
197    }
198
199    struct ToolCallModel;
200
201    #[async_trait]
202    impl LanguageModel for ToolCallModel {
203        async fn generate(&self, _request: LlmRequest) -> anyhow::Result<LlmResponse> {
204            Ok(LlmResponse {
205                message: AgentMessage {
206                    role: MessageRole::Agent,
207                    content: MessageContent::Json(serde_json::json!({
208                        "tool_calls": [
209                            {
210                                "name": "write_file",
211                                "args": { "path": "notes.txt" }
212                            }
213                        ]
214                    })),
215                    metadata: Some(MessageMetadata {
216                        tool_call_id: Some("call-1".into()),
217                        cache_control: None,
218                    }),
219                },
220            })
221        }
222    }
223
224    #[tokio::test]
225    async fn planner_parses_tool_call() {
226        let planner = LlmBackedPlanner::new(Arc::new(ToolCallModel));
227        let decision = planner
228            .plan(
229                PlannerContext {
230                    history: vec![],
231                    system_prompt: "System".into(),
232                    tools: vec![],
233                },
234                Arc::new(AgentStateSnapshot::default()),
235            )
236            .await
237            .unwrap();
238
239        match decision.next_action {
240            PlannerAction::CallTool { tool_name, payload } => {
241                assert_eq!(tool_name, "write_file");
242                assert_eq!(payload["path"], "notes.txt");
243            }
244            _ => panic!("expected tool call"),
245        }
246    }
247}