agents_core/
persistence.rs1use crate::state::AgentStateSnapshot;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub type ThreadId = String;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
13pub struct CheckpointerConfig {
14 pub params: HashMap<String, serde_json::Value>,
16}
17
18#[async_trait]
21pub trait Checkpointer: Send + Sync {
22 async fn save_state(
24 &self,
25 thread_id: &ThreadId,
26 state: &AgentStateSnapshot,
27 ) -> anyhow::Result<()>;
28
29 async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>>;
32
33 async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()>;
35
36 async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>>;
38}
39
40#[derive(Debug, Default)]
43pub struct InMemoryCheckpointer {
44 states: std::sync::RwLock<HashMap<ThreadId, AgentStateSnapshot>>,
45}
46
47impl InMemoryCheckpointer {
48 pub fn new() -> Self {
49 Self::default()
50 }
51}
52
53#[async_trait]
54impl Checkpointer for InMemoryCheckpointer {
55 async fn save_state(
56 &self,
57 thread_id: &ThreadId,
58 state: &AgentStateSnapshot,
59 ) -> anyhow::Result<()> {
60 let mut states = self.states.write().map_err(|_| {
61 anyhow::anyhow!("Failed to acquire write lock on in-memory checkpointer")
62 })?;
63 states.insert(thread_id.clone(), state.clone());
64 tracing::debug!(thread_id = %thread_id, "Saved agent state to memory");
65 Ok(())
66 }
67
68 async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
69 let states = self.states.read().map_err(|_| {
70 anyhow::anyhow!("Failed to acquire read lock on in-memory checkpointer")
71 })?;
72 let state = states.get(thread_id).cloned();
73 if state.is_some() {
74 tracing::debug!(thread_id = %thread_id, "Loaded agent state from memory");
75 }
76 Ok(state)
77 }
78
79 async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
80 let mut states = self.states.write().map_err(|_| {
81 anyhow::anyhow!("Failed to acquire write lock on in-memory checkpointer")
82 })?;
83 states.remove(thread_id);
84 tracing::debug!(thread_id = %thread_id, "Deleted thread from memory");
85 Ok(())
86 }
87
88 async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
89 let states = self.states.read().map_err(|_| {
90 anyhow::anyhow!("Failed to acquire read lock on in-memory checkpointer")
91 })?;
92 Ok(states.keys().cloned().collect())
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::state::{TodoItem, TodoStatus};
100
101 fn sample_state() -> AgentStateSnapshot {
102 let mut state = AgentStateSnapshot::default();
103 state.todos.push(TodoItem {
104 content: "Test todo".to_string(),
105 status: TodoStatus::Pending,
106 });
107 state
108 .files
109 .insert("test.txt".to_string(), "content".to_string());
110 state
111 .scratchpad
112 .insert("key".to_string(), serde_json::json!("value"));
113 state
114 }
115
116 #[tokio::test]
117 async fn in_memory_checkpointer_save_and_load() {
118 let checkpointer = InMemoryCheckpointer::new();
119 let thread_id = "test-thread".to_string();
120 let state = sample_state();
121
122 checkpointer.save_state(&thread_id, &state).await.unwrap();
124
125 let loaded = checkpointer.load_state(&thread_id).await.unwrap();
127 assert!(loaded.is_some());
128 let loaded_state = loaded.unwrap();
129
130 assert_eq!(loaded_state.todos.len(), 1);
131 assert_eq!(loaded_state.todos[0].content, "Test todo");
132 assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
133 assert_eq!(
134 loaded_state.scratchpad.get("key").unwrap(),
135 &serde_json::json!("value")
136 );
137 }
138
139 #[tokio::test]
140 async fn in_memory_checkpointer_nonexistent_thread() {
141 let checkpointer = InMemoryCheckpointer::new();
142 let result = checkpointer
143 .load_state(&"nonexistent".to_string())
144 .await
145 .unwrap();
146 assert!(result.is_none());
147 }
148
149 #[tokio::test]
150 async fn in_memory_checkpointer_delete_thread() {
151 let checkpointer = InMemoryCheckpointer::new();
152 let thread_id = "test-thread".to_string();
153 let state = sample_state();
154
155 checkpointer.save_state(&thread_id, &state).await.unwrap();
157 assert!(checkpointer.load_state(&thread_id).await.unwrap().is_some());
158
159 checkpointer.delete_thread(&thread_id).await.unwrap();
161 assert!(checkpointer.load_state(&thread_id).await.unwrap().is_none());
162 }
163
164 #[tokio::test]
165 async fn in_memory_checkpointer_list_threads() {
166 let checkpointer = InMemoryCheckpointer::new();
167 let state = sample_state();
168
169 checkpointer
171 .save_state(&"thread1".to_string(), &state)
172 .await
173 .unwrap();
174 checkpointer
175 .save_state(&"thread2".to_string(), &state)
176 .await
177 .unwrap();
178
179 let threads = checkpointer.list_threads().await.unwrap();
180 assert_eq!(threads.len(), 2);
181 assert!(threads.contains(&"thread1".to_string()));
182 assert!(threads.contains(&"thread2".to_string()));
183 }
184}