1use crate::messaging::AgentMessage;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13#[serde(tag = "type")]
14pub enum AgentInterrupt {
15 #[serde(rename = "human_in_loop")]
17 HumanInLoop(HitlInterrupt),
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub struct HitlInterrupt {
23 pub tool_name: String,
25
26 pub tool_args: serde_json::Value,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub policy_note: Option<String>,
32
33 pub created_at: DateTime<Utc>,
35
36 pub call_id: String,
38}
39
40impl HitlInterrupt {
41 pub fn new(
43 tool_name: impl Into<String>,
44 tool_args: serde_json::Value,
45 call_id: impl Into<String>,
46 policy_note: Option<String>,
47 ) -> Self {
48 Self {
49 tool_name: tool_name.into(),
50 tool_args,
51 policy_note,
52 created_at: Utc::now(),
53 call_id: call_id.into(),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60#[serde(tag = "action", rename_all = "lowercase")]
61pub enum HitlAction {
62 Accept,
64
65 Edit {
67 tool_name: String,
69 tool_args: serde_json::Value,
71 },
72
73 Reject {
75 #[serde(skip_serializing_if = "Option::is_none")]
77 reason: Option<String>,
78 },
79
80 Respond {
82 message: AgentMessage,
84 },
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::messaging::{MessageContent, MessageRole};
91 use serde_json::json;
92
93 #[test]
94 fn test_hitl_interrupt_creation() {
95 let interrupt = HitlInterrupt::new(
96 "test_tool",
97 json!({"arg": "value"}),
98 "call_123",
99 Some("Test note".to_string()),
100 );
101
102 assert_eq!(interrupt.tool_name, "test_tool");
103 assert_eq!(interrupt.tool_args, json!({"arg": "value"}));
104 assert_eq!(interrupt.call_id, "call_123");
105 assert_eq!(interrupt.policy_note, Some("Test note".to_string()));
106 }
107
108 #[test]
109 fn test_hitl_interrupt_serialization() {
110 let interrupt = HitlInterrupt::new(
111 "test_tool",
112 json!({"arg": "value"}),
113 "call_123",
114 Some("Test note".to_string()),
115 );
116
117 let agent_interrupt = AgentInterrupt::HumanInLoop(interrupt.clone());
118
119 let json = serde_json::to_string(&agent_interrupt).unwrap();
121 assert!(json.contains("human_in_loop"));
122 assert!(json.contains("test_tool"));
123
124 let deserialized: AgentInterrupt = serde_json::from_str(&json).unwrap();
126 assert_eq!(deserialized, agent_interrupt);
127 }
128
129 #[test]
130 fn test_hitl_action_accept() {
131 let action = HitlAction::Accept;
132 let json = serde_json::to_string(&action).unwrap();
133 assert!(json.contains("accept"));
134
135 let deserialized: HitlAction = serde_json::from_str(&json).unwrap();
136 assert_eq!(deserialized, action);
137 }
138
139 #[test]
140 fn test_hitl_action_edit() {
141 let action = HitlAction::Edit {
142 tool_name: "modified_tool".to_string(),
143 tool_args: json!({"new_arg": "new_value"}),
144 };
145
146 let json = serde_json::to_string(&action).unwrap();
147 assert!(json.contains("edit"));
148 assert!(json.contains("modified_tool"));
149
150 let deserialized: HitlAction = serde_json::from_str(&json).unwrap();
151 assert_eq!(deserialized, action);
152 }
153
154 #[test]
155 fn test_hitl_action_reject() {
156 let action = HitlAction::Reject {
157 reason: Some("Not safe".to_string()),
158 };
159
160 let json = serde_json::to_string(&action).unwrap();
161 assert!(json.contains("reject"));
162 assert!(json.contains("Not safe"));
163
164 let deserialized: HitlAction = serde_json::from_str(&json).unwrap();
165 assert_eq!(deserialized, action);
166 }
167
168 #[test]
169 fn test_hitl_action_respond() {
170 let message = AgentMessage {
171 role: MessageRole::Agent,
172 content: MessageContent::Text("Custom response".to_string()),
173 metadata: None,
174 };
175
176 let action = HitlAction::Respond {
177 message: message.clone(),
178 };
179
180 let json = serde_json::to_string(&action).unwrap();
181 assert!(json.contains("respond"));
182 assert!(json.contains("Custom response"));
183
184 let deserialized: HitlAction = serde_json::from_str(&json).unwrap();
185 assert_eq!(deserialized, action);
186 }
187
188 #[test]
189 fn test_interrupt_without_policy_note() {
190 let interrupt = HitlInterrupt::new("test_tool", json!({}), "call_123", None);
191
192 assert_eq!(interrupt.policy_note, None);
193
194 let json = serde_json::to_string(&interrupt).unwrap();
195 assert!(!json.contains("policy_note"));
196 }
197}