1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6 AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9 BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10 WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18#[derive(Debug, Clone)]
21pub struct ModelRequest {
22 pub system_prompt: String,
23 pub messages: Vec<AgentMessage>,
24}
25
26impl ModelRequest {
27 pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
28 Self {
29 system_prompt: system_prompt.into(),
30 messages,
31 }
32 }
33
34 pub fn append_prompt(&mut self, fragment: &str) {
35 if !fragment.is_empty() {
36 self.system_prompt.push_str("\n\n");
37 self.system_prompt.push_str(fragment);
38 }
39 }
40}
41
42pub struct MiddlewareContext<'a> {
44 pub request: &'a mut ModelRequest,
45 pub state: Arc<RwLock<AgentStateSnapshot>>,
46}
47
48impl<'a> MiddlewareContext<'a> {
49 pub fn with_request(
50 request: &'a mut ModelRequest,
51 state: Arc<RwLock<AgentStateSnapshot>>,
52 ) -> Self {
53 Self { request, state }
54 }
55}
56
57#[async_trait]
61pub trait AgentMiddleware: Send + Sync {
62 fn id(&self) -> &'static str;
64
65 fn tools(&self) -> Vec<ToolBox> {
67 Vec::new()
68 }
69
70 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
72
73 async fn before_tool_execution(
89 &self,
90 _tool_name: &str,
91 _tool_args: &serde_json::Value,
92 _call_id: &str,
93 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
94 Ok(None)
95 }
96}
97
98pub struct SummarizationMiddleware {
99 pub messages_to_keep: usize,
100 pub summary_note: String,
101}
102
103impl SummarizationMiddleware {
104 pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
105 Self {
106 messages_to_keep,
107 summary_note: summary_note.into(),
108 }
109 }
110}
111
112#[async_trait]
113impl AgentMiddleware for SummarizationMiddleware {
114 fn id(&self) -> &'static str {
115 "summarization"
116 }
117
118 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
119 if ctx.request.messages.len() > self.messages_to_keep {
120 let dropped = ctx.request.messages.len() - self.messages_to_keep;
121 let mut truncated = ctx
122 .request
123 .messages
124 .split_off(ctx.request.messages.len() - self.messages_to_keep);
125 truncated.insert(
126 0,
127 AgentMessage {
128 role: MessageRole::System,
129 content: MessageContent::Text(format!(
130 "{} ({} earlier messages summarized)",
131 self.summary_note, dropped
132 )),
133 metadata: None,
134 },
135 );
136 ctx.request.messages = truncated;
137 }
138 Ok(())
139 }
140}
141
142pub struct PlanningMiddleware {
143 _state: Arc<RwLock<AgentStateSnapshot>>,
144}
145
146impl PlanningMiddleware {
147 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
148 Self { _state: state }
149 }
150}
151
152#[async_trait]
153impl AgentMiddleware for PlanningMiddleware {
154 fn id(&self) -> &'static str {
155 "planning"
156 }
157
158 fn tools(&self) -> Vec<ToolBox> {
159 use agents_toolkit::create_todos_tools;
160 create_todos_tools()
161 }
162
163 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
164 ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
165 Ok(())
166 }
167}
168
169pub struct FilesystemMiddleware {
170 _state: Arc<RwLock<AgentStateSnapshot>>,
171}
172
173impl FilesystemMiddleware {
174 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
175 Self { _state: state }
176 }
177}
178
179#[async_trait]
180impl AgentMiddleware for FilesystemMiddleware {
181 fn id(&self) -> &'static str {
182 "filesystem"
183 }
184
185 fn tools(&self) -> Vec<ToolBox> {
186 create_filesystem_tools()
187 }
188
189 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
190 ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
191 Ok(())
192 }
193}
194
195#[derive(Clone)]
196pub struct SubAgentRegistration {
197 pub descriptor: SubAgentDescriptor,
198 pub agent: Arc<dyn AgentHandle>,
199}
200
201struct SubAgentRegistry {
202 agents: HashMap<String, Arc<dyn AgentHandle>>,
203}
204
205impl SubAgentRegistry {
206 fn new(registrations: Vec<SubAgentRegistration>) -> Self {
207 let mut agents = HashMap::new();
208 for reg in registrations {
209 agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
210 }
211 Self { agents }
212 }
213
214 fn available_names(&self) -> Vec<String> {
215 self.agents.keys().cloned().collect()
216 }
217
218 fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
219 self.agents.get(name).cloned()
220 }
221}
222
223pub struct SubAgentMiddleware {
224 task_tool: ToolBox,
225 descriptors: Vec<SubAgentDescriptor>,
226 _registry: Arc<SubAgentRegistry>,
227}
228
229impl SubAgentMiddleware {
230 pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
231 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
232 let registry = Arc::new(SubAgentRegistry::new(registrations));
233 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone()));
234 Self {
235 task_tool,
236 descriptors,
237 _registry: registry,
238 }
239 }
240
241 fn prompt_fragment(&self) -> String {
242 let descriptions: Vec<String> = if self.descriptors.is_empty() {
243 vec![String::from("- general-purpose: Default reasoning agent")]
244 } else {
245 self.descriptors
246 .iter()
247 .map(|agent| format!("- {}: {}", agent.name, agent.description))
248 .collect()
249 };
250
251 TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
252 }
253}
254
255#[async_trait]
256impl AgentMiddleware for SubAgentMiddleware {
257 fn id(&self) -> &'static str {
258 "subagent"
259 }
260
261 fn tools(&self) -> Vec<ToolBox> {
262 vec![self.task_tool.clone()]
263 }
264
265 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
266 ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
267 ctx.request.append_prompt(&self.prompt_fragment());
268 Ok(())
269 }
270}
271
272#[derive(Clone, Debug)]
273pub struct HitlPolicy {
274 pub allow_auto: bool,
275 pub note: Option<String>,
276}
277
278pub struct HumanInLoopMiddleware {
279 policies: HashMap<String, HitlPolicy>,
280}
281
282impl HumanInLoopMiddleware {
283 pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
284 Self { policies }
285 }
286
287 pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
288 self.policies
289 .get(tool_name)
290 .filter(|policy| !policy.allow_auto)
291 }
292
293 fn prompt_fragment(&self) -> Option<String> {
294 let pending: Vec<String> = self
295 .policies
296 .iter()
297 .filter(|(_, policy)| !policy.allow_auto)
298 .map(|(tool, policy)| match &policy.note {
299 Some(note) => format!("- {tool}: {note}"),
300 None => format!("- {tool}: Requires approval"),
301 })
302 .collect();
303 if pending.is_empty() {
304 None
305 } else {
306 Some(format!(
307 "The following tools require human approval before execution:\n{}",
308 pending.join("\n")
309 ))
310 }
311 }
312}
313
314#[async_trait]
315impl AgentMiddleware for HumanInLoopMiddleware {
316 fn id(&self) -> &'static str {
317 "human-in-loop"
318 }
319
320 async fn before_tool_execution(
321 &self,
322 tool_name: &str,
323 tool_args: &serde_json::Value,
324 call_id: &str,
325 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
326 if let Some(policy) = self.requires_approval(tool_name) {
327 tracing::warn!(
328 tool_name = %tool_name,
329 call_id = %call_id,
330 policy_note = ?policy.note,
331 "🔒 HITL: Tool execution requires human approval"
332 );
333
334 let interrupt = agents_core::hitl::HitlInterrupt::new(
335 tool_name,
336 tool_args.clone(),
337 call_id,
338 policy.note.clone(),
339 );
340
341 return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
342 interrupt,
343 )));
344 }
345
346 Ok(None)
347 }
348
349 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
350 if let Some(fragment) = self.prompt_fragment() {
351 ctx.request.append_prompt(&fragment);
352 }
353 ctx.request.messages.push(AgentMessage {
354 role: MessageRole::System,
355 content: MessageContent::Text(
356 "Tools marked for human approval will emit interrupts requiring external resolution."
357 .into(),
358 ),
359 metadata: None,
360 });
361 Ok(())
362 }
363}
364
365pub struct BaseSystemPromptMiddleware;
366
367#[async_trait]
368impl AgentMiddleware for BaseSystemPromptMiddleware {
369 fn id(&self) -> &'static str {
370 "base-system-prompt"
371 }
372
373 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
374 ctx.request.append_prompt(BASE_AGENT_PROMPT);
375 Ok(())
376 }
377}
378
379pub struct DeepAgentPromptMiddleware {
389 custom_instructions: String,
390}
391
392impl DeepAgentPromptMiddleware {
393 pub fn new(custom_instructions: impl Into<String>) -> Self {
394 Self {
395 custom_instructions: custom_instructions.into(),
396 }
397 }
398}
399
400#[async_trait]
401impl AgentMiddleware for DeepAgentPromptMiddleware {
402 fn id(&self) -> &'static str {
403 "deep-agent-prompt"
404 }
405
406 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
407 use crate::prompts::get_deep_agent_system_prompt;
408 let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
409 ctx.request.append_prompt(&deep_prompt);
410 Ok(())
411 }
412}
413
414pub struct AnthropicPromptCachingMiddleware {
417 pub ttl: String,
418 pub unsupported_model_behavior: String,
419}
420
421impl AnthropicPromptCachingMiddleware {
422 pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
423 Self {
424 ttl: ttl.into(),
425 unsupported_model_behavior: unsupported_model_behavior.into(),
426 }
427 }
428
429 pub fn with_defaults() -> Self {
430 Self::new("5m", "ignore")
431 }
432
433 fn should_enable_caching(&self) -> bool {
436 !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
437 }
438}
439
440#[async_trait]
441impl AgentMiddleware for AnthropicPromptCachingMiddleware {
442 fn id(&self) -> &'static str {
443 "anthropic-prompt-caching"
444 }
445
446 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
447 if !self.should_enable_caching() {
448 return Ok(());
449 }
450
451 if !ctx.request.system_prompt.is_empty() {
453 let system_message = AgentMessage {
454 role: MessageRole::System,
455 content: MessageContent::Text(ctx.request.system_prompt.clone()),
456 metadata: Some(MessageMetadata {
457 tool_call_id: None,
458 cache_control: Some(CacheControl {
459 cache_type: "ephemeral".to_string(),
460 }),
461 }),
462 };
463
464 ctx.request.messages.insert(0, system_message);
466
467 ctx.request.system_prompt.clear();
469
470 tracing::debug!(
471 ttl = %self.ttl,
472 behavior = %self.unsupported_model_behavior,
473 "Applied Anthropic prompt caching to system message"
474 );
475 }
476
477 Ok(())
478 }
479}
480
481pub struct TaskRouterTool {
482 registry: Arc<SubAgentRegistry>,
483}
484
485impl TaskRouterTool {
486 fn new(registry: Arc<SubAgentRegistry>) -> Self {
487 Self { registry }
488 }
489
490 fn available_subagents(&self) -> Vec<String> {
491 self.registry.available_names()
492 }
493}
494
495#[derive(Debug, Clone, Deserialize)]
496struct TaskInvocationArgs {
497 #[serde(alias = "description")]
498 instruction: String,
499 #[serde(alias = "subagent_type")]
500 agent: String,
501}
502
503#[async_trait]
504impl Tool for TaskRouterTool {
505 fn schema(&self) -> agents_core::tools::ToolSchema {
506 use agents_core::tools::{ToolParameterSchema, ToolSchema};
507 use std::collections::HashMap;
508
509 let mut properties = HashMap::new();
510 properties.insert(
511 "agent".to_string(),
512 ToolParameterSchema::string("Name of the sub-agent to delegate to"),
513 );
514 properties.insert(
515 "instruction".to_string(),
516 ToolParameterSchema::string("Clear instruction for the sub-agent"),
517 );
518
519 ToolSchema::new(
520 "task",
521 "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
522 ToolParameterSchema::object(
523 "Task delegation parameters",
524 properties,
525 vec!["agent".to_string(), "instruction".to_string()],
526 ),
527 )
528 }
529
530 async fn execute(
531 &self,
532 args: serde_json::Value,
533 ctx: ToolContext,
534 ) -> anyhow::Result<ToolResult> {
535 let args: TaskInvocationArgs = serde_json::from_value(args)?;
536 let available = self.available_subagents();
537
538 if let Some(agent) = self.registry.get(&args.agent) {
539 tracing::warn!(
541 "🎯 DELEGATING to sub-agent: {} with instruction: {}",
542 args.agent,
543 args.instruction
544 );
545
546 let start_time = std::time::Instant::now();
547 let user_message = AgentMessage {
548 role: MessageRole::User,
549 content: MessageContent::Text(args.instruction.clone()),
550 metadata: None,
551 };
552
553 let response = agent
554 .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
555 .await?;
556
557 let duration = start_time.elapsed();
559 let response_preview = match &response.content {
560 MessageContent::Text(t) => {
561 if t.len() > 100 {
562 format!("{}... ({} chars)", &t[..100], t.len())
563 } else {
564 t.clone()
565 }
566 }
567 MessageContent::Json(v) => {
568 format!("JSON: {} bytes", v.to_string().len())
569 }
570 };
571
572 tracing::warn!(
573 "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
574 args.agent,
575 duration,
576 response_preview
577 );
578
579 let result_text = match response.content {
582 MessageContent::Text(text) => text,
583 MessageContent::Json(json) => json.to_string(),
584 };
585
586 return Ok(ToolResult::text(&ctx, result_text));
587 }
588
589 tracing::error!(
590 "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
591 args.agent,
592 available
593 );
594
595 Ok(ToolResult::text(
596 &ctx,
597 format!(
598 "Sub-agent '{}' not found. Available sub-agents: {}",
599 args.agent,
600 available.join(", ")
601 ),
602 ))
603 }
604}
605
606#[derive(Debug, Clone)]
607pub struct SubAgentDescriptor {
608 pub name: String,
609 pub description: String,
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use agents_core::agent::{AgentDescriptor, AgentHandle};
616 use agents_core::messaging::{MessageContent, MessageRole};
617 use serde_json::json;
618
619 struct AppendPromptMiddleware;
620
621 #[async_trait]
622 impl AgentMiddleware for AppendPromptMiddleware {
623 fn id(&self) -> &'static str {
624 "append-prompt"
625 }
626
627 async fn modify_model_request(
628 &self,
629 ctx: &mut MiddlewareContext<'_>,
630 ) -> anyhow::Result<()> {
631 ctx.request.system_prompt.push_str("\nExtra directives.");
632 Ok(())
633 }
634 }
635
636 #[tokio::test]
637 async fn middleware_mutates_prompt() {
638 let mut request = ModelRequest::new(
639 "System",
640 vec![AgentMessage {
641 role: MessageRole::User,
642 content: MessageContent::Text("Hi".into()),
643 metadata: None,
644 }],
645 );
646 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
647 let mut ctx = MiddlewareContext::with_request(&mut request, state);
648 let middleware = AppendPromptMiddleware;
649 middleware.modify_model_request(&mut ctx).await.unwrap();
650 assert!(ctx.request.system_prompt.contains("Extra directives"));
651 }
652
653 #[tokio::test]
654 async fn planning_middleware_registers_write_todos() {
655 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
656 let middleware = PlanningMiddleware::new(state);
657 let tool_names: Vec<_> = middleware
658 .tools()
659 .iter()
660 .map(|t| t.schema().name.clone())
661 .collect();
662 assert!(tool_names.contains(&"write_todos".to_string()));
663
664 let mut request = ModelRequest::new("System", vec![]);
665 let mut ctx = MiddlewareContext::with_request(
666 &mut request,
667 Arc::new(RwLock::new(AgentStateSnapshot::default())),
668 );
669 middleware.modify_model_request(&mut ctx).await.unwrap();
670 assert!(ctx.request.system_prompt.contains("write_todos"));
671 }
672
673 #[tokio::test]
674 async fn filesystem_middleware_registers_tools() {
675 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
676 let middleware = FilesystemMiddleware::new(state);
677 let tool_names: Vec<_> = middleware
678 .tools()
679 .iter()
680 .map(|t| t.schema().name.clone())
681 .collect();
682 for expected in ["ls", "read_file", "write_file", "edit_file"] {
683 assert!(tool_names.contains(&expected.to_string()));
684 }
685 }
686
687 #[tokio::test]
688 async fn summarization_middleware_trims_messages() {
689 let middleware = SummarizationMiddleware::new(2, "Summary note");
690 let mut request = ModelRequest::new(
691 "System",
692 vec![
693 AgentMessage {
694 role: MessageRole::User,
695 content: MessageContent::Text("one".into()),
696 metadata: None,
697 },
698 AgentMessage {
699 role: MessageRole::Agent,
700 content: MessageContent::Text("two".into()),
701 metadata: None,
702 },
703 AgentMessage {
704 role: MessageRole::User,
705 content: MessageContent::Text("three".into()),
706 metadata: None,
707 },
708 ],
709 );
710 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
711 let mut ctx = MiddlewareContext::with_request(&mut request, state);
712 middleware.modify_model_request(&mut ctx).await.unwrap();
713 assert_eq!(ctx.request.messages.len(), 3);
714 match &ctx.request.messages[0].content {
715 MessageContent::Text(text) => assert!(text.contains("Summary note")),
716 other => panic!("expected text, got {other:?}"),
717 }
718 }
719
720 struct StubAgent;
721
722 #[async_trait]
723 impl AgentHandle for StubAgent {
724 async fn describe(&self) -> AgentDescriptor {
725 AgentDescriptor {
726 name: "stub".into(),
727 version: "0.0.1".into(),
728 description: None,
729 }
730 }
731
732 async fn handle_message(
733 &self,
734 _input: AgentMessage,
735 _state: Arc<AgentStateSnapshot>,
736 ) -> anyhow::Result<AgentMessage> {
737 Ok(AgentMessage {
738 role: MessageRole::Agent,
739 content: MessageContent::Text("stub-response".into()),
740 metadata: None,
741 })
742 }
743 }
744
745 #[tokio::test]
746 async fn task_router_reports_unknown_subagent() {
747 let registry = Arc::new(SubAgentRegistry::new(vec![]));
748 let task_tool = TaskRouterTool::new(registry.clone());
749 let state = Arc::new(AgentStateSnapshot::default());
750 let ctx = ToolContext::new(state);
751
752 let response = task_tool
753 .execute(
754 json!({
755 "instruction": "Do something",
756 "agent": "unknown"
757 }),
758 ctx,
759 )
760 .await
761 .unwrap();
762
763 match response {
764 ToolResult::Message(msg) => match msg.content {
765 MessageContent::Text(text) => {
766 assert!(text.contains("Sub-agent 'unknown' not found"))
767 }
768 other => panic!("expected text, got {other:?}"),
769 },
770 _ => panic!("expected message"),
771 }
772 }
773
774 #[tokio::test]
775 async fn subagent_middleware_appends_prompt() {
776 let subagents = vec![SubAgentRegistration {
777 descriptor: SubAgentDescriptor {
778 name: "research-agent".into(),
779 description: "Deep research specialist".into(),
780 },
781 agent: Arc::new(StubAgent),
782 }];
783 let middleware = SubAgentMiddleware::new(subagents);
784
785 let mut request = ModelRequest::new("System", vec![]);
786 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
787 let mut ctx = MiddlewareContext::with_request(&mut request, state);
788 middleware.modify_model_request(&mut ctx).await.unwrap();
789
790 assert!(ctx.request.system_prompt.contains("research-agent"));
791 let tool_names: Vec<_> = middleware
792 .tools()
793 .iter()
794 .map(|t| t.schema().name.clone())
795 .collect();
796 assert!(tool_names.contains(&"task".to_string()));
797 }
798
799 #[tokio::test]
800 async fn task_router_invokes_registered_subagent() {
801 let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
802 descriptor: SubAgentDescriptor {
803 name: "stub-agent".into(),
804 description: "Stub".into(),
805 },
806 agent: Arc::new(StubAgent),
807 }]));
808 let task_tool = TaskRouterTool::new(registry.clone());
809 let state = Arc::new(AgentStateSnapshot::default());
810 let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
811 let response = task_tool
812 .execute(
813 json!({
814 "description": "do work",
815 "subagent_type": "stub-agent"
816 }),
817 ctx,
818 )
819 .await
820 .unwrap();
821
822 match response {
823 ToolResult::Message(msg) => {
824 assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
825 match msg.content {
826 MessageContent::Text(text) => assert_eq!(text, "stub-response"),
827 other => panic!("expected text, got {other:?}"),
828 }
829 }
830 _ => panic!("expected message"),
831 }
832 }
833
834 #[tokio::test]
835 async fn human_in_loop_appends_prompt() {
836 let middleware = HumanInLoopMiddleware::new(HashMap::from([(
837 "danger-tool".into(),
838 HitlPolicy {
839 allow_auto: false,
840 note: Some("Requires security review".into()),
841 },
842 )]));
843 let mut request = ModelRequest::new("System", vec![]);
844 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
845 let mut ctx = MiddlewareContext::with_request(&mut request, state);
846 middleware.modify_model_request(&mut ctx).await.unwrap();
847 assert!(ctx
848 .request
849 .system_prompt
850 .contains("danger-tool: Requires security review"));
851 }
852
853 #[tokio::test]
854 async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
855 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
856 let mut request = ModelRequest::new(
857 "This is the system prompt",
858 vec![AgentMessage {
859 role: MessageRole::User,
860 content: MessageContent::Text("Hello".into()),
861 metadata: None,
862 }],
863 );
864 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
865 let mut ctx = MiddlewareContext::with_request(&mut request, state);
866
867 middleware.modify_model_request(&mut ctx).await.unwrap();
869
870 assert!(ctx.request.system_prompt.is_empty());
872
873 assert_eq!(ctx.request.messages.len(), 2);
875
876 let system_message = &ctx.request.messages[0];
877 assert!(matches!(system_message.role, MessageRole::System));
878 assert_eq!(
879 system_message.content.as_text().unwrap(),
880 "This is the system prompt"
881 );
882
883 let metadata = system_message.metadata.as_ref().unwrap();
885 let cache_control = metadata.cache_control.as_ref().unwrap();
886 assert_eq!(cache_control.cache_type, "ephemeral");
887
888 let user_message = &ctx.request.messages[1];
890 assert!(matches!(user_message.role, MessageRole::User));
891 assert_eq!(user_message.content.as_text().unwrap(), "Hello");
892 }
893
894 #[tokio::test]
895 async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
896 let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
897 let mut request = ModelRequest::new("This is the system prompt", vec![]);
898 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
899 let mut ctx = MiddlewareContext::with_request(&mut request, state);
900
901 middleware.modify_model_request(&mut ctx).await.unwrap();
903
904 assert_eq!(ctx.request.system_prompt, "This is the system prompt");
906 assert_eq!(ctx.request.messages.len(), 0);
907 }
908
909 #[tokio::test]
910 async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
911 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
912 let mut request = ModelRequest::new(
913 "",
914 vec![AgentMessage {
915 role: MessageRole::User,
916 content: MessageContent::Text("Hello".into()),
917 metadata: None,
918 }],
919 );
920 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
921 let mut ctx = MiddlewareContext::with_request(&mut request, state);
922
923 middleware.modify_model_request(&mut ctx).await.unwrap();
925
926 assert!(ctx.request.system_prompt.is_empty());
928 assert_eq!(ctx.request.messages.len(), 1);
930 }
931
932 #[tokio::test]
935 async fn hitl_creates_interrupt_for_disallowed_tool() {
936 let mut policies = HashMap::new();
937 policies.insert(
938 "dangerous_tool".to_string(),
939 HitlPolicy {
940 allow_auto: false,
941 note: Some("Requires security review".to_string()),
942 },
943 );
944
945 let middleware = HumanInLoopMiddleware::new(policies);
946 let tool_args = json!({"action": "delete_all"});
947
948 let result = middleware
949 .before_tool_execution("dangerous_tool", &tool_args, "call_123")
950 .await
951 .unwrap();
952
953 assert!(result.is_some());
954 let interrupt = result.unwrap();
955
956 match interrupt {
957 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
958 assert_eq!(hitl.tool_name, "dangerous_tool");
959 assert_eq!(hitl.tool_args, tool_args);
960 assert_eq!(hitl.call_id, "call_123");
961 assert_eq!(
962 hitl.policy_note,
963 Some("Requires security review".to_string())
964 );
965 }
966 }
967 }
968
969 #[tokio::test]
970 async fn hitl_no_interrupt_for_allowed_tool() {
971 let mut policies = HashMap::new();
972 policies.insert(
973 "safe_tool".to_string(),
974 HitlPolicy {
975 allow_auto: true,
976 note: None,
977 },
978 );
979
980 let middleware = HumanInLoopMiddleware::new(policies);
981 let tool_args = json!({"action": "read"});
982
983 let result = middleware
984 .before_tool_execution("safe_tool", &tool_args, "call_456")
985 .await
986 .unwrap();
987
988 assert!(result.is_none());
989 }
990
991 #[tokio::test]
992 async fn hitl_no_interrupt_for_unlisted_tool() {
993 let policies = HashMap::new();
994 let middleware = HumanInLoopMiddleware::new(policies);
995 let tool_args = json!({"action": "anything"});
996
997 let result = middleware
998 .before_tool_execution("unlisted_tool", &tool_args, "call_789")
999 .await
1000 .unwrap();
1001
1002 assert!(result.is_none());
1003 }
1004
1005 #[tokio::test]
1006 async fn hitl_interrupt_includes_correct_details() {
1007 let mut policies = HashMap::new();
1008 policies.insert(
1009 "critical_tool".to_string(),
1010 HitlPolicy {
1011 allow_auto: false,
1012 note: Some("Critical operation - requires approval".to_string()),
1013 },
1014 );
1015
1016 let middleware = HumanInLoopMiddleware::new(policies);
1017 let tool_args = json!({
1018 "database": "production",
1019 "operation": "drop_table"
1020 });
1021
1022 let result = middleware
1023 .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1024 .await
1025 .unwrap();
1026
1027 assert!(result.is_some());
1028 let interrupt = result.unwrap();
1029
1030 match interrupt {
1031 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1032 assert_eq!(hitl.tool_name, "critical_tool");
1033 assert_eq!(hitl.tool_args["database"], "production");
1034 assert_eq!(hitl.tool_args["operation"], "drop_table");
1035 assert_eq!(hitl.call_id, "call_critical_1");
1036 assert!(hitl.policy_note.is_some());
1037 assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1038 }
1041 }
1042 }
1043
1044 #[tokio::test]
1045 async fn hitl_interrupt_without_policy_note() {
1046 let mut policies = HashMap::new();
1047 policies.insert(
1048 "tool_no_note".to_string(),
1049 HitlPolicy {
1050 allow_auto: false,
1051 note: None,
1052 },
1053 );
1054
1055 let middleware = HumanInLoopMiddleware::new(policies);
1056 let tool_args = json!({"param": "value"});
1057
1058 let result = middleware
1059 .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1060 .await
1061 .unwrap();
1062
1063 assert!(result.is_some());
1064 let interrupt = result.unwrap();
1065
1066 match interrupt {
1067 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1068 assert_eq!(hitl.tool_name, "tool_no_note");
1069 assert_eq!(hitl.policy_note, None);
1070 }
1071 }
1072 }
1073}