agents_persistence/
redis_checkpointer.rs1use agents_core::persistence::{Checkpointer, ThreadId};
17use agents_core::state::AgentStateSnapshot;
18use anyhow::Context;
19use async_trait::async_trait;
20use redis::{aio::ConnectionManager, AsyncCommands};
21use std::time::Duration;
22
23#[derive(Clone)]
48pub struct RedisCheckpointer {
49 connection: ConnectionManager,
50 namespace: String,
51 ttl: Option<Duration>,
52}
53
54impl RedisCheckpointer {
55 pub async fn new(url: &str) -> anyhow::Result<Self> {
61 Self::builder().url(url).build().await
62 }
63
64 pub fn builder() -> RedisCheckpointerBuilder {
66 RedisCheckpointerBuilder::default()
67 }
68
69 fn key_for_thread(&self, thread_id: &ThreadId) -> String {
71 format!("{}:thread:{}", self.namespace, thread_id)
72 }
73
74 fn threads_index_key(&self) -> String {
76 format!("{}:threads", self.namespace)
77 }
78}
79
80#[async_trait]
81impl Checkpointer for RedisCheckpointer {
82 async fn save_state(
83 &self,
84 thread_id: &ThreadId,
85 state: &AgentStateSnapshot,
86 ) -> anyhow::Result<()> {
87 let key = self.key_for_thread(thread_id);
88 let index_key = self.threads_index_key();
89
90 let json =
91 serde_json::to_string(state).context("Failed to serialize agent state to JSON")?;
92
93 let mut conn = self.connection.clone();
94
95 if let Some(ttl) = self.ttl {
97 conn.set_ex::<_, _, ()>(&key, json, ttl.as_secs())
98 .await
99 .context("Failed to save state to Redis with TTL")?;
100 } else {
101 conn.set::<_, _, ()>(&key, json)
102 .await
103 .context("Failed to save state to Redis")?;
104 }
105
106 conn.sadd::<_, _, ()>(&index_key, thread_id)
108 .await
109 .context("Failed to update thread index")?;
110
111 tracing::debug!(
112 thread_id = %thread_id,
113 namespace = %self.namespace,
114 "Saved agent state to Redis"
115 );
116
117 Ok(())
118 }
119
120 async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
121 let key = self.key_for_thread(thread_id);
122 let mut conn = self.connection.clone();
123
124 let json: Option<String> = conn
125 .get(&key)
126 .await
127 .context("Failed to load state from Redis")?;
128
129 match json {
130 Some(data) => {
131 let state: AgentStateSnapshot = serde_json::from_str(&data)
132 .context("Failed to deserialize agent state from JSON")?;
133
134 tracing::debug!(
135 thread_id = %thread_id,
136 namespace = %self.namespace,
137 "Loaded agent state from Redis"
138 );
139
140 Ok(Some(state))
141 }
142 None => {
143 tracing::debug!(
144 thread_id = %thread_id,
145 namespace = %self.namespace,
146 "No saved state found in Redis"
147 );
148 Ok(None)
149 }
150 }
151 }
152
153 async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
154 let key = self.key_for_thread(thread_id);
155 let index_key = self.threads_index_key();
156 let mut conn = self.connection.clone();
157
158 conn.del::<_, ()>(&key)
160 .await
161 .context("Failed to delete state from Redis")?;
162
163 conn.srem::<_, _, ()>(&index_key, thread_id)
165 .await
166 .context("Failed to update thread index")?;
167
168 tracing::debug!(
169 thread_id = %thread_id,
170 namespace = %self.namespace,
171 "Deleted thread from Redis"
172 );
173
174 Ok(())
175 }
176
177 async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
178 let index_key = self.threads_index_key();
179 let mut conn = self.connection.clone();
180
181 let threads: Vec<String> = conn
182 .smembers(&index_key)
183 .await
184 .context("Failed to list threads from Redis")?;
185
186 Ok(threads)
187 }
188}
189
190#[derive(Default)]
192pub struct RedisCheckpointerBuilder {
193 url: Option<String>,
194 namespace: Option<String>,
195 ttl: Option<Duration>,
196}
197
198impl RedisCheckpointerBuilder {
199 pub fn url(mut self, url: impl Into<String>) -> Self {
201 self.url = Some(url.into());
202 self
203 }
204
205 pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
210 self.namespace = Some(namespace.into());
211 self
212 }
213
214 pub fn ttl(mut self, ttl: Duration) -> Self {
219 self.ttl = Some(ttl);
220 self
221 }
222
223 pub async fn build(self) -> anyhow::Result<RedisCheckpointer> {
225 let url = self
226 .url
227 .ok_or_else(|| anyhow::anyhow!("Redis URL is required"))?;
228
229 let client = redis::Client::open(url.as_str()).context("Failed to create Redis client")?;
230
231 let connection = ConnectionManager::new(client)
232 .await
233 .context("Failed to establish Redis connection")?;
234
235 Ok(RedisCheckpointer {
236 connection,
237 namespace: self.namespace.unwrap_or_else(|| "agents".to_string()),
238 ttl: self.ttl,
239 })
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use agents_core::state::TodoItem;
247
248 fn sample_state() -> AgentStateSnapshot {
249 let mut state = AgentStateSnapshot::default();
250 state.todos.push(TodoItem::pending("Test todo"));
251 state
252 .files
253 .insert("test.txt".to_string(), "content".to_string());
254 state
255 .scratchpad
256 .insert("key".to_string(), serde_json::json!("value"));
257 state
258 }
259
260 #[tokio::test]
261 #[ignore] async fn test_redis_save_and_load() {
263 let checkpointer = RedisCheckpointer::new("redis://127.0.0.1:6379")
264 .await
265 .expect("Failed to connect to Redis");
266
267 let thread_id = "test-thread".to_string();
268 let state = sample_state();
269
270 checkpointer
272 .save_state(&thread_id, &state)
273 .await
274 .expect("Failed to save state");
275
276 let loaded = checkpointer
278 .load_state(&thread_id)
279 .await
280 .expect("Failed to load state");
281
282 assert!(loaded.is_some());
283 let loaded_state = loaded.unwrap();
284
285 assert_eq!(loaded_state.todos.len(), 1);
286 assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
287
288 checkpointer
290 .delete_thread(&thread_id)
291 .await
292 .expect("Failed to delete thread");
293 }
294
295 #[tokio::test]
296 #[ignore] async fn test_redis_list_threads() {
298 let checkpointer = RedisCheckpointer::builder()
299 .url("redis://127.0.0.1:6379")
300 .namespace("test-namespace")
301 .build()
302 .await
303 .expect("Failed to connect to Redis");
304
305 let state = sample_state();
306
307 checkpointer
309 .save_state(&"thread1".to_string(), &state)
310 .await
311 .unwrap();
312 checkpointer
313 .save_state(&"thread2".to_string(), &state)
314 .await
315 .unwrap();
316
317 let threads = checkpointer.list_threads().await.unwrap();
319 assert!(threads.contains(&"thread1".to_string()));
320 assert!(threads.contains(&"thread2".to_string()));
321
322 checkpointer
324 .delete_thread(&"thread1".to_string())
325 .await
326 .unwrap();
327 checkpointer
328 .delete_thread(&"thread2".to_string())
329 .await
330 .unwrap();
331 }
332}