agents_runtime/middleware/
token_tracking.rs

1//! Token tracking middleware for monitoring LLM usage and costs
2//!
3//! This middleware intercepts LLM requests and responses to track token usage,
4//! costs, and other usage metrics across different providers.
5
6use crate::middleware::{AgentMiddleware, MiddlewareContext};
7use agents_core::events::{AgentEvent, EventMetadata, TokenUsage, TokenUsageEvent};
8use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
9use agents_core::messaging::AgentMessage;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13use std::sync::{Arc, RwLock};
14use std::time::Instant;
15
16/// Configuration for token tracking middleware
17#[derive(Debug, Clone)]
18pub struct TokenTrackingConfig {
19    /// Whether to track token usage
20    pub enabled: bool,
21    /// Whether to emit token usage events
22    pub emit_events: bool,
23    /// Whether to log token usage to console
24    pub log_usage: bool,
25    /// Custom cost per token (overrides provider defaults)
26    pub custom_costs: Option<TokenCosts>,
27}
28
29impl Default for TokenTrackingConfig {
30    fn default() -> Self {
31        Self {
32            enabled: true,
33            emit_events: true,
34            log_usage: true,
35            custom_costs: None,
36        }
37    }
38}
39
40/// Token cost configuration for different providers
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct TokenCosts {
43    /// Cost per input token (in USD)
44    pub input_cost_per_token: f64,
45    /// Cost per output token (in USD)
46    pub output_cost_per_token: f64,
47    /// Provider name for reference
48    pub provider: String,
49    /// Model name for reference
50    pub model: String,
51}
52
53impl TokenCosts {
54    pub fn new(
55        provider: impl Into<String>,
56        model: impl Into<String>,
57        input_cost: f64,
58        output_cost: f64,
59    ) -> Self {
60        Self {
61            provider: provider.into(),
62            model: model.into(),
63            input_cost_per_token: input_cost,
64            output_cost_per_token: output_cost,
65        }
66    }
67
68    /// Predefined costs for common models
69    pub fn openai_gpt4o_mini() -> Self {
70        Self::new("openai", "gpt-4o-mini", 0.00015 / 1000.0, 0.0006 / 1000.0)
71    }
72
73    pub fn openai_gpt4o() -> Self {
74        Self::new("openai", "gpt-4o", 0.005 / 1000.0, 0.015 / 1000.0)
75    }
76
77    pub fn anthropic_claude_sonnet() -> Self {
78        Self::new(
79            "anthropic",
80            "claude-3-5-sonnet-20241022",
81            0.003 / 1000.0,
82            0.015 / 1000.0,
83        )
84    }
85
86    pub fn gemini_flash() -> Self {
87        Self::new(
88            "gemini",
89            "gemini-2.0-flash-exp",
90            0.000075 / 1000.0,
91            0.0003 / 1000.0,
92        )
93    }
94}
95
96// TokenUsage is now defined in agents_core::events
97
98// TokenUsageEvent is now defined in agents_core::events
99
100/// Token tracking middleware that wraps an LLM to monitor usage
101pub struct TokenTrackingMiddleware {
102    config: TokenTrackingConfig,
103    inner_model: Arc<dyn LanguageModel>,
104    event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
105    usage_stats: Arc<RwLock<Vec<TokenUsage>>>,
106}
107
108impl TokenTrackingMiddleware {
109    pub fn new(
110        config: TokenTrackingConfig,
111        inner_model: Arc<dyn LanguageModel>,
112        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
113    ) -> Self {
114        Self {
115            config,
116            inner_model,
117            event_dispatcher,
118            usage_stats: Arc::new(RwLock::new(Vec::new())),
119        }
120    }
121
122    /// Get accumulated usage statistics
123    pub fn get_usage_stats(&self) -> Vec<TokenUsage> {
124        self.usage_stats.read().unwrap().clone()
125    }
126
127    /// Get total usage summary
128    pub fn get_total_usage(&self) -> TokenUsageSummary {
129        let stats = self.get_usage_stats();
130        let mut total_input = 0;
131        let mut total_output = 0;
132        let mut total_cost = 0.0;
133        let mut total_duration = 0;
134
135        for usage in &stats {
136            total_input += usage.input_tokens;
137            total_output += usage.output_tokens;
138            total_cost += usage.estimated_cost;
139            total_duration += usage.duration_ms;
140        }
141
142        TokenUsageSummary {
143            total_input_tokens: total_input,
144            total_output_tokens: total_output,
145            total_tokens: total_input + total_output,
146            total_cost,
147            total_duration_ms: total_duration,
148            request_count: stats.len(),
149        }
150    }
151
152    /// Clear usage statistics
153    pub fn clear_stats(&self) {
154        self.usage_stats.write().unwrap().clear();
155    }
156
157    fn emit_token_event(&self, usage: TokenUsage) {
158        if self.config.emit_events {
159            if let Some(dispatcher) = &self.event_dispatcher {
160                let event = AgentEvent::TokenUsage(TokenUsageEvent {
161                    metadata: EventMetadata::new(
162                        "default".to_string(),
163                        uuid::Uuid::new_v4().to_string(),
164                        None,
165                    ),
166                    usage,
167                });
168
169                let dispatcher_clone = dispatcher.clone();
170                tokio::spawn(async move {
171                    dispatcher_clone.dispatch(event).await;
172                });
173            }
174        }
175    }
176
177    fn log_usage(&self, usage: &TokenUsage) {
178        if self.config.log_usage {
179            tracing::info!(
180                provider = %usage.provider,
181                model = %usage.model,
182                input_tokens = usage.input_tokens,
183                output_tokens = usage.output_tokens,
184                total_tokens = usage.total_tokens,
185                estimated_cost = usage.estimated_cost,
186                duration_ms = usage.duration_ms,
187                "🔢 Token usage tracked"
188            );
189        }
190    }
191
192    fn extract_token_usage(&self, request: &LlmRequest, response: &LlmResponse) -> TokenUsage {
193        let start_time = Instant::now();
194
195        // Estimate tokens based on text length (rough approximation)
196        let input_tokens = self.estimate_tokens(&request.system_prompt)
197            + request
198                .messages
199                .iter()
200                .map(|msg| self.estimate_tokens(&self.message_to_text(msg)))
201                .sum::<u32>();
202
203        let output_tokens = self.estimate_tokens(&self.message_to_text(&response.message));
204        let duration_ms = start_time.elapsed().as_millis() as u64;
205
206        // Try to determine provider and model from the inner model
207        let (provider, model) = self.detect_provider_model();
208
209        // Calculate estimated cost
210        let estimated_cost = if let Some(costs) = &self.config.custom_costs {
211            (input_tokens as f64 * costs.input_cost_per_token)
212                + (output_tokens as f64 * costs.output_cost_per_token)
213        } else {
214            0.0 // Unknown cost
215        };
216
217        TokenUsage::new(
218            input_tokens,
219            output_tokens,
220            provider,
221            model,
222            duration_ms,
223            estimated_cost,
224        )
225    }
226
227    fn estimate_tokens(&self, text: &str) -> u32 {
228        // Rough estimation: ~4 characters per token for English text
229        // This is a simplified approximation - real tokenization varies by provider
230        (text.len() as f32 / 4.0).ceil() as u32
231    }
232
233    fn message_to_text(&self, message: &AgentMessage) -> String {
234        match &message.content {
235            agents_core::messaging::MessageContent::Text(text) => text.clone(),
236            agents_core::messaging::MessageContent::Json(json) => json.to_string(),
237        }
238    }
239
240    fn detect_provider_model(&self) -> (String, String) {
241        // Try to detect provider/model from the inner model
242        // This is a simplified approach - in practice, you might want to store
243        // this information in the model wrapper or use type information
244        ("unknown".to_string(), "unknown".to_string())
245    }
246}
247
248#[async_trait]
249impl LanguageModel for TokenTrackingMiddleware {
250    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
251        if !self.config.enabled {
252            return self.inner_model.generate(request).await;
253        }
254
255        let response = self.inner_model.generate(request.clone()).await?;
256
257        let usage = self.extract_token_usage(&request, &response);
258
259        // Store usage statistics
260        {
261            let mut stats = self.usage_stats.write().unwrap();
262            stats.push(usage.clone());
263        }
264
265        // Emit event and log
266        self.emit_token_event(usage.clone());
267        self.log_usage(&usage);
268
269        Ok(response)
270    }
271
272    async fn generate_stream(
273        &self,
274        request: LlmRequest,
275    ) -> anyhow::Result<agents_core::llm::ChunkStream> {
276        if !self.config.enabled {
277            return self.inner_model.generate_stream(request).await;
278        }
279
280        // For streaming, we'll track usage when the stream completes
281        // This is a simplified implementation - in practice, you might want
282        // to track partial usage as chunks arrive
283        let response = self.inner_model.generate_stream(request).await?;
284
285        // Wrap the stream to track usage when it completes
286        let config = self.config.clone();
287        let usage_stats = self.usage_stats.clone();
288        let event_dispatcher = self.event_dispatcher.clone();
289
290        Ok(Box::pin(futures::stream::unfold(
291            (response, Instant::now()),
292            move |(mut stream, start_time)| {
293                let config = config.clone();
294                let usage_stats = usage_stats.clone();
295                let event_dispatcher = event_dispatcher.clone();
296                async move {
297                    match stream.next().await {
298                        Some(Ok(chunk)) => {
299                            match chunk {
300                                agents_core::llm::StreamChunk::Done { ref message } => {
301                                    // Stream completed - track usage
302                                    let _response = LlmResponse {
303                                        message: message.clone(),
304                                    };
305                                    let duration_ms = start_time.elapsed().as_millis() as u64;
306
307                                    // Calculate estimated cost (simplified)
308                                    let input_tokens = 100; // Simplified estimation
309                                    let output_tokens = 50; // Simplified estimation
310
311                                    let estimated_cost = if let Some(costs) = &config.custom_costs {
312                                        (input_tokens as f64 * costs.input_cost_per_token)
313                                            + (output_tokens as f64 * costs.output_cost_per_token)
314                                    } else {
315                                        0.0 // Unknown cost
316                                    };
317
318                                    let usage = TokenUsage::new(
319                                        input_tokens,
320                                        output_tokens,
321                                        "unknown",
322                                        "unknown",
323                                        duration_ms,
324                                        estimated_cost,
325                                    );
326
327                                    // Store and emit usage
328                                    {
329                                        let mut stats = usage_stats.write().unwrap();
330                                        stats.push(usage.clone());
331                                    }
332
333                                    if config.emit_events {
334                                        if let Some(dispatcher) = &event_dispatcher {
335                                            let event = AgentEvent::TokenUsage(TokenUsageEvent {
336                                                metadata: EventMetadata::new(
337                                                    "default".to_string(),
338                                                    uuid::Uuid::new_v4().to_string(),
339                                                    None,
340                                                ),
341                                                usage,
342                                            });
343
344                                            let dispatcher_clone = dispatcher.clone();
345                                            tokio::spawn(async move {
346                                                dispatcher_clone.dispatch(event).await;
347                                            });
348                                        }
349                                    }
350
351                                    if config.log_usage {
352                                        tracing::info!(
353                                            provider = "unknown",
354                                            model = "unknown",
355                                            input_tokens = input_tokens,
356                                            output_tokens = output_tokens,
357                                            total_tokens = input_tokens + output_tokens,
358                                            estimated_cost = estimated_cost,
359                                            duration_ms = duration_ms,
360                                            "🔢 Token usage tracked"
361                                        );
362                                    }
363
364                                    Some((Ok(chunk), (stream, start_time)))
365                                }
366                                _ => Some((Ok(chunk), (stream, start_time))),
367                            }
368                        }
369                        Some(Err(e)) => Some((Err(e), (stream, start_time))),
370                        None => None,
371                    }
372                }
373            },
374        )))
375    }
376}
377
378#[async_trait]
379impl AgentMiddleware for TokenTrackingMiddleware {
380    fn id(&self) -> &'static str {
381        "token_tracking"
382    }
383
384    async fn modify_model_request(&self, _ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
385        // Token tracking doesn't modify requests, just monitors them
386        Ok(())
387    }
388}
389
390/// Summary of token usage across all requests
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct TokenUsageSummary {
393    pub total_input_tokens: u32,
394    pub total_output_tokens: u32,
395    pub total_tokens: u32,
396    pub total_cost: f64,
397    pub total_duration_ms: u64,
398    pub request_count: usize,
399}
400
401impl TokenUsageSummary {
402    pub fn average_tokens_per_request(&self) -> f64 {
403        if self.request_count > 0 {
404            self.total_tokens as f64 / self.request_count as f64
405        } else {
406            0.0
407        }
408    }
409
410    pub fn average_cost_per_request(&self) -> f64 {
411        if self.request_count > 0 {
412            self.total_cost / self.request_count as f64
413        } else {
414            0.0
415        }
416    }
417}