agents_runtime/
planner.rs1use std::sync::Arc;
2
3use agents_core::agent::{PlannerAction, PlannerContext, PlannerDecision, PlannerHandle};
4use agents_core::llm::{LanguageModel, LlmRequest};
5use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
6use agents_core::state::AgentStateSnapshot;
7use async_trait::async_trait;
8use serde::Deserialize;
9use serde_json::Value;
10
11#[derive(Clone)]
12pub struct LlmBackedPlanner {
13 model: Arc<dyn LanguageModel>,
14}
15
16impl LlmBackedPlanner {
17 pub fn new(model: Arc<dyn LanguageModel>) -> Self {
18 Self { model }
19 }
20
21 pub fn model(&self) -> &Arc<dyn LanguageModel> {
23 &self.model
24 }
25}
26
27#[derive(Debug, Deserialize)]
28struct ToolCall {
29 name: String,
30 #[serde(default)]
31 args: Value,
32}
33
34#[derive(Debug, Deserialize)]
35struct PlannerOutput {
36 #[serde(default)]
37 tool_calls: Vec<ToolCall>,
38 #[serde(default)]
39 response: Option<String>,
40}
41
42#[async_trait]
43impl PlannerHandle for LlmBackedPlanner {
44 async fn plan(
45 &self,
46 context: PlannerContext,
47 _state: Arc<AgentStateSnapshot>,
48 ) -> anyhow::Result<PlannerDecision> {
49 let request = LlmRequest::new(context.system_prompt.clone(), context.history.clone())
50 .with_tools(context.tools.clone());
51 let response = self.model.generate(request).await?;
52 let message = response.message;
53
54 match parse_planner_output(&message)? {
55 PlannerOutputVariant::ToolCall { name, args } => Ok(PlannerDecision {
56 next_action: PlannerAction::CallTool {
57 tool_name: name,
58 payload: args,
59 },
60 }),
61 PlannerOutputVariant::Respond(text) => Ok(PlannerDecision {
62 next_action: PlannerAction::Respond {
63 message: AgentMessage {
64 role: MessageRole::Agent,
65 content: MessageContent::Text(text),
66 metadata: message.metadata,
67 },
68 },
69 }),
70 }
71 }
72
73 fn as_any(&self) -> &dyn std::any::Any {
74 self
75 }
76}
77
78enum PlannerOutputVariant {
79 ToolCall { name: String, args: Value },
80 Respond(String),
81}
82
83fn parse_planner_output(message: &AgentMessage) -> anyhow::Result<PlannerOutputVariant> {
84 match &message.content {
85 MessageContent::Json(value) => parse_from_value(value.clone()),
86 MessageContent::Text(text) => {
87 if let Some(parsed) = parse_from_text(text) {
89 if let Some(tc) = parsed.tool_calls.first() {
90 return Ok(PlannerOutputVariant::ToolCall {
91 name: tc.name.clone(),
92 args: tc.args.clone(),
93 });
94 }
95 if let Some(resp) = parsed.response {
96 return Ok(PlannerOutputVariant::Respond(resp));
97 }
98 }
99 Ok(PlannerOutputVariant::Respond(text.clone()))
100 }
101 }
102}
103
104fn parse_from_value(value: Value) -> anyhow::Result<PlannerOutputVariant> {
105 let parsed: PlannerOutput = serde_json::from_value(value)?;
106 if let Some(tool_call) = parsed.tool_calls.first() {
107 Ok(PlannerOutputVariant::ToolCall {
108 name: tool_call.name.clone(),
109 args: tool_call.args.clone(),
110 })
111 } else if let Some(response) = parsed.response {
112 Ok(PlannerOutputVariant::Respond(response))
113 } else {
114 anyhow::bail!("LLM response missing tool call and response fields")
115 }
116}
117
118fn parse_from_text(text: &str) -> Option<PlannerOutput> {
119 if let Some(parsed) = decode_output_from_str(text) {
121 return Some(parsed);
122 }
123 let trimmed = text.trim();
125 if trimmed.starts_with("```") {
126 let without_ticks = trimmed.trim_start_matches("```");
127 let without_lang = without_ticks
129 .trim_start_matches(|c: char| c.is_alphabetic())
130 .trim_start();
131 let inner = if let Some(end) = without_lang.rfind("```") {
132 &without_lang[..end]
133 } else {
134 without_lang
135 };
136 if let Some(parsed) = decode_output_from_str(inner) {
137 return Some(parsed);
138 }
139 }
140 None
141}
142
143fn decode_output_from_str(s: &str) -> Option<PlannerOutput> {
145 serde_json::from_str::<Value>(s)
146 .ok()
147 .and_then(|v| serde_json::from_value::<PlannerOutput>(v).ok())
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use agents_core::llm::{LanguageModel, LlmResponse};
154 use agents_core::messaging::MessageMetadata;
155 use async_trait::async_trait;
156
157 struct EchoModel;
158
159 #[async_trait]
160 impl LanguageModel for EchoModel {
161 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
162 Ok(LlmResponse {
163 message: request.messages.last().cloned().unwrap_or(AgentMessage {
164 role: MessageRole::Agent,
165 content: MessageContent::Text("".into()),
166 metadata: None,
167 }),
168 })
169 }
170 }
171
172 #[tokio::test]
173 async fn planner_falls_back_to_text_response() {
174 let planner = LlmBackedPlanner::new(Arc::new(EchoModel));
175 let context = PlannerContext {
176 history: vec![AgentMessage {
177 role: MessageRole::User,
178 content: MessageContent::Text("Hi".into()),
179 metadata: None,
180 }],
181 system_prompt: "Be helpful".into(),
182 tools: vec![],
183 };
184
185 let decision = planner
186 .plan(context, Arc::new(AgentStateSnapshot::default()))
187 .await
188 .unwrap();
189
190 match decision.next_action {
191 PlannerAction::Respond { message } => match message.content {
192 MessageContent::Text(text) => assert_eq!(text, "Hi"),
193 other => panic!("expected text, got {other:?}"),
194 },
195 _ => panic!("expected respond"),
196 }
197 }
198
199 struct ToolCallModel;
200
201 #[async_trait]
202 impl LanguageModel for ToolCallModel {
203 async fn generate(&self, _request: LlmRequest) -> anyhow::Result<LlmResponse> {
204 Ok(LlmResponse {
205 message: AgentMessage {
206 role: MessageRole::Agent,
207 content: MessageContent::Json(serde_json::json!({
208 "tool_calls": [
209 {
210 "name": "write_file",
211 "args": { "path": "notes.txt" }
212 }
213 ]
214 })),
215 metadata: Some(MessageMetadata {
216 tool_call_id: Some("call-1".into()),
217 cache_control: None,
218 }),
219 },
220 })
221 }
222 }
223
224 #[tokio::test]
225 async fn planner_parses_tool_call() {
226 let planner = LlmBackedPlanner::new(Arc::new(ToolCallModel));
227 let decision = planner
228 .plan(
229 PlannerContext {
230 history: vec![],
231 system_prompt: "System".into(),
232 tools: vec![],
233 },
234 Arc::new(AgentStateSnapshot::default()),
235 )
236 .await
237 .unwrap();
238
239 match decision.next_action {
240 PlannerAction::CallTool { tool_name, payload } => {
241 assert_eq!(tool_name, "write_file");
242 assert_eq!(payload["path"], "notes.txt");
243 }
244 _ => panic!("expected tool call"),
245 }
246 }
247}