agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20/// Request sent to the underlying language model. Middlewares can augment
21/// the system prompt or mutate the pending message list before the model call.
22#[derive(Debug, Clone)]
23pub struct ModelRequest {
24    pub system_prompt: String,
25    pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30        Self {
31            system_prompt: system_prompt.into(),
32            messages,
33        }
34    }
35
36    pub fn append_prompt(&mut self, fragment: &str) {
37        if !fragment.is_empty() {
38            self.system_prompt.push_str("\n\n");
39            self.system_prompt.push_str(fragment);
40        }
41    }
42}
43
44/// Read/write state handle exposed to middleware implementations.
45pub struct MiddlewareContext<'a> {
46    pub request: &'a mut ModelRequest,
47    pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51    pub fn with_request(
52        request: &'a mut ModelRequest,
53        state: Arc<RwLock<AgentStateSnapshot>>,
54    ) -> Self {
55        Self { request, state }
56    }
57}
58
59/// Middleware hook that can register additional tools and mutate the model request
60/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
61/// interface async-first for future network calls.
62#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64    /// Unique identifier for logging and diagnostics.
65    fn id(&self) -> &'static str;
66
67    /// Tools to expose when this middleware is active.
68    fn tools(&self) -> Vec<ToolBox> {
69        Vec::new()
70    }
71
72    /// Apply middleware-specific mutations to the pending model request.
73    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75    /// Hook called before tool execution - can return an interrupt to pause execution.
76    ///
77    /// This hook is invoked for each tool call before it executes, allowing middleware
78    /// to intercept and pause execution for human review. If an interrupt is returned,
79    /// the agent will save its state and wait for human approval before continuing.
80    ///
81    /// # Arguments
82    /// * `tool_name` - Name of the tool about to be executed
83    /// * `tool_args` - Arguments that will be passed to the tool
84    /// * `call_id` - Unique identifier for this tool call
85    ///
86    /// # Returns
87    /// * `Ok(Some(interrupt))` - Pause execution and wait for human response
88    /// * `Ok(None)` - Continue with tool execution normally
89    /// * `Err(e)` - Error occurred during interrupt check
90    async fn before_tool_execution(
91        &self,
92        _tool_name: &str,
93        _tool_args: &serde_json::Value,
94        _call_id: &str,
95    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96        Ok(None)
97    }
98}
99
100pub struct SummarizationMiddleware {
101    pub messages_to_keep: usize,
102    pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107        Self {
108            messages_to_keep,
109            summary_note: summary_note.into(),
110        }
111    }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116    fn id(&self) -> &'static str {
117        "summarization"
118    }
119
120    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121        if ctx.request.messages.len() > self.messages_to_keep {
122            let dropped = ctx.request.messages.len() - self.messages_to_keep;
123            let mut truncated = ctx
124                .request
125                .messages
126                .split_off(ctx.request.messages.len() - self.messages_to_keep);
127            truncated.insert(
128                0,
129                AgentMessage {
130                    role: MessageRole::System,
131                    content: MessageContent::Text(format!(
132                        "{} ({} earlier messages summarized)",
133                        self.summary_note, dropped
134                    )),
135                    metadata: None,
136                },
137            );
138            ctx.request.messages = truncated;
139        }
140        Ok(())
141    }
142}
143
144pub struct PlanningMiddleware {
145    _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150        Self { _state: state }
151    }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156    fn id(&self) -> &'static str {
157        "planning"
158    }
159
160    fn tools(&self) -> Vec<ToolBox> {
161        use agents_toolkit::create_todos_tools;
162        create_todos_tools()
163    }
164
165    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
167        Ok(())
168    }
169}
170
171pub struct FilesystemMiddleware {
172    _state: Arc<RwLock<AgentStateSnapshot>>,
173}
174
175impl FilesystemMiddleware {
176    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
177        Self { _state: state }
178    }
179}
180
181#[async_trait]
182impl AgentMiddleware for FilesystemMiddleware {
183    fn id(&self) -> &'static str {
184        "filesystem"
185    }
186
187    fn tools(&self) -> Vec<ToolBox> {
188        create_filesystem_tools()
189    }
190
191    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
192        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
193        Ok(())
194    }
195}
196
197#[derive(Clone)]
198pub struct SubAgentRegistration {
199    pub descriptor: SubAgentDescriptor,
200    pub agent: Arc<dyn AgentHandle>,
201}
202
203struct SubAgentRegistry {
204    agents: HashMap<String, Arc<dyn AgentHandle>>,
205}
206
207impl SubAgentRegistry {
208    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
209        let mut agents = HashMap::new();
210        for reg in registrations {
211            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
212        }
213        Self { agents }
214    }
215
216    fn available_names(&self) -> Vec<String> {
217        self.agents.keys().cloned().collect()
218    }
219
220    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
221        self.agents.get(name).cloned()
222    }
223}
224
225pub struct SubAgentMiddleware {
226    task_tool: ToolBox,
227    descriptors: Vec<SubAgentDescriptor>,
228    _registry: Arc<SubAgentRegistry>,
229}
230
231impl SubAgentMiddleware {
232    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
233        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
234        let registry = Arc::new(SubAgentRegistry::new(registrations));
235        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
236        Self {
237            task_tool,
238            descriptors,
239            _registry: registry,
240        }
241    }
242
243    pub fn new_with_events(
244        registrations: Vec<SubAgentRegistration>,
245        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
246    ) -> Self {
247        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
248        let registry = Arc::new(SubAgentRegistry::new(registrations));
249        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
250        Self {
251            task_tool,
252            descriptors,
253            _registry: registry,
254        }
255    }
256
257    fn prompt_fragment(&self) -> String {
258        let descriptions: Vec<String> = if self.descriptors.is_empty() {
259            vec![String::from("- general-purpose: Default reasoning agent")]
260        } else {
261            self.descriptors
262                .iter()
263                .map(|agent| format!("- {}: {}", agent.name, agent.description))
264                .collect()
265        };
266
267        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
268    }
269}
270
271#[async_trait]
272impl AgentMiddleware for SubAgentMiddleware {
273    fn id(&self) -> &'static str {
274        "subagent"
275    }
276
277    fn tools(&self) -> Vec<ToolBox> {
278        vec![self.task_tool.clone()]
279    }
280
281    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
282        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
283        ctx.request.append_prompt(&self.prompt_fragment());
284        Ok(())
285    }
286}
287
288#[derive(Clone, Debug)]
289pub struct HitlPolicy {
290    pub allow_auto: bool,
291    pub note: Option<String>,
292}
293
294pub struct HumanInLoopMiddleware {
295    policies: HashMap<String, HitlPolicy>,
296}
297
298impl HumanInLoopMiddleware {
299    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
300        Self { policies }
301    }
302
303    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
304        self.policies
305            .get(tool_name)
306            .filter(|policy| !policy.allow_auto)
307    }
308
309    fn prompt_fragment(&self) -> Option<String> {
310        let pending: Vec<String> = self
311            .policies
312            .iter()
313            .filter(|(_, policy)| !policy.allow_auto)
314            .map(|(tool, policy)| match &policy.note {
315                Some(note) => format!("- {tool}: {note}"),
316                None => format!("- {tool}: Requires approval"),
317            })
318            .collect();
319        if pending.is_empty() {
320            None
321        } else {
322            Some(format!(
323                "The following tools require human approval before execution:\n{}",
324                pending.join("\n")
325            ))
326        }
327    }
328}
329
330#[async_trait]
331impl AgentMiddleware for HumanInLoopMiddleware {
332    fn id(&self) -> &'static str {
333        "human-in-loop"
334    }
335
336    async fn before_tool_execution(
337        &self,
338        tool_name: &str,
339        tool_args: &serde_json::Value,
340        call_id: &str,
341    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
342        if let Some(policy) = self.requires_approval(tool_name) {
343            tracing::warn!(
344                tool_name = %tool_name,
345                call_id = %call_id,
346                policy_note = ?policy.note,
347                "🔒 HITL: Tool execution requires human approval"
348            );
349
350            let interrupt = agents_core::hitl::HitlInterrupt::new(
351                tool_name,
352                tool_args.clone(),
353                call_id,
354                policy.note.clone(),
355            );
356
357            return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
358                interrupt,
359            )));
360        }
361
362        Ok(None)
363    }
364
365    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
366        if let Some(fragment) = self.prompt_fragment() {
367            ctx.request.append_prompt(&fragment);
368        }
369        ctx.request.messages.push(AgentMessage {
370            role: MessageRole::System,
371            content: MessageContent::Text(
372                "Tools marked for human approval will emit interrupts requiring external resolution."
373                    .into(),
374            ),
375            metadata: None,
376        });
377        Ok(())
378    }
379}
380
381pub struct BaseSystemPromptMiddleware;
382
383#[async_trait]
384impl AgentMiddleware for BaseSystemPromptMiddleware {
385    fn id(&self) -> &'static str {
386        "base-system-prompt"
387    }
388
389    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
390        ctx.request.append_prompt(BASE_AGENT_PROMPT);
391        Ok(())
392    }
393}
394
395/// Deep Agent prompt middleware that injects comprehensive tool usage instructions
396/// and examples to force the LLM to actually call tools instead of just talking about them.
397///
398/// This middleware is inspired by Python's deepagents package and Claude Code's system prompt.
399/// It provides:
400/// - Explicit tool usage rules with imperative language
401/// - JSON examples of tool calling
402/// - Workflow guidance for multi-step tasks
403/// - Few-shot examples for common patterns
404pub struct DeepAgentPromptMiddleware {
405    custom_instructions: String,
406}
407
408impl DeepAgentPromptMiddleware {
409    pub fn new(custom_instructions: impl Into<String>) -> Self {
410        Self {
411            custom_instructions: custom_instructions.into(),
412        }
413    }
414}
415
416#[async_trait]
417impl AgentMiddleware for DeepAgentPromptMiddleware {
418    fn id(&self) -> &'static str {
419        "deep-agent-prompt"
420    }
421
422    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
423        use crate::prompts::get_deep_agent_system_prompt;
424        let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
425        ctx.request.append_prompt(&deep_prompt);
426        Ok(())
427    }
428}
429
430/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
431/// to reduce latency on subsequent requests with the same base prompt.
432pub struct AnthropicPromptCachingMiddleware {
433    pub ttl: String,
434    pub unsupported_model_behavior: String,
435}
436
437impl AnthropicPromptCachingMiddleware {
438    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
439        Self {
440            ttl: ttl.into(),
441            unsupported_model_behavior: unsupported_model_behavior.into(),
442        }
443    }
444
445    pub fn with_defaults() -> Self {
446        Self::new("5m", "ignore")
447    }
448
449    /// Parse TTL string like "5m" to detect if caching is requested.
450    /// For now, any non-empty TTL enables ephemeral caching.
451    fn should_enable_caching(&self) -> bool {
452        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
453    }
454}
455
456#[async_trait]
457impl AgentMiddleware for AnthropicPromptCachingMiddleware {
458    fn id(&self) -> &'static str {
459        "anthropic-prompt-caching"
460    }
461
462    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
463        if !self.should_enable_caching() {
464            return Ok(());
465        }
466
467        // Mark system prompt for caching by converting it to a system message with cache control
468        if !ctx.request.system_prompt.is_empty() {
469            let system_message = AgentMessage {
470                role: MessageRole::System,
471                content: MessageContent::Text(ctx.request.system_prompt.clone()),
472                metadata: Some(MessageMetadata {
473                    tool_call_id: None,
474                    cache_control: Some(CacheControl {
475                        cache_type: "ephemeral".to_string(),
476                    }),
477                }),
478            };
479
480            // Insert system message at the beginning of the messages
481            ctx.request.messages.insert(0, system_message);
482
483            // Clear the system_prompt since it's now in messages
484            ctx.request.system_prompt.clear();
485
486            tracing::debug!(
487                ttl = %self.ttl,
488                behavior = %self.unsupported_model_behavior,
489                "Applied Anthropic prompt caching to system message"
490            );
491        }
492
493        Ok(())
494    }
495}
496
497pub struct TaskRouterTool {
498    registry: Arc<SubAgentRegistry>,
499    event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
500    delegation_depth: Arc<RwLock<u32>>,
501}
502
503impl TaskRouterTool {
504    fn new(
505        registry: Arc<SubAgentRegistry>,
506        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
507    ) -> Self {
508        Self {
509            registry,
510            event_dispatcher,
511            delegation_depth: Arc::new(RwLock::new(0)),
512        }
513    }
514
515    fn available_subagents(&self) -> Vec<String> {
516        self.registry.available_names()
517    }
518
519    fn emit_event(&self, event: agents_core::events::AgentEvent) {
520        if let Some(dispatcher) = &self.event_dispatcher {
521            let dispatcher_clone = dispatcher.clone();
522            tokio::spawn(async move {
523                dispatcher_clone.dispatch(event).await;
524            });
525        }
526    }
527
528    fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
529        agents_core::events::EventMetadata::new(
530            "default".to_string(),
531            uuid::Uuid::new_v4().to_string(),
532            None,
533        )
534    }
535
536    fn get_delegation_depth(&self) -> u32 {
537        *self.delegation_depth.read().unwrap_or_else(|_| {
538            tracing::warn!("Failed to read delegation depth, defaulting to 0");
539            panic!("RwLock poisoned")
540        })
541    }
542
543    fn increment_delegation_depth(&self) {
544        if let Ok(mut depth) = self.delegation_depth.write() {
545            *depth += 1;
546        }
547    }
548
549    fn decrement_delegation_depth(&self) {
550        if let Ok(mut depth) = self.delegation_depth.write() {
551            if *depth > 0 {
552                *depth -= 1;
553            }
554        }
555    }
556}
557
558#[derive(Debug, Clone, Deserialize)]
559struct TaskInvocationArgs {
560    #[serde(alias = "description")]
561    instruction: String,
562    #[serde(alias = "subagent_type")]
563    agent: String,
564}
565
566#[async_trait]
567impl Tool for TaskRouterTool {
568    fn schema(&self) -> agents_core::tools::ToolSchema {
569        use agents_core::tools::{ToolParameterSchema, ToolSchema};
570        use std::collections::HashMap;
571
572        let mut properties = HashMap::new();
573        properties.insert(
574            "agent".to_string(),
575            ToolParameterSchema::string("Name of the sub-agent to delegate to"),
576        );
577        properties.insert(
578            "instruction".to_string(),
579            ToolParameterSchema::string("Clear instruction for the sub-agent"),
580        );
581
582        ToolSchema::new(
583            "task",
584            "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
585            ToolParameterSchema::object(
586                "Task delegation parameters",
587                properties,
588                vec!["agent".to_string(), "instruction".to_string()],
589            ),
590        )
591    }
592
593    async fn execute(
594        &self,
595        args: serde_json::Value,
596        ctx: ToolContext,
597    ) -> anyhow::Result<ToolResult> {
598        let args: TaskInvocationArgs = serde_json::from_value(args)?;
599        let available = self.available_subagents();
600
601        if let Some(agent) = self.registry.get(&args.agent) {
602            // Increment delegation depth
603            self.increment_delegation_depth();
604            let current_depth = self.get_delegation_depth();
605
606            // Truncate instruction for event
607            let instruction_summary = if args.instruction.len() > 100 {
608                format!("{}...", &args.instruction[..100])
609            } else {
610                args.instruction.clone()
611            };
612
613            // Emit: SubAgentStarted event
614            self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
615                agents_core::events::SubAgentStartedEvent {
616                    metadata: self.create_event_metadata(),
617                    agent_name: args.agent.clone(),
618                    instruction_summary: instruction_summary.clone(),
619                    delegation_depth: current_depth,
620                },
621            ));
622
623            // Log delegation start
624            tracing::warn!(
625                "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
626                args.agent,
627                current_depth,
628                args.instruction
629            );
630
631            let start_time = std::time::Instant::now();
632            let user_message = AgentMessage {
633                role: MessageRole::User,
634                content: MessageContent::Text(args.instruction.clone()),
635                metadata: None,
636            };
637
638            let response = agent
639                .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
640                .await?;
641
642            // Calculate duration
643            let duration = start_time.elapsed();
644            let duration_ms = duration.as_millis() as u64;
645
646            // Create response preview
647            let response_preview = match &response.content {
648                MessageContent::Text(t) => {
649                    if t.len() > 100 {
650                        format!("{}...", &t[..100])
651                    } else {
652                        t.clone()
653                    }
654                }
655                MessageContent::Json(v) => {
656                    let json_str = v.to_string();
657                    if json_str.len() > 100 {
658                        format!("{}...", &json_str[..100])
659                    } else {
660                        json_str
661                    }
662                }
663            };
664
665            // Emit: SubAgentCompleted event
666            self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
667                agents_core::events::SubAgentCompletedEvent {
668                    metadata: self.create_event_metadata(),
669                    agent_name: args.agent.clone(),
670                    duration_ms,
671                    result_summary: response_preview.clone(),
672                },
673            ));
674
675            // Log delegation completion
676            tracing::warn!(
677                "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
678                args.agent,
679                duration,
680                response_preview
681            );
682
683            // Decrement delegation depth
684            self.decrement_delegation_depth();
685
686            // Return sub-agent response as text content, not as a separate tool message
687            // This will be incorporated into the LLM's next response naturally
688            let result_text = match response.content {
689                MessageContent::Text(text) => text,
690                MessageContent::Json(json) => json.to_string(),
691            };
692
693            return Ok(ToolResult::text(&ctx, result_text));
694        }
695
696        tracing::error!(
697            "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
698            args.agent,
699            available
700        );
701
702        Ok(ToolResult::text(
703            &ctx,
704            format!(
705                "Sub-agent '{}' not found. Available sub-agents: {}",
706                args.agent,
707                available.join(", ")
708            ),
709        ))
710    }
711}
712
713#[derive(Debug, Clone)]
714pub struct SubAgentDescriptor {
715    pub name: String,
716    pub description: String,
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use agents_core::agent::{AgentDescriptor, AgentHandle};
723    use agents_core::messaging::{MessageContent, MessageRole};
724    use serde_json::json;
725
726    struct AppendPromptMiddleware;
727
728    #[async_trait]
729    impl AgentMiddleware for AppendPromptMiddleware {
730        fn id(&self) -> &'static str {
731            "append-prompt"
732        }
733
734        async fn modify_model_request(
735            &self,
736            ctx: &mut MiddlewareContext<'_>,
737        ) -> anyhow::Result<()> {
738            ctx.request.system_prompt.push_str("\nExtra directives.");
739            Ok(())
740        }
741    }
742
743    #[tokio::test]
744    async fn middleware_mutates_prompt() {
745        let mut request = ModelRequest::new(
746            "System",
747            vec![AgentMessage {
748                role: MessageRole::User,
749                content: MessageContent::Text("Hi".into()),
750                metadata: None,
751            }],
752        );
753        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
754        let mut ctx = MiddlewareContext::with_request(&mut request, state);
755        let middleware = AppendPromptMiddleware;
756        middleware.modify_model_request(&mut ctx).await.unwrap();
757        assert!(ctx.request.system_prompt.contains("Extra directives"));
758    }
759
760    #[tokio::test]
761    async fn planning_middleware_registers_write_todos() {
762        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
763        let middleware = PlanningMiddleware::new(state);
764        let tool_names: Vec<_> = middleware
765            .tools()
766            .iter()
767            .map(|t| t.schema().name.clone())
768            .collect();
769        assert!(tool_names.contains(&"write_todos".to_string()));
770
771        let mut request = ModelRequest::new("System", vec![]);
772        let mut ctx = MiddlewareContext::with_request(
773            &mut request,
774            Arc::new(RwLock::new(AgentStateSnapshot::default())),
775        );
776        middleware.modify_model_request(&mut ctx).await.unwrap();
777        assert!(ctx.request.system_prompt.contains("write_todos"));
778    }
779
780    #[tokio::test]
781    async fn filesystem_middleware_registers_tools() {
782        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
783        let middleware = FilesystemMiddleware::new(state);
784        let tool_names: Vec<_> = middleware
785            .tools()
786            .iter()
787            .map(|t| t.schema().name.clone())
788            .collect();
789        for expected in ["ls", "read_file", "write_file", "edit_file"] {
790            assert!(tool_names.contains(&expected.to_string()));
791        }
792    }
793
794    #[tokio::test]
795    async fn summarization_middleware_trims_messages() {
796        let middleware = SummarizationMiddleware::new(2, "Summary note");
797        let mut request = ModelRequest::new(
798            "System",
799            vec![
800                AgentMessage {
801                    role: MessageRole::User,
802                    content: MessageContent::Text("one".into()),
803                    metadata: None,
804                },
805                AgentMessage {
806                    role: MessageRole::Agent,
807                    content: MessageContent::Text("two".into()),
808                    metadata: None,
809                },
810                AgentMessage {
811                    role: MessageRole::User,
812                    content: MessageContent::Text("three".into()),
813                    metadata: None,
814                },
815            ],
816        );
817        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
818        let mut ctx = MiddlewareContext::with_request(&mut request, state);
819        middleware.modify_model_request(&mut ctx).await.unwrap();
820        assert_eq!(ctx.request.messages.len(), 3);
821        match &ctx.request.messages[0].content {
822            MessageContent::Text(text) => assert!(text.contains("Summary note")),
823            other => panic!("expected text, got {other:?}"),
824        }
825    }
826
827    struct StubAgent;
828
829    #[async_trait]
830    impl AgentHandle for StubAgent {
831        async fn describe(&self) -> AgentDescriptor {
832            AgentDescriptor {
833                name: "stub".into(),
834                version: "0.0.1".into(),
835                description: None,
836            }
837        }
838
839        async fn handle_message(
840            &self,
841            _input: AgentMessage,
842            _state: Arc<AgentStateSnapshot>,
843        ) -> anyhow::Result<AgentMessage> {
844            Ok(AgentMessage {
845                role: MessageRole::Agent,
846                content: MessageContent::Text("stub-response".into()),
847                metadata: None,
848            })
849        }
850    }
851
852    #[tokio::test]
853    async fn task_router_reports_unknown_subagent() {
854        let registry = Arc::new(SubAgentRegistry::new(vec![]));
855        let task_tool = TaskRouterTool::new(registry.clone(), None);
856        let state = Arc::new(AgentStateSnapshot::default());
857        let ctx = ToolContext::new(state);
858
859        let response = task_tool
860            .execute(
861                json!({
862                    "instruction": "Do something",
863                    "agent": "unknown"
864                }),
865                ctx,
866            )
867            .await
868            .unwrap();
869
870        match response {
871            ToolResult::Message(msg) => match msg.content {
872                MessageContent::Text(text) => {
873                    assert!(text.contains("Sub-agent 'unknown' not found"))
874                }
875                other => panic!("expected text, got {other:?}"),
876            },
877            _ => panic!("expected message"),
878        }
879    }
880
881    #[tokio::test]
882    async fn subagent_middleware_appends_prompt() {
883        let subagents = vec![SubAgentRegistration {
884            descriptor: SubAgentDescriptor {
885                name: "research-agent".into(),
886                description: "Deep research specialist".into(),
887            },
888            agent: Arc::new(StubAgent),
889        }];
890        let middleware = SubAgentMiddleware::new(subagents);
891
892        let mut request = ModelRequest::new("System", vec![]);
893        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
894        let mut ctx = MiddlewareContext::with_request(&mut request, state);
895        middleware.modify_model_request(&mut ctx).await.unwrap();
896
897        assert!(ctx.request.system_prompt.contains("research-agent"));
898        let tool_names: Vec<_> = middleware
899            .tools()
900            .iter()
901            .map(|t| t.schema().name.clone())
902            .collect();
903        assert!(tool_names.contains(&"task".to_string()));
904    }
905
906    #[tokio::test]
907    async fn task_router_invokes_registered_subagent() {
908        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
909            descriptor: SubAgentDescriptor {
910                name: "stub-agent".into(),
911                description: "Stub".into(),
912            },
913            agent: Arc::new(StubAgent),
914        }]));
915        let task_tool = TaskRouterTool::new(registry.clone(), None);
916        let state = Arc::new(AgentStateSnapshot::default());
917        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
918        let response = task_tool
919            .execute(
920                json!({
921                    "description": "do work",
922                    "subagent_type": "stub-agent"
923                }),
924                ctx,
925            )
926            .await
927            .unwrap();
928
929        match response {
930            ToolResult::Message(msg) => {
931                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
932                match msg.content {
933                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
934                    other => panic!("expected text, got {other:?}"),
935                }
936            }
937            _ => panic!("expected message"),
938        }
939    }
940
941    #[tokio::test]
942    async fn human_in_loop_appends_prompt() {
943        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
944            "danger-tool".into(),
945            HitlPolicy {
946                allow_auto: false,
947                note: Some("Requires security review".into()),
948            },
949        )]));
950        let mut request = ModelRequest::new("System", vec![]);
951        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
952        let mut ctx = MiddlewareContext::with_request(&mut request, state);
953        middleware.modify_model_request(&mut ctx).await.unwrap();
954        assert!(ctx
955            .request
956            .system_prompt
957            .contains("danger-tool: Requires security review"));
958    }
959
960    #[tokio::test]
961    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
962        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
963        let mut request = ModelRequest::new(
964            "This is the system prompt",
965            vec![AgentMessage {
966                role: MessageRole::User,
967                content: MessageContent::Text("Hello".into()),
968                metadata: None,
969            }],
970        );
971        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
972        let mut ctx = MiddlewareContext::with_request(&mut request, state);
973
974        // Apply the middleware
975        middleware.modify_model_request(&mut ctx).await.unwrap();
976
977        // System prompt should be cleared
978        assert!(ctx.request.system_prompt.is_empty());
979
980        // Should have added a system message with cache control at the beginning
981        assert_eq!(ctx.request.messages.len(), 2);
982
983        let system_message = &ctx.request.messages[0];
984        assert!(matches!(system_message.role, MessageRole::System));
985        assert_eq!(
986            system_message.content.as_text().unwrap(),
987            "This is the system prompt"
988        );
989
990        // Check cache control metadata
991        let metadata = system_message.metadata.as_ref().unwrap();
992        let cache_control = metadata.cache_control.as_ref().unwrap();
993        assert_eq!(cache_control.cache_type, "ephemeral");
994
995        // Original user message should still be there
996        let user_message = &ctx.request.messages[1];
997        assert!(matches!(user_message.role, MessageRole::User));
998        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
999    }
1000
1001    #[tokio::test]
1002    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1003        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1004        let mut request = ModelRequest::new("This is the system prompt", vec![]);
1005        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1006        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1007
1008        // Apply the middleware
1009        middleware.modify_model_request(&mut ctx).await.unwrap();
1010
1011        // System prompt should be unchanged
1012        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1013        assert_eq!(ctx.request.messages.len(), 0);
1014    }
1015
1016    #[tokio::test]
1017    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1018        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1019        let mut request = ModelRequest::new(
1020            "",
1021            vec![AgentMessage {
1022                role: MessageRole::User,
1023                content: MessageContent::Text("Hello".into()),
1024                metadata: None,
1025            }],
1026        );
1027        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1028        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1029
1030        // Apply the middleware
1031        middleware.modify_model_request(&mut ctx).await.unwrap();
1032
1033        // System prompt should remain empty
1034        assert!(ctx.request.system_prompt.is_empty());
1035        // No system message should be added
1036        assert_eq!(ctx.request.messages.len(), 1);
1037    }
1038
1039    // ========== HITL Interrupt Creation Tests ==========
1040
1041    #[tokio::test]
1042    async fn hitl_creates_interrupt_for_disallowed_tool() {
1043        let mut policies = HashMap::new();
1044        policies.insert(
1045            "dangerous_tool".to_string(),
1046            HitlPolicy {
1047                allow_auto: false,
1048                note: Some("Requires security review".to_string()),
1049            },
1050        );
1051
1052        let middleware = HumanInLoopMiddleware::new(policies);
1053        let tool_args = json!({"action": "delete_all"});
1054
1055        let result = middleware
1056            .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1057            .await
1058            .unwrap();
1059
1060        assert!(result.is_some());
1061        let interrupt = result.unwrap();
1062
1063        match interrupt {
1064            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1065                assert_eq!(hitl.tool_name, "dangerous_tool");
1066                assert_eq!(hitl.tool_args, tool_args);
1067                assert_eq!(hitl.call_id, "call_123");
1068                assert_eq!(
1069                    hitl.policy_note,
1070                    Some("Requires security review".to_string())
1071                );
1072            }
1073        }
1074    }
1075
1076    #[tokio::test]
1077    async fn hitl_no_interrupt_for_allowed_tool() {
1078        let mut policies = HashMap::new();
1079        policies.insert(
1080            "safe_tool".to_string(),
1081            HitlPolicy {
1082                allow_auto: true,
1083                note: None,
1084            },
1085        );
1086
1087        let middleware = HumanInLoopMiddleware::new(policies);
1088        let tool_args = json!({"action": "read"});
1089
1090        let result = middleware
1091            .before_tool_execution("safe_tool", &tool_args, "call_456")
1092            .await
1093            .unwrap();
1094
1095        assert!(result.is_none());
1096    }
1097
1098    #[tokio::test]
1099    async fn hitl_no_interrupt_for_unlisted_tool() {
1100        let policies = HashMap::new();
1101        let middleware = HumanInLoopMiddleware::new(policies);
1102        let tool_args = json!({"action": "anything"});
1103
1104        let result = middleware
1105            .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1106            .await
1107            .unwrap();
1108
1109        assert!(result.is_none());
1110    }
1111
1112    #[tokio::test]
1113    async fn hitl_interrupt_includes_correct_details() {
1114        let mut policies = HashMap::new();
1115        policies.insert(
1116            "critical_tool".to_string(),
1117            HitlPolicy {
1118                allow_auto: false,
1119                note: Some("Critical operation - requires approval".to_string()),
1120            },
1121        );
1122
1123        let middleware = HumanInLoopMiddleware::new(policies);
1124        let tool_args = json!({
1125            "database": "production",
1126            "operation": "drop_table"
1127        });
1128
1129        let result = middleware
1130            .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1131            .await
1132            .unwrap();
1133
1134        assert!(result.is_some());
1135        let interrupt = result.unwrap();
1136
1137        match interrupt {
1138            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1139                assert_eq!(hitl.tool_name, "critical_tool");
1140                assert_eq!(hitl.tool_args["database"], "production");
1141                assert_eq!(hitl.tool_args["operation"], "drop_table");
1142                assert_eq!(hitl.call_id, "call_critical_1");
1143                assert!(hitl.policy_note.is_some());
1144                assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1145                // Verify timestamp exists (created_at field is populated)
1146                // The actual timestamp value is tested in agents-core/hitl.rs
1147            }
1148        }
1149    }
1150
1151    #[tokio::test]
1152    async fn hitl_interrupt_without_policy_note() {
1153        let mut policies = HashMap::new();
1154        policies.insert(
1155            "tool_no_note".to_string(),
1156            HitlPolicy {
1157                allow_auto: false,
1158                note: None,
1159            },
1160        );
1161
1162        let middleware = HumanInLoopMiddleware::new(policies);
1163        let tool_args = json!({"param": "value"});
1164
1165        let result = middleware
1166            .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1167            .await
1168            .unwrap();
1169
1170        assert!(result.is_some());
1171        let interrupt = result.unwrap();
1172
1173        match interrupt {
1174            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1175                assert_eq!(hitl.tool_name, "tool_no_note");
1176                assert_eq!(hitl.policy_note, None);
1177            }
1178        }
1179    }
1180}