agents_core/
persistence.rs

1//! Persistence traits for checkpointing agent state between runs.
2
3use crate::state::AgentStateSnapshot;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Unique identifier for a conversation thread/session.
9pub type ThreadId = String;
10
11/// Configuration for a checkpointer instance.
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
13pub struct CheckpointerConfig {
14    /// Additional configuration parameters specific to the checkpointer implementation.
15    pub params: HashMap<String, serde_json::Value>,
16}
17
18/// Trait for persisting and retrieving agent state between conversation runs.
19/// This mirrors the LangGraph Checkpointer interface used in the Python implementation.
20#[async_trait]
21pub trait Checkpointer: Send + Sync {
22    /// Save the current agent state for a given thread.
23    async fn save_state(
24        &self,
25        thread_id: &ThreadId,
26        state: &AgentStateSnapshot,
27    ) -> anyhow::Result<()>;
28
29    /// Load the last saved state for a given thread.
30    /// Returns None if no state exists for this thread.
31    async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>>;
32
33    /// Delete all saved state for a given thread.
34    async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()>;
35
36    /// List all thread IDs that have saved state.
37    async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>>;
38}
39
40/// In-memory checkpointer for testing and development.
41/// State is not persisted between process restarts.
42#[derive(Debug, Default)]
43pub struct InMemoryCheckpointer {
44    states: std::sync::RwLock<HashMap<ThreadId, AgentStateSnapshot>>,
45}
46
47impl InMemoryCheckpointer {
48    pub fn new() -> Self {
49        Self::default()
50    }
51}
52
53#[async_trait]
54impl Checkpointer for InMemoryCheckpointer {
55    async fn save_state(
56        &self,
57        thread_id: &ThreadId,
58        state: &AgentStateSnapshot,
59    ) -> anyhow::Result<()> {
60        let mut states = self.states.write().map_err(|_| {
61            anyhow::anyhow!("Failed to acquire write lock on in-memory checkpointer")
62        })?;
63        states.insert(thread_id.clone(), state.clone());
64        tracing::debug!(thread_id = %thread_id, "Saved agent state to memory");
65        Ok(())
66    }
67
68    async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
69        let states = self.states.read().map_err(|_| {
70            anyhow::anyhow!("Failed to acquire read lock on in-memory checkpointer")
71        })?;
72        let state = states.get(thread_id).cloned();
73        if state.is_some() {
74            tracing::debug!(thread_id = %thread_id, "Loaded agent state from memory");
75        }
76        Ok(state)
77    }
78
79    async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
80        let mut states = self.states.write().map_err(|_| {
81            anyhow::anyhow!("Failed to acquire write lock on in-memory checkpointer")
82        })?;
83        states.remove(thread_id);
84        tracing::debug!(thread_id = %thread_id, "Deleted thread from memory");
85        Ok(())
86    }
87
88    async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
89        let states = self.states.read().map_err(|_| {
90            anyhow::anyhow!("Failed to acquire read lock on in-memory checkpointer")
91        })?;
92        Ok(states.keys().cloned().collect())
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::state::{TodoItem, TodoStatus};
100
101    fn sample_state() -> AgentStateSnapshot {
102        let mut state = AgentStateSnapshot::default();
103        state.todos.push(TodoItem {
104            content: "Test todo".to_string(),
105            status: TodoStatus::Pending,
106        });
107        state
108            .files
109            .insert("test.txt".to_string(), "content".to_string());
110        state
111            .scratchpad
112            .insert("key".to_string(), serde_json::json!("value"));
113        state
114    }
115
116    #[tokio::test]
117    async fn in_memory_checkpointer_save_and_load() {
118        let checkpointer = InMemoryCheckpointer::new();
119        let thread_id = "test-thread".to_string();
120        let state = sample_state();
121
122        // Save state
123        checkpointer.save_state(&thread_id, &state).await.unwrap();
124
125        // Load state
126        let loaded = checkpointer.load_state(&thread_id).await.unwrap();
127        assert!(loaded.is_some());
128        let loaded_state = loaded.unwrap();
129
130        assert_eq!(loaded_state.todos.len(), 1);
131        assert_eq!(loaded_state.todos[0].content, "Test todo");
132        assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
133        assert_eq!(
134            loaded_state.scratchpad.get("key").unwrap(),
135            &serde_json::json!("value")
136        );
137    }
138
139    #[tokio::test]
140    async fn in_memory_checkpointer_nonexistent_thread() {
141        let checkpointer = InMemoryCheckpointer::new();
142        let result = checkpointer
143            .load_state(&"nonexistent".to_string())
144            .await
145            .unwrap();
146        assert!(result.is_none());
147    }
148
149    #[tokio::test]
150    async fn in_memory_checkpointer_delete_thread() {
151        let checkpointer = InMemoryCheckpointer::new();
152        let thread_id = "test-thread".to_string();
153        let state = sample_state();
154
155        // Save and verify
156        checkpointer.save_state(&thread_id, &state).await.unwrap();
157        assert!(checkpointer.load_state(&thread_id).await.unwrap().is_some());
158
159        // Delete and verify
160        checkpointer.delete_thread(&thread_id).await.unwrap();
161        assert!(checkpointer.load_state(&thread_id).await.unwrap().is_none());
162    }
163
164    #[tokio::test]
165    async fn in_memory_checkpointer_list_threads() {
166        let checkpointer = InMemoryCheckpointer::new();
167        let state = sample_state();
168
169        // Save multiple threads
170        checkpointer
171            .save_state(&"thread1".to_string(), &state)
172            .await
173            .unwrap();
174        checkpointer
175            .save_state(&"thread2".to_string(), &state)
176            .await
177            .unwrap();
178
179        let threads = checkpointer.list_threads().await.unwrap();
180        assert_eq!(threads.len(), 2);
181        assert!(threads.contains(&"thread1".to_string()));
182        assert!(threads.contains(&"thread2".to_string()));
183    }
184}