agents_persistence/
postgres_checkpointer.rs1use agents_core::persistence::{Checkpointer, ThreadId};
23use agents_core::state::AgentStateSnapshot;
24use anyhow::Context;
25use async_trait::async_trait;
26use sqlx::{postgres::PgPoolOptions, PgPool, Row};
27
28#[derive(Clone)]
54pub struct PostgresCheckpointer {
55 pool: PgPool,
56 table_name: String,
57}
58
59impl PostgresCheckpointer {
60 pub async fn new(database_url: &str) -> anyhow::Result<Self> {
68 Self::builder().url(database_url).build().await
69 }
70
71 pub fn builder() -> PostgresCheckpointerBuilder {
73 PostgresCheckpointerBuilder::default()
74 }
75
76 async fn ensure_table(&self) -> anyhow::Result<()> {
78 let create_table_sql = format!(
80 r#"
81 CREATE TABLE IF NOT EXISTS {} (
82 thread_id TEXT PRIMARY KEY,
83 state JSONB NOT NULL,
84 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
85 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
86 )
87 "#,
88 self.table_name
89 );
90
91 sqlx::query(&create_table_sql)
92 .execute(&self.pool)
93 .await
94 .context("Failed to create checkpoints table")?;
95
96 let create_index_sql = format!(
98 r#"
99 CREATE INDEX IF NOT EXISTS idx_{}_updated_at
100 ON {} (updated_at DESC)
101 "#,
102 self.table_name, self.table_name
103 );
104
105 sqlx::query(&create_index_sql)
106 .execute(&self.pool)
107 .await
108 .context("Failed to create index")?;
109
110 Ok(())
111 }
112}
113
114#[async_trait]
115impl Checkpointer for PostgresCheckpointer {
116 async fn save_state(
117 &self,
118 thread_id: &ThreadId,
119 state: &AgentStateSnapshot,
120 ) -> anyhow::Result<()> {
121 let json =
122 serde_json::to_value(state).context("Failed to serialize agent state to JSON")?;
123
124 let query = format!(
125 r#"
126 INSERT INTO {} (thread_id, state, created_at, updated_at)
127 VALUES ($1, $2, NOW(), NOW())
128 ON CONFLICT (thread_id)
129 DO UPDATE SET state = $2, updated_at = NOW()
130 "#,
131 self.table_name
132 );
133
134 sqlx::query(&query)
135 .bind(thread_id)
136 .bind(&json)
137 .execute(&self.pool)
138 .await
139 .context("Failed to save state to PostgreSQL")?;
140
141 tracing::debug!(
142 thread_id = %thread_id,
143 table = %self.table_name,
144 "Saved agent state to PostgreSQL"
145 );
146
147 Ok(())
148 }
149
150 async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
151 let query = format!(
152 r#"
153 SELECT state FROM {} WHERE thread_id = $1
154 "#,
155 self.table_name
156 );
157
158 let row: Option<(serde_json::Value,)> = sqlx::query_as(&query)
159 .bind(thread_id)
160 .fetch_optional(&self.pool)
161 .await
162 .context("Failed to load state from PostgreSQL")?;
163
164 match row {
165 Some((json,)) => {
166 let state: AgentStateSnapshot = serde_json::from_value(json)
167 .context("Failed to deserialize agent state from JSON")?;
168
169 tracing::debug!(
170 thread_id = %thread_id,
171 table = %self.table_name,
172 "Loaded agent state from PostgreSQL"
173 );
174
175 Ok(Some(state))
176 }
177 None => {
178 tracing::debug!(
179 thread_id = %thread_id,
180 table = %self.table_name,
181 "No saved state found in PostgreSQL"
182 );
183 Ok(None)
184 }
185 }
186 }
187
188 async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
189 let query = format!(
190 r#"
191 DELETE FROM {} WHERE thread_id = $1
192 "#,
193 self.table_name
194 );
195
196 sqlx::query(&query)
197 .bind(thread_id)
198 .execute(&self.pool)
199 .await
200 .context("Failed to delete thread from PostgreSQL")?;
201
202 tracing::debug!(
203 thread_id = %thread_id,
204 table = %self.table_name,
205 "Deleted thread from PostgreSQL"
206 );
207
208 Ok(())
209 }
210
211 async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
212 let query = format!(
213 r#"
214 SELECT thread_id FROM {} ORDER BY updated_at DESC
215 "#,
216 self.table_name
217 );
218
219 let rows = sqlx::query(&query)
220 .fetch_all(&self.pool)
221 .await
222 .context("Failed to list threads from PostgreSQL")?;
223
224 let threads = rows
225 .into_iter()
226 .map(|row| row.get::<String, _>("thread_id"))
227 .collect();
228
229 Ok(threads)
230 }
231}
232
233#[derive(Default)]
235pub struct PostgresCheckpointerBuilder {
236 url: Option<String>,
237 table_name: Option<String>,
238 max_connections: Option<u32>,
239 min_connections: Option<u32>,
240}
241
242impl PostgresCheckpointerBuilder {
243 pub fn url(mut self, url: impl Into<String>) -> Self {
245 self.url = Some(url.into());
246 self
247 }
248
249 pub fn table_name(mut self, table_name: impl Into<String>) -> Self {
251 self.table_name = Some(table_name.into());
252 self
253 }
254
255 pub fn max_connections(mut self, max: u32) -> Self {
257 self.max_connections = Some(max);
258 self
259 }
260
261 pub fn min_connections(mut self, min: u32) -> Self {
263 self.min_connections = Some(min);
264 self
265 }
266
267 pub async fn build(self) -> anyhow::Result<PostgresCheckpointer> {
269 let url = self
270 .url
271 .ok_or_else(|| anyhow::anyhow!("PostgreSQL URL is required"))?;
272
273 let mut pool_options = PgPoolOptions::new();
274
275 if let Some(max) = self.max_connections {
276 pool_options = pool_options.max_connections(max);
277 } else {
278 pool_options = pool_options.max_connections(10);
279 }
280
281 if let Some(min) = self.min_connections {
282 pool_options = pool_options.min_connections(min);
283 }
284
285 let pool = pool_options
286 .connect(&url)
287 .await
288 .context("Failed to connect to PostgreSQL")?;
289
290 let checkpointer = PostgresCheckpointer {
291 pool,
292 table_name: self
293 .table_name
294 .unwrap_or_else(|| "agent_checkpoints".to_string()),
295 };
296
297 checkpointer
299 .ensure_table()
300 .await
301 .context("Failed to initialize database schema")?;
302
303 Ok(checkpointer)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use agents_core::state::TodoItem;
311
312 fn sample_state() -> AgentStateSnapshot {
313 let mut state = AgentStateSnapshot::default();
314 state.todos.push(TodoItem::pending("Test todo"));
315 state
316 .files
317 .insert("test.txt".to_string(), "content".to_string());
318 state
319 .scratchpad
320 .insert("key".to_string(), serde_json::json!("value"));
321 state
322 }
323
324 #[tokio::test]
325 #[ignore] async fn test_postgres_save_and_load() {
327 let checkpointer = PostgresCheckpointer::new("postgresql://localhost/agents_test")
328 .await
329 .expect("Failed to connect to PostgreSQL");
330
331 let thread_id = "test-thread".to_string();
332 let state = sample_state();
333
334 checkpointer
336 .save_state(&thread_id, &state)
337 .await
338 .expect("Failed to save state");
339
340 let loaded = checkpointer
342 .load_state(&thread_id)
343 .await
344 .expect("Failed to load state");
345
346 assert!(loaded.is_some());
347 let loaded_state = loaded.unwrap();
348
349 assert_eq!(loaded_state.todos.len(), 1);
350 assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
351
352 checkpointer
354 .delete_thread(&thread_id)
355 .await
356 .expect("Failed to delete thread");
357 }
358
359 #[tokio::test]
360 #[ignore] async fn test_postgres_list_threads() {
362 let checkpointer = PostgresCheckpointer::builder()
363 .url("postgresql://localhost/agents_test")
364 .table_name("test_checkpoints")
365 .build()
366 .await
367 .expect("Failed to connect to PostgreSQL");
368
369 let state = sample_state();
370
371 checkpointer
373 .save_state(&"thread1".to_string(), &state)
374 .await
375 .unwrap();
376 checkpointer
377 .save_state(&"thread2".to_string(), &state)
378 .await
379 .unwrap();
380
381 let threads = checkpointer.list_threads().await.unwrap();
383 assert!(threads.contains(&"thread1".to_string()));
384 assert!(threads.contains(&"thread2".to_string()));
385
386 checkpointer
388 .delete_thread(&"thread1".to_string())
389 .await
390 .unwrap();
391 checkpointer
392 .delete_thread(&"thread2".to_string())
393 .await
394 .unwrap();
395 }
396}