1use super::config::DeepAgentConfig;
7use crate::middleware::{
8 AgentMiddleware, AnthropicPromptCachingMiddleware, BaseSystemPromptMiddleware,
9 DeepAgentPromptMiddleware, FilesystemMiddleware, HumanInLoopMiddleware, MiddlewareContext,
10 ModelRequest, PlanningMiddleware, SubAgentDescriptor, SubAgentMiddleware, SubAgentRegistration,
11 SummarizationMiddleware,
12};
13use crate::planner::LlmBackedPlanner;
14use agents_core::agent::{
15 AgentDescriptor, AgentHandle, PlannerAction, PlannerContext, PlannerHandle,
16};
17use agents_core::hitl::{AgentInterrupt, HitlAction};
18use agents_core::messaging::{AgentMessage, MessageContent, MessageMetadata, MessageRole};
19use agents_core::persistence::{Checkpointer, ThreadId};
20use agents_core::state::AgentStateSnapshot;
21use agents_core::tools::{ToolBox, ToolContext, ToolResult};
22use async_trait::async_trait;
23use serde_json::Value;
24use std::collections::{HashMap, HashSet};
25use std::sync::{Arc, RwLock};
26
27const BUILTIN_TOOL_NAMES: &[&str] = &["write_todos", "ls", "read_file", "write_file", "edit_file"];
29
30pub struct DeepAgent {
37 descriptor: AgentDescriptor,
38 instructions: String,
39 planner: Arc<dyn PlannerHandle>,
40 middlewares: Vec<Arc<dyn AgentMiddleware>>,
41 base_tools: Vec<ToolBox>,
42 state: Arc<RwLock<AgentStateSnapshot>>,
43 history: Arc<RwLock<Vec<AgentMessage>>>,
44 _summarization: Option<Arc<SummarizationMiddleware>>,
45 _hitl: Option<Arc<HumanInLoopMiddleware>>,
46 builtin_tools: Option<HashSet<String>>,
47 checkpointer: Option<Arc<dyn Checkpointer>>,
48}
49
50impl DeepAgent {
51 fn collect_tools(&self) -> HashMap<String, ToolBox> {
52 let mut tools: HashMap<String, ToolBox> = HashMap::new();
53 for tool in &self.base_tools {
54 tools.insert(tool.schema().name.clone(), tool.clone());
55 }
56 for middleware in &self.middlewares {
57 for tool in middleware.tools() {
58 let tool_name = tool.schema().name.clone();
59 if self.should_include(&tool_name) {
60 tools.insert(tool_name, tool);
61 }
62 }
63 }
64 tools
65 }
66 fn should_include(&self, name: &str) -> bool {
69 let is_builtin = BUILTIN_TOOL_NAMES.contains(&name);
70 if !is_builtin {
71 return true;
72 }
73 match &self.builtin_tools {
74 None => true,
75 Some(selected) => selected.contains(name),
76 }
77 }
78
79 fn append_history(&self, message: AgentMessage) {
80 if let Ok(mut history) = self.history.write() {
81 history.push(message);
82 }
83 }
84
85 fn current_history(&self) -> Vec<AgentMessage> {
86 self.history.read().map(|h| h.clone()).unwrap_or_default()
87 }
88
89 pub async fn save_state(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
91 if let Some(ref checkpointer) = self.checkpointer {
92 let state = self
93 .state
94 .read()
95 .map_err(|_| anyhow::anyhow!("Failed to read agent state"))?
96 .clone();
97 checkpointer.save_state(thread_id, &state).await
98 } else {
99 tracing::warn!("Attempted to save state but no checkpointer is configured");
100 Ok(())
101 }
102 }
103
104 pub async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<bool> {
106 if let Some(ref checkpointer) = self.checkpointer {
107 if let Some(saved_state) = checkpointer.load_state(thread_id).await? {
108 *self
109 .state
110 .write()
111 .map_err(|_| anyhow::anyhow!("Failed to write agent state"))? = saved_state;
112 tracing::info!(thread_id = %thread_id, "Loaded agent state from checkpointer");
113 Ok(true)
114 } else {
115 tracing::debug!(thread_id = %thread_id, "No saved state found for thread");
116 Ok(false)
117 }
118 } else {
119 tracing::warn!("Attempted to load state but no checkpointer is configured");
120 Ok(false)
121 }
122 }
123
124 pub async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
126 if let Some(ref checkpointer) = self.checkpointer {
127 checkpointer.delete_thread(thread_id).await
128 } else {
129 tracing::warn!("Attempted to delete thread state but no checkpointer is configured");
130 Ok(())
131 }
132 }
133
134 pub async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
136 if let Some(ref checkpointer) = self.checkpointer {
137 checkpointer.list_threads().await
138 } else {
139 Ok(Vec::new())
140 }
141 }
142
143 async fn execute_tool(
144 &self,
145 tool: ToolBox,
146 _tool_name: String,
147 payload: Value,
148 ) -> anyhow::Result<AgentMessage> {
149 let state_snapshot = self.state.read().unwrap().clone();
150 let ctx = ToolContext::with_mutable_state(Arc::new(state_snapshot), self.state.clone());
151
152 let result = tool.execute(payload, ctx).await?;
153 Ok(self.apply_tool_result(result))
154 }
155
156 fn apply_tool_result(&self, result: ToolResult) -> AgentMessage {
157 match result {
158 ToolResult::Message(message) => {
159 message
162 }
163 ToolResult::WithStateUpdate {
164 message,
165 state_diff,
166 } => {
167 if let Ok(mut state) = self.state.write() {
168 let command = agents_core::command::Command::with_state(state_diff);
169 command.apply_to(&mut state);
170 }
171 message
174 }
175 }
176 }
177
178 pub fn current_interrupt(&self) -> Option<AgentInterrupt> {
180 self.state
181 .read()
182 .ok()
183 .and_then(|guard| guard.pending_interrupts.first().cloned())
184 }
185
186 pub async fn resume_with_approval(&self, action: HitlAction) -> anyhow::Result<AgentMessage> {
188 let interrupt = {
190 let state_guard = self
191 .state
192 .read()
193 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on state"))?;
194 state_guard
195 .pending_interrupts
196 .first()
197 .cloned()
198 .ok_or_else(|| anyhow::anyhow!("No pending interrupts"))?
199 };
200
201 let result_message = match action {
202 HitlAction::Accept => {
203 let AgentInterrupt::HumanInLoop(hitl) = interrupt;
205 tracing::info!(
206 tool_name = %hitl.tool_name,
207 call_id = %hitl.call_id,
208 "β
HITL: Tool approved, executing with original arguments"
209 );
210
211 let tools = self.collect_tools();
212 let tool = tools
213 .get(&hitl.tool_name)
214 .cloned()
215 .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", hitl.tool_name))?;
216
217 self.execute_tool(tool, hitl.tool_name, hitl.tool_args)
218 .await?
219 }
220
221 HitlAction::Edit {
222 tool_name,
223 tool_args,
224 } => {
225 tracing::info!(
227 tool_name = %tool_name,
228 "βοΈ HITL: Tool edited, executing with modified arguments"
229 );
230
231 let tools = self.collect_tools();
232 let tool = tools
233 .get(&tool_name)
234 .cloned()
235 .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", tool_name))?;
236
237 self.execute_tool(tool, tool_name, tool_args).await?
238 }
239
240 HitlAction::Reject { reason } => {
241 tracing::info!("β HITL: Tool rejected");
243
244 let text = reason
245 .unwrap_or_else(|| "Tool execution rejected by human reviewer.".to_string());
246
247 let message = AgentMessage {
248 role: MessageRole::Tool,
249 content: MessageContent::Text(text),
250 metadata: None,
251 };
252
253 self.append_history(message.clone());
254 message
255 }
256
257 HitlAction::Respond { message } => {
258 tracing::info!("π¬ HITL: Custom response provided");
260
261 self.append_history(message.clone());
262 message
263 }
264 };
265
266 {
268 let mut state_guard = self
269 .state
270 .write()
271 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on state"))?;
272 state_guard.clear_interrupts();
273 }
274
275 if let Some(checkpointer) = &self.checkpointer {
277 let state_clone = self
278 .state
279 .read()
280 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on state"))?
281 .clone();
282 checkpointer
283 .save_state(&ThreadId::default(), &state_clone)
284 .await?;
285 }
286
287 Ok(result_message)
288 }
289
290 pub async fn handle_message(
292 &self,
293 input: impl AsRef<str>,
294 state: Arc<AgentStateSnapshot>,
295 ) -> anyhow::Result<AgentMessage> {
296 self.handle_message_with_metadata(input, None, state).await
297 }
298
299 pub async fn handle_message_with_metadata(
301 &self,
302 input: impl AsRef<str>,
303 metadata: Option<MessageMetadata>,
304 state: Arc<AgentStateSnapshot>,
305 ) -> anyhow::Result<AgentMessage> {
306 let agent_message = AgentMessage {
307 role: MessageRole::User,
308 content: MessageContent::Text(input.as_ref().to_string()),
309 metadata,
310 };
311 self.handle_message_internal(agent_message, state).await
312 }
313
314 async fn handle_message_internal(
316 &self,
317 input: AgentMessage,
318 _state: Arc<AgentStateSnapshot>,
319 ) -> anyhow::Result<AgentMessage> {
320 self.append_history(input.clone());
321
322 let mut request = ModelRequest::new(&self.instructions, self.current_history());
324 let tools = self.collect_tools();
325 for middleware in &self.middlewares {
326 let mut ctx = MiddlewareContext::with_request(&mut request, self.state.clone());
327 middleware.modify_model_request(&mut ctx).await?;
328 }
329
330 let tool_schemas: Vec<_> = tools.values().map(|t| t.schema()).collect();
331 let context = PlannerContext {
332 history: request.messages.clone(),
333 system_prompt: request.system_prompt.clone(),
334 tools: tool_schemas,
335 };
336 let state_snapshot = Arc::new(self.state.read().map(|s| s.clone()).unwrap_or_default());
337
338 let decision = self.planner.plan(context, state_snapshot).await?;
340
341 match decision.next_action {
342 PlannerAction::Respond { message } => {
343 self.append_history(message.clone());
344 Ok(message)
345 }
346 PlannerAction::CallTool { tool_name, payload } => {
347 if let Some(tool) = tools.get(&tool_name).cloned() {
348 let call_id = format!("call_{}", uuid::Uuid::new_v4());
350 for middleware in &self.middlewares {
351 if let Some(interrupt) = middleware
352 .before_tool_execution(&tool_name, &payload, &call_id)
353 .await?
354 {
355 {
357 let mut state_guard = self.state.write().map_err(|_| {
358 anyhow::anyhow!("Failed to acquire write lock on state")
359 })?;
360 state_guard.add_interrupt(interrupt.clone());
361 }
362
363 if let Some(checkpointer) = &self.checkpointer {
365 let state_clone = self
366 .state
367 .read()
368 .map_err(|_| {
369 anyhow::anyhow!("Failed to acquire read lock on state")
370 })?
371 .clone();
372 checkpointer
373 .save_state(&ThreadId::default(), &state_clone)
374 .await?;
375 }
376
377 let interrupt_message = AgentMessage {
379 role: MessageRole::System,
380 content: MessageContent::Text(format!(
381 "βΈοΈ Execution paused: Tool '{}' requires human approval",
382 tool_name
383 )),
384 metadata: None,
385 };
386 self.append_history(interrupt_message.clone());
387 return Ok(interrupt_message);
388 }
389 }
390
391 let start_time = std::time::Instant::now();
393 tracing::warn!(
394 "βοΈ EXECUTING TOOL: {} with payload: {}",
395 tool_name,
396 serde_json::to_string(&payload)
397 .unwrap_or_else(|_| "invalid json".to_string())
398 );
399
400 let result = self
401 .execute_tool(tool.clone(), tool_name.clone(), payload.clone())
402 .await;
403
404 let duration = start_time.elapsed();
405 match result {
406 Ok(tool_result_message) => {
407 let content_preview = match &tool_result_message.content {
408 MessageContent::Text(t) => {
409 if t.len() > 100 {
410 format!("{}... ({} chars)", &t[..100], t.len())
411 } else {
412 t.clone()
413 }
414 }
415 MessageContent::Json(v) => {
416 format!("JSON: {} bytes", v.to_string().len())
417 }
418 };
419 tracing::warn!(
420 "β
TOOL COMPLETED: {} in {:?} - Result: {}",
421 tool_name,
422 duration,
423 content_preview
424 );
425
426 let natural_response = match &tool_result_message.content {
429 MessageContent::Text(text) => {
430 if text.is_empty() {
431 format!(
432 "I've executed the {} tool successfully.",
433 tool_name
434 )
435 } else {
436 text.clone()
438 }
439 }
440 MessageContent::Json(json) => {
441 format!("Tool result: {}", json)
442 }
443 };
444
445 let response = AgentMessage {
446 role: MessageRole::Agent,
447 content: MessageContent::Text(natural_response),
448 metadata: None,
449 };
450 self.append_history(response.clone());
451 Ok(response)
452 }
453 Err(e) => {
454 tracing::error!(
455 "β TOOL FAILED: {} in {:?} - Error: {}",
456 tool_name,
457 duration,
458 e
459 );
460
461 let error_response = AgentMessage {
463 role: MessageRole::Agent,
464 content: MessageContent::Text(format!(
465 "I encountered an error while executing {}: {}",
466 tool_name, e
467 )),
468 metadata: None,
469 };
470 self.append_history(error_response.clone());
471 Ok(error_response)
472 }
473 }
474 } else {
475 tracing::warn!("β οΈ Tool '{}' not found", tool_name);
477 let error_response = AgentMessage {
478 role: MessageRole::Agent,
479 content: MessageContent::Text(format!(
480 "I don't have access to the '{}' tool.",
481 tool_name
482 )),
483 metadata: None,
484 };
485 self.append_history(error_response.clone());
486 Ok(error_response)
487 }
488 }
489 PlannerAction::Terminate => {
490 tracing::debug!("π Agent terminated");
491 let message = AgentMessage {
492 role: MessageRole::Agent,
493 content: MessageContent::Text("Task completed.".into()),
494 metadata: None,
495 };
496 self.append_history(message.clone());
497 Ok(message)
498 }
499 }
500 }
501}
502
503#[async_trait]
504impl AgentHandle for DeepAgent {
505 async fn describe(&self) -> AgentDescriptor {
506 self.descriptor.clone()
507 }
508
509 async fn handle_message(
510 &self,
511 input: AgentMessage,
512 _state: Arc<AgentStateSnapshot>,
513 ) -> anyhow::Result<AgentMessage> {
514 self.handle_message_internal(input, _state).await
515 }
516
517 async fn handle_message_stream(
518 &self,
519 input: AgentMessage,
520 _state: Arc<AgentStateSnapshot>,
521 ) -> anyhow::Result<agents_core::agent::AgentStream> {
522 use crate::planner::LlmBackedPlanner;
523 use agents_core::llm::{LlmRequest, StreamChunk};
524
525 self.append_history(input.clone());
527
528 let mut request = ModelRequest::new(&self.instructions, self.current_history());
530 let tools = self.collect_tools();
531
532 for middleware in &self.middlewares {
534 let mut ctx = MiddlewareContext::with_request(&mut request, self.state.clone());
535 middleware.modify_model_request(&mut ctx).await?;
536 }
537
538 let tool_schemas: Vec<_> = tools.values().map(|t| t.schema()).collect();
540 let llm_request = LlmRequest {
541 system_prompt: request.system_prompt.clone(),
542 messages: request.messages.clone(),
543 tools: tool_schemas,
544 };
545
546 let planner_any = self.planner.as_any();
548
549 if let Some(llm_planner) = planner_any.downcast_ref::<LlmBackedPlanner>() {
550 let model = llm_planner.model().clone();
552 let stream = model.generate_stream(llm_request).await?;
553 Ok(stream)
554 } else {
555 let response = self.handle_message_internal(input, _state).await?;
557 Ok(Box::pin(futures::stream::once(async move {
558 Ok(StreamChunk::Done { message: response })
559 })))
560 }
561 }
562
563 async fn current_interrupt(&self) -> anyhow::Result<Option<AgentInterrupt>> {
564 let state_guard = self
565 .state
566 .read()
567 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on state"))?;
568 Ok(state_guard.pending_interrupts.first().cloned())
569 }
570
571 async fn resume_with_approval(
572 &self,
573 action: agents_core::hitl::HitlAction,
574 ) -> anyhow::Result<AgentMessage> {
575 self.resume_with_approval(action).await
576 }
577}
578
579pub fn create_deep_agent_from_config(config: DeepAgentConfig) -> DeepAgent {
584 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
585 let history = Arc::new(RwLock::new(Vec::<AgentMessage>::new()));
586
587 let planning = Arc::new(PlanningMiddleware::new(state.clone()));
588 let filesystem = Arc::new(FilesystemMiddleware::new(state.clone()));
589
590 let mut registrations: Vec<SubAgentRegistration> = Vec::new();
592
593 for subagent_config in &config.subagent_configs {
595 let sub_planner = if let Some(ref model) = subagent_config.model {
597 Arc::new(LlmBackedPlanner::new(model.clone())) as Arc<dyn PlannerHandle>
599 } else {
600 config.planner.clone()
602 };
603
604 let mut sub_cfg = DeepAgentConfig::new(subagent_config.instructions.clone(), sub_planner);
606
607 if let Some(ref tools) = subagent_config.tools {
609 for tool in tools {
610 sub_cfg = sub_cfg.with_tool(tool.clone());
611 }
612 }
613
614 if let Some(ref builtin) = subagent_config.builtin_tools {
616 sub_cfg = sub_cfg.with_builtin_tools(builtin.iter().cloned());
617 }
618
619 sub_cfg = sub_cfg.with_auto_general_purpose(false);
621
622 sub_cfg = sub_cfg.with_prompt_caching(subagent_config.enable_prompt_caching);
624
625 let sub_agent = create_deep_agent_from_config(sub_cfg);
627
628 registrations.push(SubAgentRegistration {
630 descriptor: SubAgentDescriptor {
631 name: subagent_config.name.clone(),
632 description: subagent_config.description.clone(),
633 },
634 agent: Arc::new(sub_agent),
635 });
636 }
637
638 if config.auto_general_purpose {
640 let has_gp = registrations
641 .iter()
642 .any(|r| r.descriptor.name == "general-purpose");
643 if !has_gp {
644 let mut sub_cfg =
646 DeepAgentConfig::new(config.instructions.clone(), config.planner.clone())
647 .with_auto_general_purpose(false)
648 .with_prompt_caching(config.enable_prompt_caching);
649 if let Some(ref selected) = config.builtin_tools {
650 sub_cfg = sub_cfg.with_builtin_tools(selected.iter().cloned());
651 }
652 if let Some(ref sum) = config.summarization {
653 sub_cfg = sub_cfg.with_summarization(sum.clone());
654 }
655 for t in &config.tools {
656 sub_cfg = sub_cfg.with_tool(t.clone());
657 }
658
659 let gp = create_deep_agent_from_config(sub_cfg);
660 registrations.push(SubAgentRegistration {
661 descriptor: SubAgentDescriptor {
662 name: "general-purpose".into(),
663 description: "Default reasoning agent".into(),
664 },
665 agent: Arc::new(gp),
666 });
667 }
668 }
669
670 let subagent = Arc::new(SubAgentMiddleware::new(registrations));
671 let base_prompt = Arc::new(BaseSystemPromptMiddleware);
672 let deep_agent_prompt = Arc::new(DeepAgentPromptMiddleware::new(config.instructions.clone()));
673 let summarization = config.summarization.as_ref().map(|cfg| {
674 Arc::new(SummarizationMiddleware::new(
675 cfg.messages_to_keep,
676 cfg.summary_note.clone(),
677 ))
678 });
679 let hitl = if config.tool_interrupts.is_empty() {
680 None
681 } else {
682 if config.checkpointer.is_none() {
684 tracing::error!(
685 "β οΈ HITL middleware requires a checkpointer to persist interrupt state. \
686 HITL will be disabled. Please configure a checkpointer to enable HITL."
687 );
688 None
689 } else {
690 tracing::info!("π HITL enabled for {} tools", config.tool_interrupts.len());
691 Some(Arc::new(HumanInLoopMiddleware::new(
692 config.tool_interrupts.clone(),
693 )))
694 }
695 };
696
697 let mut middlewares: Vec<Arc<dyn AgentMiddleware>> = vec![
700 base_prompt,
701 deep_agent_prompt,
702 planning,
703 filesystem,
704 subagent,
705 ];
706 if let Some(ref summary) = summarization {
707 middlewares.push(summary.clone());
708 }
709 if config.enable_prompt_caching {
710 middlewares.push(Arc::new(AnthropicPromptCachingMiddleware::with_defaults()));
711 }
712 if let Some(ref hitl_mw) = hitl {
713 middlewares.push(hitl_mw.clone());
714 }
715
716 DeepAgent {
717 descriptor: AgentDescriptor {
718 name: "deep-agent".into(),
719 version: "0.0.1".into(),
720 description: Some("Rust deep agent".into()),
721 },
722 instructions: config.instructions,
723 planner: config.planner,
724 middlewares,
725 base_tools: config.tools,
726 state,
727 history,
728 _summarization: summarization,
729 _hitl: hitl,
730 builtin_tools: config.builtin_tools,
731 checkpointer: config.checkpointer,
732 }
733}