agents_persistence/
redis_checkpointer.rs

1//! Redis-backed checkpointer implementation using connection pooling.
2//!
3//! This checkpointer stores agent state in Redis with automatic serialization
4//! and deserialization. It's ideal for:
5//! - High-performance applications requiring fast state access
6//! - Distributed systems where multiple agent instances share state
7//! - Applications already using Redis for caching or session management
8//!
9//! ## Features
10//!
11//! - Automatic JSON serialization/deserialization
12//! - Connection pooling for efficient resource usage
13//! - TTL support for automatic state expiration
14//! - Namespace support for multi-tenant applications
15
16use agents_core::persistence::{Checkpointer, ThreadId};
17use agents_core::state::AgentStateSnapshot;
18use anyhow::Context;
19use async_trait::async_trait;
20use redis::{aio::ConnectionManager, AsyncCommands};
21use std::time::Duration;
22
23/// Redis-backed checkpointer with connection pooling and TTL support.
24///
25/// # Examples
26///
27/// ```rust,no_run
28/// use agents_persistence::RedisCheckpointer;
29/// use std::time::Duration;
30///
31/// #[tokio::main]
32/// async fn main() -> anyhow::Result<()> {
33///     // Basic usage
34///     let checkpointer = RedisCheckpointer::new("redis://127.0.0.1:6379").await?;
35///
36///     // With namespace and TTL
37///     let checkpointer = RedisCheckpointer::builder()
38///         .url("redis://127.0.0.1:6379")
39///         .namespace("myapp")
40///         .ttl(Duration::from_secs(86400)) // 24 hours
41///         .build()
42///         .await?;
43///
44///     Ok(())
45/// }
46/// ```
47#[derive(Clone)]
48pub struct RedisCheckpointer {
49    connection: ConnectionManager,
50    namespace: String,
51    ttl: Option<Duration>,
52}
53
54impl RedisCheckpointer {
55    /// Create a new Redis checkpointer with the default namespace.
56    ///
57    /// # Arguments
58    ///
59    /// * `url` - Redis connection URL (e.g., "redis://127.0.0.1:6379")
60    pub async fn new(url: &str) -> anyhow::Result<Self> {
61        Self::builder().url(url).build().await
62    }
63
64    /// Create a builder for configuring the Redis checkpointer.
65    pub fn builder() -> RedisCheckpointerBuilder {
66        RedisCheckpointerBuilder::default()
67    }
68
69    /// Generate the full Redis key for a thread.
70    fn key_for_thread(&self, thread_id: &ThreadId) -> String {
71        format!("{}:thread:{}", self.namespace, thread_id)
72    }
73
74    /// Generate the Redis key for the thread index.
75    fn threads_index_key(&self) -> String {
76        format!("{}:threads", self.namespace)
77    }
78}
79
80#[async_trait]
81impl Checkpointer for RedisCheckpointer {
82    async fn save_state(
83        &self,
84        thread_id: &ThreadId,
85        state: &AgentStateSnapshot,
86    ) -> anyhow::Result<()> {
87        let key = self.key_for_thread(thread_id);
88        let index_key = self.threads_index_key();
89
90        let json =
91            serde_json::to_string(state).context("Failed to serialize agent state to JSON")?;
92
93        let mut conn = self.connection.clone();
94
95        // Save the state
96        if let Some(ttl) = self.ttl {
97            conn.set_ex::<_, _, ()>(&key, json, ttl.as_secs())
98                .await
99                .context("Failed to save state to Redis with TTL")?;
100        } else {
101            conn.set::<_, _, ()>(&key, json)
102                .await
103                .context("Failed to save state to Redis")?;
104        }
105
106        // Add to thread index
107        conn.sadd::<_, _, ()>(&index_key, thread_id)
108            .await
109            .context("Failed to update thread index")?;
110
111        tracing::debug!(
112            thread_id = %thread_id,
113            namespace = %self.namespace,
114            "Saved agent state to Redis"
115        );
116
117        Ok(())
118    }
119
120    async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
121        let key = self.key_for_thread(thread_id);
122        let mut conn = self.connection.clone();
123
124        let json: Option<String> = conn
125            .get(&key)
126            .await
127            .context("Failed to load state from Redis")?;
128
129        match json {
130            Some(data) => {
131                let state: AgentStateSnapshot = serde_json::from_str(&data)
132                    .context("Failed to deserialize agent state from JSON")?;
133
134                tracing::debug!(
135                    thread_id = %thread_id,
136                    namespace = %self.namespace,
137                    "Loaded agent state from Redis"
138                );
139
140                Ok(Some(state))
141            }
142            None => {
143                tracing::debug!(
144                    thread_id = %thread_id,
145                    namespace = %self.namespace,
146                    "No saved state found in Redis"
147                );
148                Ok(None)
149            }
150        }
151    }
152
153    async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
154        let key = self.key_for_thread(thread_id);
155        let index_key = self.threads_index_key();
156        let mut conn = self.connection.clone();
157
158        // Delete the state
159        conn.del::<_, ()>(&key)
160            .await
161            .context("Failed to delete state from Redis")?;
162
163        // Remove from thread index
164        conn.srem::<_, _, ()>(&index_key, thread_id)
165            .await
166            .context("Failed to update thread index")?;
167
168        tracing::debug!(
169            thread_id = %thread_id,
170            namespace = %self.namespace,
171            "Deleted thread from Redis"
172        );
173
174        Ok(())
175    }
176
177    async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
178        let index_key = self.threads_index_key();
179        let mut conn = self.connection.clone();
180
181        let threads: Vec<String> = conn
182            .smembers(&index_key)
183            .await
184            .context("Failed to list threads from Redis")?;
185
186        Ok(threads)
187    }
188}
189
190/// Builder for configuring a Redis checkpointer.
191#[derive(Default)]
192pub struct RedisCheckpointerBuilder {
193    url: Option<String>,
194    namespace: Option<String>,
195    ttl: Option<Duration>,
196}
197
198impl RedisCheckpointerBuilder {
199    /// Set the Redis connection URL.
200    pub fn url(mut self, url: impl Into<String>) -> Self {
201        self.url = Some(url.into());
202        self
203    }
204
205    /// Set the namespace for Redis keys (default: "agents").
206    ///
207    /// This is useful for multi-tenant applications or when multiple
208    /// agent systems share the same Redis instance.
209    pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
210        self.namespace = Some(namespace.into());
211        self
212    }
213
214    /// Set the TTL (time-to-live) for stored states.
215    ///
216    /// After this duration, Redis will automatically delete the state.
217    /// Useful for implementing automatic cleanup policies.
218    pub fn ttl(mut self, ttl: Duration) -> Self {
219        self.ttl = Some(ttl);
220        self
221    }
222
223    /// Build the Redis checkpointer.
224    pub async fn build(self) -> anyhow::Result<RedisCheckpointer> {
225        let url = self
226            .url
227            .ok_or_else(|| anyhow::anyhow!("Redis URL is required"))?;
228
229        let client = redis::Client::open(url.as_str()).context("Failed to create Redis client")?;
230
231        let connection = ConnectionManager::new(client)
232            .await
233            .context("Failed to establish Redis connection")?;
234
235        Ok(RedisCheckpointer {
236            connection,
237            namespace: self.namespace.unwrap_or_else(|| "agents".to_string()),
238            ttl: self.ttl,
239        })
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use agents_core::state::TodoItem;
247
248    fn sample_state() -> AgentStateSnapshot {
249        let mut state = AgentStateSnapshot::default();
250        state.todos.push(TodoItem::pending("Test todo"));
251        state
252            .files
253            .insert("test.txt".to_string(), "content".to_string());
254        state
255            .scratchpad
256            .insert("key".to_string(), serde_json::json!("value"));
257        state
258    }
259
260    #[tokio::test]
261    #[ignore] // Requires Redis instance running
262    async fn test_redis_save_and_load() {
263        let checkpointer = RedisCheckpointer::new("redis://127.0.0.1:6379")
264            .await
265            .expect("Failed to connect to Redis");
266
267        let thread_id = "test-thread".to_string();
268        let state = sample_state();
269
270        // Save state
271        checkpointer
272            .save_state(&thread_id, &state)
273            .await
274            .expect("Failed to save state");
275
276        // Load state
277        let loaded = checkpointer
278            .load_state(&thread_id)
279            .await
280            .expect("Failed to load state");
281
282        assert!(loaded.is_some());
283        let loaded_state = loaded.unwrap();
284
285        assert_eq!(loaded_state.todos.len(), 1);
286        assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
287
288        // Cleanup
289        checkpointer
290            .delete_thread(&thread_id)
291            .await
292            .expect("Failed to delete thread");
293    }
294
295    #[tokio::test]
296    #[ignore] // Requires Redis instance running
297    async fn test_redis_list_threads() {
298        let checkpointer = RedisCheckpointer::builder()
299            .url("redis://127.0.0.1:6379")
300            .namespace("test-namespace")
301            .build()
302            .await
303            .expect("Failed to connect to Redis");
304
305        let state = sample_state();
306
307        // Save multiple threads
308        checkpointer
309            .save_state(&"thread1".to_string(), &state)
310            .await
311            .unwrap();
312        checkpointer
313            .save_state(&"thread2".to_string(), &state)
314            .await
315            .unwrap();
316
317        // List threads
318        let threads = checkpointer.list_threads().await.unwrap();
319        assert!(threads.contains(&"thread1".to_string()));
320        assert!(threads.contains(&"thread2".to_string()));
321
322        // Cleanup
323        checkpointer
324            .delete_thread(&"thread1".to_string())
325            .await
326            .unwrap();
327        checkpointer
328            .delete_thread(&"thread2".to_string())
329            .await
330            .unwrap();
331    }
332}