agents_persistence/
postgres_checkpointer.rs

1//! PostgreSQL-backed checkpointer implementation with ACID guarantees.
2//!
3//! This checkpointer stores agent state in a PostgreSQL database, providing:
4//! - ACID transaction guarantees
5//! - Persistent storage with backup capabilities
6//! - SQL querying for analytics and debugging
7//! - Multi-region replication support
8//!
9//! ## Schema
10//!
11//! The checkpointer automatically creates the following table:
12//!
13//! ```sql
14//! CREATE TABLE IF NOT EXISTS agent_checkpoints (
15//!     thread_id TEXT PRIMARY KEY,
16//!     state JSONB NOT NULL,
17//!     created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
18//!     updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
19//! );
20//! ```
21
22use agents_core::persistence::{Checkpointer, ThreadId};
23use agents_core::state::AgentStateSnapshot;
24use anyhow::Context;
25use async_trait::async_trait;
26use sqlx::{postgres::PgPoolOptions, PgPool, Row};
27
28/// PostgreSQL-backed checkpointer with connection pooling.
29///
30/// # Examples
31///
32/// ```rust,no_run
33/// use agents_persistence::PostgresCheckpointer;
34///
35/// #[tokio::main]
36/// async fn main() -> anyhow::Result<()> {
37///     // Basic usage
38///     let checkpointer = PostgresCheckpointer::new(
39///         "postgresql://user:pass@localhost/agents"
40///     ).await?;
41///
42///     // With custom pool configuration
43///     let checkpointer = PostgresCheckpointer::builder()
44///         .url("postgresql://user:pass@localhost/agents")
45///         .table_name("my_checkpoints")
46///         .max_connections(20)
47///         .build()
48///         .await?;
49///
50///     Ok(())
51/// }
52/// ```
53#[derive(Clone)]
54pub struct PostgresCheckpointer {
55    pool: PgPool,
56    table_name: String,
57}
58
59impl PostgresCheckpointer {
60    /// Create a new PostgreSQL checkpointer with default settings.
61    ///
62    /// This will automatically create the checkpoints table if it doesn't exist.
63    ///
64    /// # Arguments
65    ///
66    /// * `database_url` - PostgreSQL connection string
67    pub async fn new(database_url: &str) -> anyhow::Result<Self> {
68        Self::builder().url(database_url).build().await
69    }
70
71    /// Create a builder for configuring the PostgreSQL checkpointer.
72    pub fn builder() -> PostgresCheckpointerBuilder {
73        PostgresCheckpointerBuilder::default()
74    }
75
76    /// Ensure the checkpoints table exists.
77    async fn ensure_table(&self) -> anyhow::Result<()> {
78        // Create table
79        let create_table_sql = format!(
80            r#"
81            CREATE TABLE IF NOT EXISTS {} (
82                thread_id TEXT PRIMARY KEY,
83                state JSONB NOT NULL,
84                created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
85                updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
86            )
87            "#,
88            self.table_name
89        );
90
91        sqlx::query(&create_table_sql)
92            .execute(&self.pool)
93            .await
94            .context("Failed to create checkpoints table")?;
95
96        // Create index (separate query)
97        let create_index_sql = format!(
98            r#"
99            CREATE INDEX IF NOT EXISTS idx_{}_updated_at 
100            ON {} (updated_at DESC)
101            "#,
102            self.table_name, self.table_name
103        );
104
105        sqlx::query(&create_index_sql)
106            .execute(&self.pool)
107            .await
108            .context("Failed to create index")?;
109
110        Ok(())
111    }
112}
113
114#[async_trait]
115impl Checkpointer for PostgresCheckpointer {
116    async fn save_state(
117        &self,
118        thread_id: &ThreadId,
119        state: &AgentStateSnapshot,
120    ) -> anyhow::Result<()> {
121        let json =
122            serde_json::to_value(state).context("Failed to serialize agent state to JSON")?;
123
124        let query = format!(
125            r#"
126            INSERT INTO {} (thread_id, state, created_at, updated_at)
127            VALUES ($1, $2, NOW(), NOW())
128            ON CONFLICT (thread_id) 
129            DO UPDATE SET state = $2, updated_at = NOW()
130            "#,
131            self.table_name
132        );
133
134        sqlx::query(&query)
135            .bind(thread_id)
136            .bind(&json)
137            .execute(&self.pool)
138            .await
139            .context("Failed to save state to PostgreSQL")?;
140
141        tracing::debug!(
142            thread_id = %thread_id,
143            table = %self.table_name,
144            "Saved agent state to PostgreSQL"
145        );
146
147        Ok(())
148    }
149
150    async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
151        let query = format!(
152            r#"
153            SELECT state FROM {} WHERE thread_id = $1
154            "#,
155            self.table_name
156        );
157
158        let row: Option<(serde_json::Value,)> = sqlx::query_as(&query)
159            .bind(thread_id)
160            .fetch_optional(&self.pool)
161            .await
162            .context("Failed to load state from PostgreSQL")?;
163
164        match row {
165            Some((json,)) => {
166                let state: AgentStateSnapshot = serde_json::from_value(json)
167                    .context("Failed to deserialize agent state from JSON")?;
168
169                tracing::debug!(
170                    thread_id = %thread_id,
171                    table = %self.table_name,
172                    "Loaded agent state from PostgreSQL"
173                );
174
175                Ok(Some(state))
176            }
177            None => {
178                tracing::debug!(
179                    thread_id = %thread_id,
180                    table = %self.table_name,
181                    "No saved state found in PostgreSQL"
182                );
183                Ok(None)
184            }
185        }
186    }
187
188    async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
189        let query = format!(
190            r#"
191            DELETE FROM {} WHERE thread_id = $1
192            "#,
193            self.table_name
194        );
195
196        sqlx::query(&query)
197            .bind(thread_id)
198            .execute(&self.pool)
199            .await
200            .context("Failed to delete thread from PostgreSQL")?;
201
202        tracing::debug!(
203            thread_id = %thread_id,
204            table = %self.table_name,
205            "Deleted thread from PostgreSQL"
206        );
207
208        Ok(())
209    }
210
211    async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
212        let query = format!(
213            r#"
214            SELECT thread_id FROM {} ORDER BY updated_at DESC
215            "#,
216            self.table_name
217        );
218
219        let rows = sqlx::query(&query)
220            .fetch_all(&self.pool)
221            .await
222            .context("Failed to list threads from PostgreSQL")?;
223
224        let threads = rows
225            .into_iter()
226            .map(|row| row.get::<String, _>("thread_id"))
227            .collect();
228
229        Ok(threads)
230    }
231}
232
233/// Builder for configuring a PostgreSQL checkpointer.
234#[derive(Default)]
235pub struct PostgresCheckpointerBuilder {
236    url: Option<String>,
237    table_name: Option<String>,
238    max_connections: Option<u32>,
239    min_connections: Option<u32>,
240}
241
242impl PostgresCheckpointerBuilder {
243    /// Set the PostgreSQL connection URL.
244    pub fn url(mut self, url: impl Into<String>) -> Self {
245        self.url = Some(url.into());
246        self
247    }
248
249    /// Set the table name for storing checkpoints (default: "agent_checkpoints").
250    pub fn table_name(mut self, table_name: impl Into<String>) -> Self {
251        self.table_name = Some(table_name.into());
252        self
253    }
254
255    /// Set the maximum number of connections in the pool (default: 10).
256    pub fn max_connections(mut self, max: u32) -> Self {
257        self.max_connections = Some(max);
258        self
259    }
260
261    /// Set the minimum number of connections in the pool (default: 2).
262    pub fn min_connections(mut self, min: u32) -> Self {
263        self.min_connections = Some(min);
264        self
265    }
266
267    /// Build the PostgreSQL checkpointer and initialize the table.
268    pub async fn build(self) -> anyhow::Result<PostgresCheckpointer> {
269        let url = self
270            .url
271            .ok_or_else(|| anyhow::anyhow!("PostgreSQL URL is required"))?;
272
273        let mut pool_options = PgPoolOptions::new();
274
275        if let Some(max) = self.max_connections {
276            pool_options = pool_options.max_connections(max);
277        } else {
278            pool_options = pool_options.max_connections(10);
279        }
280
281        if let Some(min) = self.min_connections {
282            pool_options = pool_options.min_connections(min);
283        }
284
285        let pool = pool_options
286            .connect(&url)
287            .await
288            .context("Failed to connect to PostgreSQL")?;
289
290        let checkpointer = PostgresCheckpointer {
291            pool,
292            table_name: self
293                .table_name
294                .unwrap_or_else(|| "agent_checkpoints".to_string()),
295        };
296
297        // Ensure table exists
298        checkpointer
299            .ensure_table()
300            .await
301            .context("Failed to initialize database schema")?;
302
303        Ok(checkpointer)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use agents_core::state::TodoItem;
311
312    fn sample_state() -> AgentStateSnapshot {
313        let mut state = AgentStateSnapshot::default();
314        state.todos.push(TodoItem::pending("Test todo"));
315        state
316            .files
317            .insert("test.txt".to_string(), "content".to_string());
318        state
319            .scratchpad
320            .insert("key".to_string(), serde_json::json!("value"));
321        state
322    }
323
324    #[tokio::test]
325    #[ignore] // Requires PostgreSQL instance running
326    async fn test_postgres_save_and_load() {
327        let checkpointer = PostgresCheckpointer::new("postgresql://localhost/agents_test")
328            .await
329            .expect("Failed to connect to PostgreSQL");
330
331        let thread_id = "test-thread".to_string();
332        let state = sample_state();
333
334        // Save state
335        checkpointer
336            .save_state(&thread_id, &state)
337            .await
338            .expect("Failed to save state");
339
340        // Load state
341        let loaded = checkpointer
342            .load_state(&thread_id)
343            .await
344            .expect("Failed to load state");
345
346        assert!(loaded.is_some());
347        let loaded_state = loaded.unwrap();
348
349        assert_eq!(loaded_state.todos.len(), 1);
350        assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
351
352        // Cleanup
353        checkpointer
354            .delete_thread(&thread_id)
355            .await
356            .expect("Failed to delete thread");
357    }
358
359    #[tokio::test]
360    #[ignore] // Requires PostgreSQL instance running
361    async fn test_postgres_list_threads() {
362        let checkpointer = PostgresCheckpointer::builder()
363            .url("postgresql://localhost/agents_test")
364            .table_name("test_checkpoints")
365            .build()
366            .await
367            .expect("Failed to connect to PostgreSQL");
368
369        let state = sample_state();
370
371        // Save multiple threads
372        checkpointer
373            .save_state(&"thread1".to_string(), &state)
374            .await
375            .unwrap();
376        checkpointer
377            .save_state(&"thread2".to_string(), &state)
378            .await
379            .unwrap();
380
381        // List threads
382        let threads = checkpointer.list_threads().await.unwrap();
383        assert!(threads.contains(&"thread1".to_string()));
384        assert!(threads.contains(&"thread2".to_string()));
385
386        // Cleanup
387        checkpointer
388            .delete_thread(&"thread1".to_string())
389            .await
390            .unwrap();
391        checkpointer
392            .delete_thread(&"thread2".to_string())
393            .await
394            .unwrap();
395    }
396}