agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18/// Request sent to the underlying language model. Middlewares can augment
19/// the system prompt or mutate the pending message list before the model call.
20#[derive(Debug, Clone)]
21pub struct ModelRequest {
22    pub system_prompt: String,
23    pub messages: Vec<AgentMessage>,
24}
25
26impl ModelRequest {
27    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
28        Self {
29            system_prompt: system_prompt.into(),
30            messages,
31        }
32    }
33
34    pub fn append_prompt(&mut self, fragment: &str) {
35        if !fragment.is_empty() {
36            self.system_prompt.push_str("\n\n");
37            self.system_prompt.push_str(fragment);
38        }
39    }
40}
41
42/// Read/write state handle exposed to middleware implementations.
43pub struct MiddlewareContext<'a> {
44    pub request: &'a mut ModelRequest,
45    pub state: Arc<RwLock<AgentStateSnapshot>>,
46}
47
48impl<'a> MiddlewareContext<'a> {
49    pub fn with_request(
50        request: &'a mut ModelRequest,
51        state: Arc<RwLock<AgentStateSnapshot>>,
52    ) -> Self {
53        Self { request, state }
54    }
55}
56
57/// Middleware hook that can register additional tools and mutate the model request
58/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
59/// interface async-first for future network calls.
60#[async_trait]
61pub trait AgentMiddleware: Send + Sync {
62    /// Unique identifier for logging and diagnostics.
63    fn id(&self) -> &'static str;
64
65    /// Tools to expose when this middleware is active.
66    fn tools(&self) -> Vec<ToolBox> {
67        Vec::new()
68    }
69
70    /// Apply middleware-specific mutations to the pending model request.
71    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
72
73    /// Hook called before tool execution - can return an interrupt to pause execution.
74    ///
75    /// This hook is invoked for each tool call before it executes, allowing middleware
76    /// to intercept and pause execution for human review. If an interrupt is returned,
77    /// the agent will save its state and wait for human approval before continuing.
78    ///
79    /// # Arguments
80    /// * `tool_name` - Name of the tool about to be executed
81    /// * `tool_args` - Arguments that will be passed to the tool
82    /// * `call_id` - Unique identifier for this tool call
83    ///
84    /// # Returns
85    /// * `Ok(Some(interrupt))` - Pause execution and wait for human response
86    /// * `Ok(None)` - Continue with tool execution normally
87    /// * `Err(e)` - Error occurred during interrupt check
88    async fn before_tool_execution(
89        &self,
90        _tool_name: &str,
91        _tool_args: &serde_json::Value,
92        _call_id: &str,
93    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
94        Ok(None)
95    }
96}
97
98pub struct SummarizationMiddleware {
99    pub messages_to_keep: usize,
100    pub summary_note: String,
101}
102
103impl SummarizationMiddleware {
104    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
105        Self {
106            messages_to_keep,
107            summary_note: summary_note.into(),
108        }
109    }
110}
111
112#[async_trait]
113impl AgentMiddleware for SummarizationMiddleware {
114    fn id(&self) -> &'static str {
115        "summarization"
116    }
117
118    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
119        if ctx.request.messages.len() > self.messages_to_keep {
120            let dropped = ctx.request.messages.len() - self.messages_to_keep;
121            let mut truncated = ctx
122                .request
123                .messages
124                .split_off(ctx.request.messages.len() - self.messages_to_keep);
125            truncated.insert(
126                0,
127                AgentMessage {
128                    role: MessageRole::System,
129                    content: MessageContent::Text(format!(
130                        "{} ({} earlier messages summarized)",
131                        self.summary_note, dropped
132                    )),
133                    metadata: None,
134                },
135            );
136            ctx.request.messages = truncated;
137        }
138        Ok(())
139    }
140}
141
142pub struct PlanningMiddleware {
143    _state: Arc<RwLock<AgentStateSnapshot>>,
144}
145
146impl PlanningMiddleware {
147    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
148        Self { _state: state }
149    }
150}
151
152#[async_trait]
153impl AgentMiddleware for PlanningMiddleware {
154    fn id(&self) -> &'static str {
155        "planning"
156    }
157
158    fn tools(&self) -> Vec<ToolBox> {
159        use agents_toolkit::create_todos_tools;
160        create_todos_tools()
161    }
162
163    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
164        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
165        Ok(())
166    }
167}
168
169pub struct FilesystemMiddleware {
170    _state: Arc<RwLock<AgentStateSnapshot>>,
171}
172
173impl FilesystemMiddleware {
174    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
175        Self { _state: state }
176    }
177}
178
179#[async_trait]
180impl AgentMiddleware for FilesystemMiddleware {
181    fn id(&self) -> &'static str {
182        "filesystem"
183    }
184
185    fn tools(&self) -> Vec<ToolBox> {
186        create_filesystem_tools()
187    }
188
189    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
190        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
191        Ok(())
192    }
193}
194
195#[derive(Clone)]
196pub struct SubAgentRegistration {
197    pub descriptor: SubAgentDescriptor,
198    pub agent: Arc<dyn AgentHandle>,
199}
200
201struct SubAgentRegistry {
202    agents: HashMap<String, Arc<dyn AgentHandle>>,
203}
204
205impl SubAgentRegistry {
206    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
207        let mut agents = HashMap::new();
208        for reg in registrations {
209            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
210        }
211        Self { agents }
212    }
213
214    fn available_names(&self) -> Vec<String> {
215        self.agents.keys().cloned().collect()
216    }
217
218    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
219        self.agents.get(name).cloned()
220    }
221}
222
223pub struct SubAgentMiddleware {
224    task_tool: ToolBox,
225    descriptors: Vec<SubAgentDescriptor>,
226    _registry: Arc<SubAgentRegistry>,
227}
228
229impl SubAgentMiddleware {
230    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
231        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
232        let registry = Arc::new(SubAgentRegistry::new(registrations));
233        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone()));
234        Self {
235            task_tool,
236            descriptors,
237            _registry: registry,
238        }
239    }
240
241    fn prompt_fragment(&self) -> String {
242        let descriptions: Vec<String> = if self.descriptors.is_empty() {
243            vec![String::from("- general-purpose: Default reasoning agent")]
244        } else {
245            self.descriptors
246                .iter()
247                .map(|agent| format!("- {}: {}", agent.name, agent.description))
248                .collect()
249        };
250
251        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
252    }
253}
254
255#[async_trait]
256impl AgentMiddleware for SubAgentMiddleware {
257    fn id(&self) -> &'static str {
258        "subagent"
259    }
260
261    fn tools(&self) -> Vec<ToolBox> {
262        vec![self.task_tool.clone()]
263    }
264
265    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
266        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
267        ctx.request.append_prompt(&self.prompt_fragment());
268        Ok(())
269    }
270}
271
272#[derive(Clone, Debug)]
273pub struct HitlPolicy {
274    pub allow_auto: bool,
275    pub note: Option<String>,
276}
277
278pub struct HumanInLoopMiddleware {
279    policies: HashMap<String, HitlPolicy>,
280}
281
282impl HumanInLoopMiddleware {
283    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
284        Self { policies }
285    }
286
287    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
288        self.policies
289            .get(tool_name)
290            .filter(|policy| !policy.allow_auto)
291    }
292
293    fn prompt_fragment(&self) -> Option<String> {
294        let pending: Vec<String> = self
295            .policies
296            .iter()
297            .filter(|(_, policy)| !policy.allow_auto)
298            .map(|(tool, policy)| match &policy.note {
299                Some(note) => format!("- {tool}: {note}"),
300                None => format!("- {tool}: Requires approval"),
301            })
302            .collect();
303        if pending.is_empty() {
304            None
305        } else {
306            Some(format!(
307                "The following tools require human approval before execution:\n{}",
308                pending.join("\n")
309            ))
310        }
311    }
312}
313
314#[async_trait]
315impl AgentMiddleware for HumanInLoopMiddleware {
316    fn id(&self) -> &'static str {
317        "human-in-loop"
318    }
319
320    async fn before_tool_execution(
321        &self,
322        tool_name: &str,
323        tool_args: &serde_json::Value,
324        call_id: &str,
325    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
326        if let Some(policy) = self.requires_approval(tool_name) {
327            tracing::warn!(
328                tool_name = %tool_name,
329                call_id = %call_id,
330                policy_note = ?policy.note,
331                "🔒 HITL: Tool execution requires human approval"
332            );
333
334            let interrupt = agents_core::hitl::HitlInterrupt::new(
335                tool_name,
336                tool_args.clone(),
337                call_id,
338                policy.note.clone(),
339            );
340
341            return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
342                interrupt,
343            )));
344        }
345
346        Ok(None)
347    }
348
349    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
350        if let Some(fragment) = self.prompt_fragment() {
351            ctx.request.append_prompt(&fragment);
352        }
353        ctx.request.messages.push(AgentMessage {
354            role: MessageRole::System,
355            content: MessageContent::Text(
356                "Tools marked for human approval will emit interrupts requiring external resolution."
357                    .into(),
358            ),
359            metadata: None,
360        });
361        Ok(())
362    }
363}
364
365pub struct BaseSystemPromptMiddleware;
366
367#[async_trait]
368impl AgentMiddleware for BaseSystemPromptMiddleware {
369    fn id(&self) -> &'static str {
370        "base-system-prompt"
371    }
372
373    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
374        ctx.request.append_prompt(BASE_AGENT_PROMPT);
375        Ok(())
376    }
377}
378
379/// Deep Agent prompt middleware that injects comprehensive tool usage instructions
380/// and examples to force the LLM to actually call tools instead of just talking about them.
381///
382/// This middleware is inspired by Python's deepagents package and Claude Code's system prompt.
383/// It provides:
384/// - Explicit tool usage rules with imperative language
385/// - JSON examples of tool calling
386/// - Workflow guidance for multi-step tasks
387/// - Few-shot examples for common patterns
388pub struct DeepAgentPromptMiddleware {
389    custom_instructions: String,
390}
391
392impl DeepAgentPromptMiddleware {
393    pub fn new(custom_instructions: impl Into<String>) -> Self {
394        Self {
395            custom_instructions: custom_instructions.into(),
396        }
397    }
398}
399
400#[async_trait]
401impl AgentMiddleware for DeepAgentPromptMiddleware {
402    fn id(&self) -> &'static str {
403        "deep-agent-prompt"
404    }
405
406    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
407        use crate::prompts::get_deep_agent_system_prompt;
408        let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
409        ctx.request.append_prompt(&deep_prompt);
410        Ok(())
411    }
412}
413
414/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
415/// to reduce latency on subsequent requests with the same base prompt.
416pub struct AnthropicPromptCachingMiddleware {
417    pub ttl: String,
418    pub unsupported_model_behavior: String,
419}
420
421impl AnthropicPromptCachingMiddleware {
422    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
423        Self {
424            ttl: ttl.into(),
425            unsupported_model_behavior: unsupported_model_behavior.into(),
426        }
427    }
428
429    pub fn with_defaults() -> Self {
430        Self::new("5m", "ignore")
431    }
432
433    /// Parse TTL string like "5m" to detect if caching is requested.
434    /// For now, any non-empty TTL enables ephemeral caching.
435    fn should_enable_caching(&self) -> bool {
436        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
437    }
438}
439
440#[async_trait]
441impl AgentMiddleware for AnthropicPromptCachingMiddleware {
442    fn id(&self) -> &'static str {
443        "anthropic-prompt-caching"
444    }
445
446    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
447        if !self.should_enable_caching() {
448            return Ok(());
449        }
450
451        // Mark system prompt for caching by converting it to a system message with cache control
452        if !ctx.request.system_prompt.is_empty() {
453            let system_message = AgentMessage {
454                role: MessageRole::System,
455                content: MessageContent::Text(ctx.request.system_prompt.clone()),
456                metadata: Some(MessageMetadata {
457                    tool_call_id: None,
458                    cache_control: Some(CacheControl {
459                        cache_type: "ephemeral".to_string(),
460                    }),
461                }),
462            };
463
464            // Insert system message at the beginning of the messages
465            ctx.request.messages.insert(0, system_message);
466
467            // Clear the system_prompt since it's now in messages
468            ctx.request.system_prompt.clear();
469
470            tracing::debug!(
471                ttl = %self.ttl,
472                behavior = %self.unsupported_model_behavior,
473                "Applied Anthropic prompt caching to system message"
474            );
475        }
476
477        Ok(())
478    }
479}
480
481pub struct TaskRouterTool {
482    registry: Arc<SubAgentRegistry>,
483}
484
485impl TaskRouterTool {
486    fn new(registry: Arc<SubAgentRegistry>) -> Self {
487        Self { registry }
488    }
489
490    fn available_subagents(&self) -> Vec<String> {
491        self.registry.available_names()
492    }
493}
494
495#[derive(Debug, Clone, Deserialize)]
496struct TaskInvocationArgs {
497    #[serde(alias = "description")]
498    instruction: String,
499    #[serde(alias = "subagent_type")]
500    agent: String,
501}
502
503#[async_trait]
504impl Tool for TaskRouterTool {
505    fn schema(&self) -> agents_core::tools::ToolSchema {
506        use agents_core::tools::{ToolParameterSchema, ToolSchema};
507        use std::collections::HashMap;
508
509        let mut properties = HashMap::new();
510        properties.insert(
511            "agent".to_string(),
512            ToolParameterSchema::string("Name of the sub-agent to delegate to"),
513        );
514        properties.insert(
515            "instruction".to_string(),
516            ToolParameterSchema::string("Clear instruction for the sub-agent"),
517        );
518
519        ToolSchema::new(
520            "task",
521            "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
522            ToolParameterSchema::object(
523                "Task delegation parameters",
524                properties,
525                vec!["agent".to_string(), "instruction".to_string()],
526            ),
527        )
528    }
529
530    async fn execute(
531        &self,
532        args: serde_json::Value,
533        ctx: ToolContext,
534    ) -> anyhow::Result<ToolResult> {
535        let args: TaskInvocationArgs = serde_json::from_value(args)?;
536        let available = self.available_subagents();
537
538        if let Some(agent) = self.registry.get(&args.agent) {
539            // Log delegation start
540            tracing::warn!(
541                "🎯 DELEGATING to sub-agent: {} with instruction: {}",
542                args.agent,
543                args.instruction
544            );
545
546            let start_time = std::time::Instant::now();
547            let user_message = AgentMessage {
548                role: MessageRole::User,
549                content: MessageContent::Text(args.instruction.clone()),
550                metadata: None,
551            };
552
553            let response = agent
554                .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
555                .await?;
556
557            // Log delegation completion
558            let duration = start_time.elapsed();
559            let response_preview = match &response.content {
560                MessageContent::Text(t) => {
561                    if t.len() > 100 {
562                        format!("{}... ({} chars)", &t[..100], t.len())
563                    } else {
564                        t.clone()
565                    }
566                }
567                MessageContent::Json(v) => {
568                    format!("JSON: {} bytes", v.to_string().len())
569                }
570            };
571
572            tracing::warn!(
573                "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
574                args.agent,
575                duration,
576                response_preview
577            );
578
579            // Return sub-agent response as text content, not as a separate tool message
580            // This will be incorporated into the LLM's next response naturally
581            let result_text = match response.content {
582                MessageContent::Text(text) => text,
583                MessageContent::Json(json) => json.to_string(),
584            };
585
586            return Ok(ToolResult::text(&ctx, result_text));
587        }
588
589        tracing::error!(
590            "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
591            args.agent,
592            available
593        );
594
595        Ok(ToolResult::text(
596            &ctx,
597            format!(
598                "Sub-agent '{}' not found. Available sub-agents: {}",
599                args.agent,
600                available.join(", ")
601            ),
602        ))
603    }
604}
605
606#[derive(Debug, Clone)]
607pub struct SubAgentDescriptor {
608    pub name: String,
609    pub description: String,
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use agents_core::agent::{AgentDescriptor, AgentHandle};
616    use agents_core::messaging::{MessageContent, MessageRole};
617    use serde_json::json;
618
619    struct AppendPromptMiddleware;
620
621    #[async_trait]
622    impl AgentMiddleware for AppendPromptMiddleware {
623        fn id(&self) -> &'static str {
624            "append-prompt"
625        }
626
627        async fn modify_model_request(
628            &self,
629            ctx: &mut MiddlewareContext<'_>,
630        ) -> anyhow::Result<()> {
631            ctx.request.system_prompt.push_str("\nExtra directives.");
632            Ok(())
633        }
634    }
635
636    #[tokio::test]
637    async fn middleware_mutates_prompt() {
638        let mut request = ModelRequest::new(
639            "System",
640            vec![AgentMessage {
641                role: MessageRole::User,
642                content: MessageContent::Text("Hi".into()),
643                metadata: None,
644            }],
645        );
646        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
647        let mut ctx = MiddlewareContext::with_request(&mut request, state);
648        let middleware = AppendPromptMiddleware;
649        middleware.modify_model_request(&mut ctx).await.unwrap();
650        assert!(ctx.request.system_prompt.contains("Extra directives"));
651    }
652
653    #[tokio::test]
654    async fn planning_middleware_registers_write_todos() {
655        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
656        let middleware = PlanningMiddleware::new(state);
657        let tool_names: Vec<_> = middleware
658            .tools()
659            .iter()
660            .map(|t| t.schema().name.clone())
661            .collect();
662        assert!(tool_names.contains(&"write_todos".to_string()));
663
664        let mut request = ModelRequest::new("System", vec![]);
665        let mut ctx = MiddlewareContext::with_request(
666            &mut request,
667            Arc::new(RwLock::new(AgentStateSnapshot::default())),
668        );
669        middleware.modify_model_request(&mut ctx).await.unwrap();
670        assert!(ctx.request.system_prompt.contains("write_todos"));
671    }
672
673    #[tokio::test]
674    async fn filesystem_middleware_registers_tools() {
675        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
676        let middleware = FilesystemMiddleware::new(state);
677        let tool_names: Vec<_> = middleware
678            .tools()
679            .iter()
680            .map(|t| t.schema().name.clone())
681            .collect();
682        for expected in ["ls", "read_file", "write_file", "edit_file"] {
683            assert!(tool_names.contains(&expected.to_string()));
684        }
685    }
686
687    #[tokio::test]
688    async fn summarization_middleware_trims_messages() {
689        let middleware = SummarizationMiddleware::new(2, "Summary note");
690        let mut request = ModelRequest::new(
691            "System",
692            vec![
693                AgentMessage {
694                    role: MessageRole::User,
695                    content: MessageContent::Text("one".into()),
696                    metadata: None,
697                },
698                AgentMessage {
699                    role: MessageRole::Agent,
700                    content: MessageContent::Text("two".into()),
701                    metadata: None,
702                },
703                AgentMessage {
704                    role: MessageRole::User,
705                    content: MessageContent::Text("three".into()),
706                    metadata: None,
707                },
708            ],
709        );
710        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
711        let mut ctx = MiddlewareContext::with_request(&mut request, state);
712        middleware.modify_model_request(&mut ctx).await.unwrap();
713        assert_eq!(ctx.request.messages.len(), 3);
714        match &ctx.request.messages[0].content {
715            MessageContent::Text(text) => assert!(text.contains("Summary note")),
716            other => panic!("expected text, got {other:?}"),
717        }
718    }
719
720    struct StubAgent;
721
722    #[async_trait]
723    impl AgentHandle for StubAgent {
724        async fn describe(&self) -> AgentDescriptor {
725            AgentDescriptor {
726                name: "stub".into(),
727                version: "0.0.1".into(),
728                description: None,
729            }
730        }
731
732        async fn handle_message(
733            &self,
734            _input: AgentMessage,
735            _state: Arc<AgentStateSnapshot>,
736        ) -> anyhow::Result<AgentMessage> {
737            Ok(AgentMessage {
738                role: MessageRole::Agent,
739                content: MessageContent::Text("stub-response".into()),
740                metadata: None,
741            })
742        }
743    }
744
745    #[tokio::test]
746    async fn task_router_reports_unknown_subagent() {
747        let registry = Arc::new(SubAgentRegistry::new(vec![]));
748        let task_tool = TaskRouterTool::new(registry.clone());
749        let state = Arc::new(AgentStateSnapshot::default());
750        let ctx = ToolContext::new(state);
751
752        let response = task_tool
753            .execute(
754                json!({
755                    "instruction": "Do something",
756                    "agent": "unknown"
757                }),
758                ctx,
759            )
760            .await
761            .unwrap();
762
763        match response {
764            ToolResult::Message(msg) => match msg.content {
765                MessageContent::Text(text) => {
766                    assert!(text.contains("Sub-agent 'unknown' not found"))
767                }
768                other => panic!("expected text, got {other:?}"),
769            },
770            _ => panic!("expected message"),
771        }
772    }
773
774    #[tokio::test]
775    async fn subagent_middleware_appends_prompt() {
776        let subagents = vec![SubAgentRegistration {
777            descriptor: SubAgentDescriptor {
778                name: "research-agent".into(),
779                description: "Deep research specialist".into(),
780            },
781            agent: Arc::new(StubAgent),
782        }];
783        let middleware = SubAgentMiddleware::new(subagents);
784
785        let mut request = ModelRequest::new("System", vec![]);
786        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
787        let mut ctx = MiddlewareContext::with_request(&mut request, state);
788        middleware.modify_model_request(&mut ctx).await.unwrap();
789
790        assert!(ctx.request.system_prompt.contains("research-agent"));
791        let tool_names: Vec<_> = middleware
792            .tools()
793            .iter()
794            .map(|t| t.schema().name.clone())
795            .collect();
796        assert!(tool_names.contains(&"task".to_string()));
797    }
798
799    #[tokio::test]
800    async fn task_router_invokes_registered_subagent() {
801        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
802            descriptor: SubAgentDescriptor {
803                name: "stub-agent".into(),
804                description: "Stub".into(),
805            },
806            agent: Arc::new(StubAgent),
807        }]));
808        let task_tool = TaskRouterTool::new(registry.clone());
809        let state = Arc::new(AgentStateSnapshot::default());
810        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
811        let response = task_tool
812            .execute(
813                json!({
814                    "description": "do work",
815                    "subagent_type": "stub-agent"
816                }),
817                ctx,
818            )
819            .await
820            .unwrap();
821
822        match response {
823            ToolResult::Message(msg) => {
824                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
825                match msg.content {
826                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
827                    other => panic!("expected text, got {other:?}"),
828                }
829            }
830            _ => panic!("expected message"),
831        }
832    }
833
834    #[tokio::test]
835    async fn human_in_loop_appends_prompt() {
836        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
837            "danger-tool".into(),
838            HitlPolicy {
839                allow_auto: false,
840                note: Some("Requires security review".into()),
841            },
842        )]));
843        let mut request = ModelRequest::new("System", vec![]);
844        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
845        let mut ctx = MiddlewareContext::with_request(&mut request, state);
846        middleware.modify_model_request(&mut ctx).await.unwrap();
847        assert!(ctx
848            .request
849            .system_prompt
850            .contains("danger-tool: Requires security review"));
851    }
852
853    #[tokio::test]
854    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
855        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
856        let mut request = ModelRequest::new(
857            "This is the system prompt",
858            vec![AgentMessage {
859                role: MessageRole::User,
860                content: MessageContent::Text("Hello".into()),
861                metadata: None,
862            }],
863        );
864        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
865        let mut ctx = MiddlewareContext::with_request(&mut request, state);
866
867        // Apply the middleware
868        middleware.modify_model_request(&mut ctx).await.unwrap();
869
870        // System prompt should be cleared
871        assert!(ctx.request.system_prompt.is_empty());
872
873        // Should have added a system message with cache control at the beginning
874        assert_eq!(ctx.request.messages.len(), 2);
875
876        let system_message = &ctx.request.messages[0];
877        assert!(matches!(system_message.role, MessageRole::System));
878        assert_eq!(
879            system_message.content.as_text().unwrap(),
880            "This is the system prompt"
881        );
882
883        // Check cache control metadata
884        let metadata = system_message.metadata.as_ref().unwrap();
885        let cache_control = metadata.cache_control.as_ref().unwrap();
886        assert_eq!(cache_control.cache_type, "ephemeral");
887
888        // Original user message should still be there
889        let user_message = &ctx.request.messages[1];
890        assert!(matches!(user_message.role, MessageRole::User));
891        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
892    }
893
894    #[tokio::test]
895    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
896        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
897        let mut request = ModelRequest::new("This is the system prompt", vec![]);
898        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
899        let mut ctx = MiddlewareContext::with_request(&mut request, state);
900
901        // Apply the middleware
902        middleware.modify_model_request(&mut ctx).await.unwrap();
903
904        // System prompt should be unchanged
905        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
906        assert_eq!(ctx.request.messages.len(), 0);
907    }
908
909    #[tokio::test]
910    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
911        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
912        let mut request = ModelRequest::new(
913            "",
914            vec![AgentMessage {
915                role: MessageRole::User,
916                content: MessageContent::Text("Hello".into()),
917                metadata: None,
918            }],
919        );
920        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
921        let mut ctx = MiddlewareContext::with_request(&mut request, state);
922
923        // Apply the middleware
924        middleware.modify_model_request(&mut ctx).await.unwrap();
925
926        // System prompt should remain empty
927        assert!(ctx.request.system_prompt.is_empty());
928        // No system message should be added
929        assert_eq!(ctx.request.messages.len(), 1);
930    }
931
932    // ========== HITL Interrupt Creation Tests ==========
933
934    #[tokio::test]
935    async fn hitl_creates_interrupt_for_disallowed_tool() {
936        let mut policies = HashMap::new();
937        policies.insert(
938            "dangerous_tool".to_string(),
939            HitlPolicy {
940                allow_auto: false,
941                note: Some("Requires security review".to_string()),
942            },
943        );
944
945        let middleware = HumanInLoopMiddleware::new(policies);
946        let tool_args = json!({"action": "delete_all"});
947
948        let result = middleware
949            .before_tool_execution("dangerous_tool", &tool_args, "call_123")
950            .await
951            .unwrap();
952
953        assert!(result.is_some());
954        let interrupt = result.unwrap();
955
956        match interrupt {
957            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
958                assert_eq!(hitl.tool_name, "dangerous_tool");
959                assert_eq!(hitl.tool_args, tool_args);
960                assert_eq!(hitl.call_id, "call_123");
961                assert_eq!(
962                    hitl.policy_note,
963                    Some("Requires security review".to_string())
964                );
965            }
966        }
967    }
968
969    #[tokio::test]
970    async fn hitl_no_interrupt_for_allowed_tool() {
971        let mut policies = HashMap::new();
972        policies.insert(
973            "safe_tool".to_string(),
974            HitlPolicy {
975                allow_auto: true,
976                note: None,
977            },
978        );
979
980        let middleware = HumanInLoopMiddleware::new(policies);
981        let tool_args = json!({"action": "read"});
982
983        let result = middleware
984            .before_tool_execution("safe_tool", &tool_args, "call_456")
985            .await
986            .unwrap();
987
988        assert!(result.is_none());
989    }
990
991    #[tokio::test]
992    async fn hitl_no_interrupt_for_unlisted_tool() {
993        let policies = HashMap::new();
994        let middleware = HumanInLoopMiddleware::new(policies);
995        let tool_args = json!({"action": "anything"});
996
997        let result = middleware
998            .before_tool_execution("unlisted_tool", &tool_args, "call_789")
999            .await
1000            .unwrap();
1001
1002        assert!(result.is_none());
1003    }
1004
1005    #[tokio::test]
1006    async fn hitl_interrupt_includes_correct_details() {
1007        let mut policies = HashMap::new();
1008        policies.insert(
1009            "critical_tool".to_string(),
1010            HitlPolicy {
1011                allow_auto: false,
1012                note: Some("Critical operation - requires approval".to_string()),
1013            },
1014        );
1015
1016        let middleware = HumanInLoopMiddleware::new(policies);
1017        let tool_args = json!({
1018            "database": "production",
1019            "operation": "drop_table"
1020        });
1021
1022        let result = middleware
1023            .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1024            .await
1025            .unwrap();
1026
1027        assert!(result.is_some());
1028        let interrupt = result.unwrap();
1029
1030        match interrupt {
1031            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1032                assert_eq!(hitl.tool_name, "critical_tool");
1033                assert_eq!(hitl.tool_args["database"], "production");
1034                assert_eq!(hitl.tool_args["operation"], "drop_table");
1035                assert_eq!(hitl.call_id, "call_critical_1");
1036                assert!(hitl.policy_note.is_some());
1037                assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1038                // Verify timestamp exists (created_at field is populated)
1039                // The actual timestamp value is tested in agents-core/hitl.rs
1040            }
1041        }
1042    }
1043
1044    #[tokio::test]
1045    async fn hitl_interrupt_without_policy_note() {
1046        let mut policies = HashMap::new();
1047        policies.insert(
1048            "tool_no_note".to_string(),
1049            HitlPolicy {
1050                allow_auto: false,
1051                note: None,
1052            },
1053        );
1054
1055        let middleware = HumanInLoopMiddleware::new(policies);
1056        let tool_args = json!({"param": "value"});
1057
1058        let result = middleware
1059            .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1060            .await
1061            .unwrap();
1062
1063        assert!(result.is_some());
1064        let interrupt = result.unwrap();
1065
1066        match interrupt {
1067            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1068                assert_eq!(hitl.tool_name, "tool_no_note");
1069                assert_eq!(hitl.policy_note, None);
1070            }
1071        }
1072    }
1073}