1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6 AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9 BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10 WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20#[derive(Debug, Clone)]
23pub struct ModelRequest {
24 pub system_prompt: String,
25 pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29 pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30 Self {
31 system_prompt: system_prompt.into(),
32 messages,
33 }
34 }
35
36 pub fn append_prompt(&mut self, fragment: &str) {
37 if !fragment.is_empty() {
38 self.system_prompt.push_str("\n\n");
39 self.system_prompt.push_str(fragment);
40 }
41 }
42}
43
44pub struct MiddlewareContext<'a> {
46 pub request: &'a mut ModelRequest,
47 pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51 pub fn with_request(
52 request: &'a mut ModelRequest,
53 state: Arc<RwLock<AgentStateSnapshot>>,
54 ) -> Self {
55 Self { request, state }
56 }
57}
58
59#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64 fn id(&self) -> &'static str;
66
67 fn tools(&self) -> Vec<ToolBox> {
69 Vec::new()
70 }
71
72 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75 async fn before_tool_execution(
91 &self,
92 _tool_name: &str,
93 _tool_args: &serde_json::Value,
94 _call_id: &str,
95 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96 Ok(None)
97 }
98}
99
100pub struct SummarizationMiddleware {
101 pub messages_to_keep: usize,
102 pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106 pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107 Self {
108 messages_to_keep,
109 summary_note: summary_note.into(),
110 }
111 }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116 fn id(&self) -> &'static str {
117 "summarization"
118 }
119
120 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121 if ctx.request.messages.len() > self.messages_to_keep {
122 let dropped = ctx.request.messages.len() - self.messages_to_keep;
123 let mut truncated = ctx
124 .request
125 .messages
126 .split_off(ctx.request.messages.len() - self.messages_to_keep);
127 truncated.insert(
128 0,
129 AgentMessage {
130 role: MessageRole::System,
131 content: MessageContent::Text(format!(
132 "{} ({} earlier messages summarized)",
133 self.summary_note, dropped
134 )),
135 metadata: None,
136 },
137 );
138 ctx.request.messages = truncated;
139 }
140 Ok(())
141 }
142}
143
144pub struct PlanningMiddleware {
145 _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150 Self { _state: state }
151 }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156 fn id(&self) -> &'static str {
157 "planning"
158 }
159
160 fn tools(&self) -> Vec<ToolBox> {
161 use agents_toolkit::create_todos_tools;
162 create_todos_tools()
163 }
164
165 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166 ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
167 Ok(())
168 }
169}
170
171pub struct FilesystemMiddleware {
172 _state: Arc<RwLock<AgentStateSnapshot>>,
173}
174
175impl FilesystemMiddleware {
176 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
177 Self { _state: state }
178 }
179}
180
181#[async_trait]
182impl AgentMiddleware for FilesystemMiddleware {
183 fn id(&self) -> &'static str {
184 "filesystem"
185 }
186
187 fn tools(&self) -> Vec<ToolBox> {
188 create_filesystem_tools()
189 }
190
191 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
192 ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
193 Ok(())
194 }
195}
196
197#[derive(Clone)]
198pub struct SubAgentRegistration {
199 pub descriptor: SubAgentDescriptor,
200 pub agent: Arc<dyn AgentHandle>,
201}
202
203struct SubAgentRegistry {
204 agents: HashMap<String, Arc<dyn AgentHandle>>,
205}
206
207impl SubAgentRegistry {
208 fn new(registrations: Vec<SubAgentRegistration>) -> Self {
209 let mut agents = HashMap::new();
210 for reg in registrations {
211 agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
212 }
213 Self { agents }
214 }
215
216 fn available_names(&self) -> Vec<String> {
217 self.agents.keys().cloned().collect()
218 }
219
220 fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
221 self.agents.get(name).cloned()
222 }
223}
224
225pub struct SubAgentMiddleware {
226 task_tool: ToolBox,
227 descriptors: Vec<SubAgentDescriptor>,
228 _registry: Arc<SubAgentRegistry>,
229}
230
231impl SubAgentMiddleware {
232 pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
233 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
234 let registry = Arc::new(SubAgentRegistry::new(registrations));
235 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
236 Self {
237 task_tool,
238 descriptors,
239 _registry: registry,
240 }
241 }
242
243 pub fn new_with_events(
244 registrations: Vec<SubAgentRegistration>,
245 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
246 ) -> Self {
247 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
248 let registry = Arc::new(SubAgentRegistry::new(registrations));
249 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
250 Self {
251 task_tool,
252 descriptors,
253 _registry: registry,
254 }
255 }
256
257 fn prompt_fragment(&self) -> String {
258 let descriptions: Vec<String> = if self.descriptors.is_empty() {
259 vec![String::from("- general-purpose: Default reasoning agent")]
260 } else {
261 self.descriptors
262 .iter()
263 .map(|agent| format!("- {}: {}", agent.name, agent.description))
264 .collect()
265 };
266
267 TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
268 }
269}
270
271#[async_trait]
272impl AgentMiddleware for SubAgentMiddleware {
273 fn id(&self) -> &'static str {
274 "subagent"
275 }
276
277 fn tools(&self) -> Vec<ToolBox> {
278 vec![self.task_tool.clone()]
279 }
280
281 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
282 ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
283 ctx.request.append_prompt(&self.prompt_fragment());
284 Ok(())
285 }
286}
287
288#[derive(Clone, Debug)]
289pub struct HitlPolicy {
290 pub allow_auto: bool,
291 pub note: Option<String>,
292}
293
294pub struct HumanInLoopMiddleware {
295 policies: HashMap<String, HitlPolicy>,
296}
297
298impl HumanInLoopMiddleware {
299 pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
300 Self { policies }
301 }
302
303 pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
304 self.policies
305 .get(tool_name)
306 .filter(|policy| !policy.allow_auto)
307 }
308
309 fn prompt_fragment(&self) -> Option<String> {
310 let pending: Vec<String> = self
311 .policies
312 .iter()
313 .filter(|(_, policy)| !policy.allow_auto)
314 .map(|(tool, policy)| match &policy.note {
315 Some(note) => format!("- {tool}: {note}"),
316 None => format!("- {tool}: Requires approval"),
317 })
318 .collect();
319 if pending.is_empty() {
320 None
321 } else {
322 Some(format!(
323 "The following tools require human approval before execution:\n{}",
324 pending.join("\n")
325 ))
326 }
327 }
328}
329
330#[async_trait]
331impl AgentMiddleware for HumanInLoopMiddleware {
332 fn id(&self) -> &'static str {
333 "human-in-loop"
334 }
335
336 async fn before_tool_execution(
337 &self,
338 tool_name: &str,
339 tool_args: &serde_json::Value,
340 call_id: &str,
341 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
342 if let Some(policy) = self.requires_approval(tool_name) {
343 tracing::warn!(
344 tool_name = %tool_name,
345 call_id = %call_id,
346 policy_note = ?policy.note,
347 "🔒 HITL: Tool execution requires human approval"
348 );
349
350 let interrupt = agents_core::hitl::HitlInterrupt::new(
351 tool_name,
352 tool_args.clone(),
353 call_id,
354 policy.note.clone(),
355 );
356
357 return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
358 interrupt,
359 )));
360 }
361
362 Ok(None)
363 }
364
365 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
366 if let Some(fragment) = self.prompt_fragment() {
367 ctx.request.append_prompt(&fragment);
368 }
369 ctx.request.messages.push(AgentMessage {
370 role: MessageRole::System,
371 content: MessageContent::Text(
372 "Tools marked for human approval will emit interrupts requiring external resolution."
373 .into(),
374 ),
375 metadata: None,
376 });
377 Ok(())
378 }
379}
380
381pub struct BaseSystemPromptMiddleware;
382
383#[async_trait]
384impl AgentMiddleware for BaseSystemPromptMiddleware {
385 fn id(&self) -> &'static str {
386 "base-system-prompt"
387 }
388
389 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
390 ctx.request.append_prompt(BASE_AGENT_PROMPT);
391 Ok(())
392 }
393}
394
395pub struct DeepAgentPromptMiddleware {
405 custom_instructions: String,
406}
407
408impl DeepAgentPromptMiddleware {
409 pub fn new(custom_instructions: impl Into<String>) -> Self {
410 Self {
411 custom_instructions: custom_instructions.into(),
412 }
413 }
414}
415
416#[async_trait]
417impl AgentMiddleware for DeepAgentPromptMiddleware {
418 fn id(&self) -> &'static str {
419 "deep-agent-prompt"
420 }
421
422 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
423 use crate::prompts::get_deep_agent_system_prompt;
424 let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
425 ctx.request.append_prompt(&deep_prompt);
426 Ok(())
427 }
428}
429
430pub struct AnthropicPromptCachingMiddleware {
433 pub ttl: String,
434 pub unsupported_model_behavior: String,
435}
436
437impl AnthropicPromptCachingMiddleware {
438 pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
439 Self {
440 ttl: ttl.into(),
441 unsupported_model_behavior: unsupported_model_behavior.into(),
442 }
443 }
444
445 pub fn with_defaults() -> Self {
446 Self::new("5m", "ignore")
447 }
448
449 fn should_enable_caching(&self) -> bool {
452 !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
453 }
454}
455
456#[async_trait]
457impl AgentMiddleware for AnthropicPromptCachingMiddleware {
458 fn id(&self) -> &'static str {
459 "anthropic-prompt-caching"
460 }
461
462 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
463 if !self.should_enable_caching() {
464 return Ok(());
465 }
466
467 if !ctx.request.system_prompt.is_empty() {
469 let system_message = AgentMessage {
470 role: MessageRole::System,
471 content: MessageContent::Text(ctx.request.system_prompt.clone()),
472 metadata: Some(MessageMetadata {
473 tool_call_id: None,
474 cache_control: Some(CacheControl {
475 cache_type: "ephemeral".to_string(),
476 }),
477 }),
478 };
479
480 ctx.request.messages.insert(0, system_message);
482
483 ctx.request.system_prompt.clear();
485
486 tracing::debug!(
487 ttl = %self.ttl,
488 behavior = %self.unsupported_model_behavior,
489 "Applied Anthropic prompt caching to system message"
490 );
491 }
492
493 Ok(())
494 }
495}
496
497pub struct TaskRouterTool {
498 registry: Arc<SubAgentRegistry>,
499 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
500 delegation_depth: Arc<RwLock<u32>>,
501}
502
503impl TaskRouterTool {
504 fn new(
505 registry: Arc<SubAgentRegistry>,
506 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
507 ) -> Self {
508 Self {
509 registry,
510 event_dispatcher,
511 delegation_depth: Arc::new(RwLock::new(0)),
512 }
513 }
514
515 fn available_subagents(&self) -> Vec<String> {
516 self.registry.available_names()
517 }
518
519 fn emit_event(&self, event: agents_core::events::AgentEvent) {
520 if let Some(dispatcher) = &self.event_dispatcher {
521 let dispatcher_clone = dispatcher.clone();
522 tokio::spawn(async move {
523 dispatcher_clone.dispatch(event).await;
524 });
525 }
526 }
527
528 fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
529 agents_core::events::EventMetadata::new(
530 "default".to_string(),
531 uuid::Uuid::new_v4().to_string(),
532 None,
533 )
534 }
535
536 fn get_delegation_depth(&self) -> u32 {
537 *self.delegation_depth.read().unwrap_or_else(|_| {
538 tracing::warn!("Failed to read delegation depth, defaulting to 0");
539 panic!("RwLock poisoned")
540 })
541 }
542
543 fn increment_delegation_depth(&self) {
544 if let Ok(mut depth) = self.delegation_depth.write() {
545 *depth += 1;
546 }
547 }
548
549 fn decrement_delegation_depth(&self) {
550 if let Ok(mut depth) = self.delegation_depth.write() {
551 if *depth > 0 {
552 *depth -= 1;
553 }
554 }
555 }
556}
557
558#[derive(Debug, Clone, Deserialize)]
559struct TaskInvocationArgs {
560 #[serde(alias = "description")]
561 instruction: String,
562 #[serde(alias = "subagent_type")]
563 agent: String,
564}
565
566#[async_trait]
567impl Tool for TaskRouterTool {
568 fn schema(&self) -> agents_core::tools::ToolSchema {
569 use agents_core::tools::{ToolParameterSchema, ToolSchema};
570 use std::collections::HashMap;
571
572 let mut properties = HashMap::new();
573 properties.insert(
574 "agent".to_string(),
575 ToolParameterSchema::string("Name of the sub-agent to delegate to"),
576 );
577 properties.insert(
578 "instruction".to_string(),
579 ToolParameterSchema::string("Clear instruction for the sub-agent"),
580 );
581
582 ToolSchema::new(
583 "task",
584 "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
585 ToolParameterSchema::object(
586 "Task delegation parameters",
587 properties,
588 vec!["agent".to_string(), "instruction".to_string()],
589 ),
590 )
591 }
592
593 async fn execute(
594 &self,
595 args: serde_json::Value,
596 ctx: ToolContext,
597 ) -> anyhow::Result<ToolResult> {
598 let args: TaskInvocationArgs = serde_json::from_value(args)?;
599 let available = self.available_subagents();
600
601 if let Some(agent) = self.registry.get(&args.agent) {
602 self.increment_delegation_depth();
604 let current_depth = self.get_delegation_depth();
605
606 let instruction_summary = if args.instruction.len() > 100 {
608 format!("{}...", &args.instruction[..100])
609 } else {
610 args.instruction.clone()
611 };
612
613 self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
615 agents_core::events::SubAgentStartedEvent {
616 metadata: self.create_event_metadata(),
617 agent_name: args.agent.clone(),
618 instruction_summary: instruction_summary.clone(),
619 delegation_depth: current_depth,
620 },
621 ));
622
623 tracing::warn!(
625 "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
626 args.agent,
627 current_depth,
628 args.instruction
629 );
630
631 let start_time = std::time::Instant::now();
632 let user_message = AgentMessage {
633 role: MessageRole::User,
634 content: MessageContent::Text(args.instruction.clone()),
635 metadata: None,
636 };
637
638 let response = agent
639 .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
640 .await?;
641
642 let duration = start_time.elapsed();
644 let duration_ms = duration.as_millis() as u64;
645
646 let response_preview = match &response.content {
648 MessageContent::Text(t) => {
649 if t.len() > 100 {
650 format!("{}...", &t[..100])
651 } else {
652 t.clone()
653 }
654 }
655 MessageContent::Json(v) => {
656 let json_str = v.to_string();
657 if json_str.len() > 100 {
658 format!("{}...", &json_str[..100])
659 } else {
660 json_str
661 }
662 }
663 };
664
665 self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
667 agents_core::events::SubAgentCompletedEvent {
668 metadata: self.create_event_metadata(),
669 agent_name: args.agent.clone(),
670 duration_ms,
671 result_summary: response_preview.clone(),
672 },
673 ));
674
675 tracing::warn!(
677 "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
678 args.agent,
679 duration,
680 response_preview
681 );
682
683 self.decrement_delegation_depth();
685
686 let result_text = match response.content {
689 MessageContent::Text(text) => text,
690 MessageContent::Json(json) => json.to_string(),
691 };
692
693 return Ok(ToolResult::text(&ctx, result_text));
694 }
695
696 tracing::error!(
697 "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
698 args.agent,
699 available
700 );
701
702 Ok(ToolResult::text(
703 &ctx,
704 format!(
705 "Sub-agent '{}' not found. Available sub-agents: {}",
706 args.agent,
707 available.join(", ")
708 ),
709 ))
710 }
711}
712
713#[derive(Debug, Clone)]
714pub struct SubAgentDescriptor {
715 pub name: String,
716 pub description: String,
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722 use agents_core::agent::{AgentDescriptor, AgentHandle};
723 use agents_core::messaging::{MessageContent, MessageRole};
724 use serde_json::json;
725
726 struct AppendPromptMiddleware;
727
728 #[async_trait]
729 impl AgentMiddleware for AppendPromptMiddleware {
730 fn id(&self) -> &'static str {
731 "append-prompt"
732 }
733
734 async fn modify_model_request(
735 &self,
736 ctx: &mut MiddlewareContext<'_>,
737 ) -> anyhow::Result<()> {
738 ctx.request.system_prompt.push_str("\nExtra directives.");
739 Ok(())
740 }
741 }
742
743 #[tokio::test]
744 async fn middleware_mutates_prompt() {
745 let mut request = ModelRequest::new(
746 "System",
747 vec![AgentMessage {
748 role: MessageRole::User,
749 content: MessageContent::Text("Hi".into()),
750 metadata: None,
751 }],
752 );
753 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
754 let mut ctx = MiddlewareContext::with_request(&mut request, state);
755 let middleware = AppendPromptMiddleware;
756 middleware.modify_model_request(&mut ctx).await.unwrap();
757 assert!(ctx.request.system_prompt.contains("Extra directives"));
758 }
759
760 #[tokio::test]
761 async fn planning_middleware_registers_write_todos() {
762 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
763 let middleware = PlanningMiddleware::new(state);
764 let tool_names: Vec<_> = middleware
765 .tools()
766 .iter()
767 .map(|t| t.schema().name.clone())
768 .collect();
769 assert!(tool_names.contains(&"write_todos".to_string()));
770
771 let mut request = ModelRequest::new("System", vec![]);
772 let mut ctx = MiddlewareContext::with_request(
773 &mut request,
774 Arc::new(RwLock::new(AgentStateSnapshot::default())),
775 );
776 middleware.modify_model_request(&mut ctx).await.unwrap();
777 assert!(ctx.request.system_prompt.contains("write_todos"));
778 }
779
780 #[tokio::test]
781 async fn filesystem_middleware_registers_tools() {
782 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
783 let middleware = FilesystemMiddleware::new(state);
784 let tool_names: Vec<_> = middleware
785 .tools()
786 .iter()
787 .map(|t| t.schema().name.clone())
788 .collect();
789 for expected in ["ls", "read_file", "write_file", "edit_file"] {
790 assert!(tool_names.contains(&expected.to_string()));
791 }
792 }
793
794 #[tokio::test]
795 async fn summarization_middleware_trims_messages() {
796 let middleware = SummarizationMiddleware::new(2, "Summary note");
797 let mut request = ModelRequest::new(
798 "System",
799 vec![
800 AgentMessage {
801 role: MessageRole::User,
802 content: MessageContent::Text("one".into()),
803 metadata: None,
804 },
805 AgentMessage {
806 role: MessageRole::Agent,
807 content: MessageContent::Text("two".into()),
808 metadata: None,
809 },
810 AgentMessage {
811 role: MessageRole::User,
812 content: MessageContent::Text("three".into()),
813 metadata: None,
814 },
815 ],
816 );
817 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
818 let mut ctx = MiddlewareContext::with_request(&mut request, state);
819 middleware.modify_model_request(&mut ctx).await.unwrap();
820 assert_eq!(ctx.request.messages.len(), 3);
821 match &ctx.request.messages[0].content {
822 MessageContent::Text(text) => assert!(text.contains("Summary note")),
823 other => panic!("expected text, got {other:?}"),
824 }
825 }
826
827 struct StubAgent;
828
829 #[async_trait]
830 impl AgentHandle for StubAgent {
831 async fn describe(&self) -> AgentDescriptor {
832 AgentDescriptor {
833 name: "stub".into(),
834 version: "0.0.1".into(),
835 description: None,
836 }
837 }
838
839 async fn handle_message(
840 &self,
841 _input: AgentMessage,
842 _state: Arc<AgentStateSnapshot>,
843 ) -> anyhow::Result<AgentMessage> {
844 Ok(AgentMessage {
845 role: MessageRole::Agent,
846 content: MessageContent::Text("stub-response".into()),
847 metadata: None,
848 })
849 }
850 }
851
852 #[tokio::test]
853 async fn task_router_reports_unknown_subagent() {
854 let registry = Arc::new(SubAgentRegistry::new(vec![]));
855 let task_tool = TaskRouterTool::new(registry.clone(), None);
856 let state = Arc::new(AgentStateSnapshot::default());
857 let ctx = ToolContext::new(state);
858
859 let response = task_tool
860 .execute(
861 json!({
862 "instruction": "Do something",
863 "agent": "unknown"
864 }),
865 ctx,
866 )
867 .await
868 .unwrap();
869
870 match response {
871 ToolResult::Message(msg) => match msg.content {
872 MessageContent::Text(text) => {
873 assert!(text.contains("Sub-agent 'unknown' not found"))
874 }
875 other => panic!("expected text, got {other:?}"),
876 },
877 _ => panic!("expected message"),
878 }
879 }
880
881 #[tokio::test]
882 async fn subagent_middleware_appends_prompt() {
883 let subagents = vec![SubAgentRegistration {
884 descriptor: SubAgentDescriptor {
885 name: "research-agent".into(),
886 description: "Deep research specialist".into(),
887 },
888 agent: Arc::new(StubAgent),
889 }];
890 let middleware = SubAgentMiddleware::new(subagents);
891
892 let mut request = ModelRequest::new("System", vec![]);
893 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
894 let mut ctx = MiddlewareContext::with_request(&mut request, state);
895 middleware.modify_model_request(&mut ctx).await.unwrap();
896
897 assert!(ctx.request.system_prompt.contains("research-agent"));
898 let tool_names: Vec<_> = middleware
899 .tools()
900 .iter()
901 .map(|t| t.schema().name.clone())
902 .collect();
903 assert!(tool_names.contains(&"task".to_string()));
904 }
905
906 #[tokio::test]
907 async fn task_router_invokes_registered_subagent() {
908 let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
909 descriptor: SubAgentDescriptor {
910 name: "stub-agent".into(),
911 description: "Stub".into(),
912 },
913 agent: Arc::new(StubAgent),
914 }]));
915 let task_tool = TaskRouterTool::new(registry.clone(), None);
916 let state = Arc::new(AgentStateSnapshot::default());
917 let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
918 let response = task_tool
919 .execute(
920 json!({
921 "description": "do work",
922 "subagent_type": "stub-agent"
923 }),
924 ctx,
925 )
926 .await
927 .unwrap();
928
929 match response {
930 ToolResult::Message(msg) => {
931 assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
932 match msg.content {
933 MessageContent::Text(text) => assert_eq!(text, "stub-response"),
934 other => panic!("expected text, got {other:?}"),
935 }
936 }
937 _ => panic!("expected message"),
938 }
939 }
940
941 #[tokio::test]
942 async fn human_in_loop_appends_prompt() {
943 let middleware = HumanInLoopMiddleware::new(HashMap::from([(
944 "danger-tool".into(),
945 HitlPolicy {
946 allow_auto: false,
947 note: Some("Requires security review".into()),
948 },
949 )]));
950 let mut request = ModelRequest::new("System", vec![]);
951 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
952 let mut ctx = MiddlewareContext::with_request(&mut request, state);
953 middleware.modify_model_request(&mut ctx).await.unwrap();
954 assert!(ctx
955 .request
956 .system_prompt
957 .contains("danger-tool: Requires security review"));
958 }
959
960 #[tokio::test]
961 async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
962 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
963 let mut request = ModelRequest::new(
964 "This is the system prompt",
965 vec![AgentMessage {
966 role: MessageRole::User,
967 content: MessageContent::Text("Hello".into()),
968 metadata: None,
969 }],
970 );
971 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
972 let mut ctx = MiddlewareContext::with_request(&mut request, state);
973
974 middleware.modify_model_request(&mut ctx).await.unwrap();
976
977 assert!(ctx.request.system_prompt.is_empty());
979
980 assert_eq!(ctx.request.messages.len(), 2);
982
983 let system_message = &ctx.request.messages[0];
984 assert!(matches!(system_message.role, MessageRole::System));
985 assert_eq!(
986 system_message.content.as_text().unwrap(),
987 "This is the system prompt"
988 );
989
990 let metadata = system_message.metadata.as_ref().unwrap();
992 let cache_control = metadata.cache_control.as_ref().unwrap();
993 assert_eq!(cache_control.cache_type, "ephemeral");
994
995 let user_message = &ctx.request.messages[1];
997 assert!(matches!(user_message.role, MessageRole::User));
998 assert_eq!(user_message.content.as_text().unwrap(), "Hello");
999 }
1000
1001 #[tokio::test]
1002 async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1003 let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1004 let mut request = ModelRequest::new("This is the system prompt", vec![]);
1005 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1006 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1007
1008 middleware.modify_model_request(&mut ctx).await.unwrap();
1010
1011 assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1013 assert_eq!(ctx.request.messages.len(), 0);
1014 }
1015
1016 #[tokio::test]
1017 async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1018 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1019 let mut request = ModelRequest::new(
1020 "",
1021 vec![AgentMessage {
1022 role: MessageRole::User,
1023 content: MessageContent::Text("Hello".into()),
1024 metadata: None,
1025 }],
1026 );
1027 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1028 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1029
1030 middleware.modify_model_request(&mut ctx).await.unwrap();
1032
1033 assert!(ctx.request.system_prompt.is_empty());
1035 assert_eq!(ctx.request.messages.len(), 1);
1037 }
1038
1039 #[tokio::test]
1042 async fn hitl_creates_interrupt_for_disallowed_tool() {
1043 let mut policies = HashMap::new();
1044 policies.insert(
1045 "dangerous_tool".to_string(),
1046 HitlPolicy {
1047 allow_auto: false,
1048 note: Some("Requires security review".to_string()),
1049 },
1050 );
1051
1052 let middleware = HumanInLoopMiddleware::new(policies);
1053 let tool_args = json!({"action": "delete_all"});
1054
1055 let result = middleware
1056 .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1057 .await
1058 .unwrap();
1059
1060 assert!(result.is_some());
1061 let interrupt = result.unwrap();
1062
1063 match interrupt {
1064 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1065 assert_eq!(hitl.tool_name, "dangerous_tool");
1066 assert_eq!(hitl.tool_args, tool_args);
1067 assert_eq!(hitl.call_id, "call_123");
1068 assert_eq!(
1069 hitl.policy_note,
1070 Some("Requires security review".to_string())
1071 );
1072 }
1073 }
1074 }
1075
1076 #[tokio::test]
1077 async fn hitl_no_interrupt_for_allowed_tool() {
1078 let mut policies = HashMap::new();
1079 policies.insert(
1080 "safe_tool".to_string(),
1081 HitlPolicy {
1082 allow_auto: true,
1083 note: None,
1084 },
1085 );
1086
1087 let middleware = HumanInLoopMiddleware::new(policies);
1088 let tool_args = json!({"action": "read"});
1089
1090 let result = middleware
1091 .before_tool_execution("safe_tool", &tool_args, "call_456")
1092 .await
1093 .unwrap();
1094
1095 assert!(result.is_none());
1096 }
1097
1098 #[tokio::test]
1099 async fn hitl_no_interrupt_for_unlisted_tool() {
1100 let policies = HashMap::new();
1101 let middleware = HumanInLoopMiddleware::new(policies);
1102 let tool_args = json!({"action": "anything"});
1103
1104 let result = middleware
1105 .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1106 .await
1107 .unwrap();
1108
1109 assert!(result.is_none());
1110 }
1111
1112 #[tokio::test]
1113 async fn hitl_interrupt_includes_correct_details() {
1114 let mut policies = HashMap::new();
1115 policies.insert(
1116 "critical_tool".to_string(),
1117 HitlPolicy {
1118 allow_auto: false,
1119 note: Some("Critical operation - requires approval".to_string()),
1120 },
1121 );
1122
1123 let middleware = HumanInLoopMiddleware::new(policies);
1124 let tool_args = json!({
1125 "database": "production",
1126 "operation": "drop_table"
1127 });
1128
1129 let result = middleware
1130 .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1131 .await
1132 .unwrap();
1133
1134 assert!(result.is_some());
1135 let interrupt = result.unwrap();
1136
1137 match interrupt {
1138 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1139 assert_eq!(hitl.tool_name, "critical_tool");
1140 assert_eq!(hitl.tool_args["database"], "production");
1141 assert_eq!(hitl.tool_args["operation"], "drop_table");
1142 assert_eq!(hitl.call_id, "call_critical_1");
1143 assert!(hitl.policy_note.is_some());
1144 assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1145 }
1148 }
1149 }
1150
1151 #[tokio::test]
1152 async fn hitl_interrupt_without_policy_note() {
1153 let mut policies = HashMap::new();
1154 policies.insert(
1155 "tool_no_note".to_string(),
1156 HitlPolicy {
1157 allow_auto: false,
1158 note: None,
1159 },
1160 );
1161
1162 let middleware = HumanInLoopMiddleware::new(policies);
1163 let tool_args = json!({"param": "value"});
1164
1165 let result = middleware
1166 .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1167 .await
1168 .unwrap();
1169
1170 assert!(result.is_some());
1171 let interrupt = result.unwrap();
1172
1173 match interrupt {
1174 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1175 assert_eq!(hitl.tool_name, "tool_no_note");
1176 assert_eq!(hitl.policy_note, None);
1177 }
1178 }
1179 }
1180}