agents_runtime/middleware/
token_tracking.rs1use crate::middleware::{AgentMiddleware, MiddlewareContext};
7use agents_core::events::{AgentEvent, EventMetadata, TokenUsage, TokenUsageEvent};
8use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
9use agents_core::messaging::AgentMessage;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13use std::sync::{Arc, RwLock};
14use std::time::Instant;
15
16#[derive(Debug, Clone)]
18pub struct TokenTrackingConfig {
19 pub enabled: bool,
21 pub emit_events: bool,
23 pub log_usage: bool,
25 pub custom_costs: Option<TokenCosts>,
27}
28
29impl Default for TokenTrackingConfig {
30 fn default() -> Self {
31 Self {
32 enabled: true,
33 emit_events: true,
34 log_usage: true,
35 custom_costs: None,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct TokenCosts {
43 pub input_cost_per_token: f64,
45 pub output_cost_per_token: f64,
47 pub provider: String,
49 pub model: String,
51}
52
53impl TokenCosts {
54 pub fn new(
55 provider: impl Into<String>,
56 model: impl Into<String>,
57 input_cost: f64,
58 output_cost: f64,
59 ) -> Self {
60 Self {
61 provider: provider.into(),
62 model: model.into(),
63 input_cost_per_token: input_cost,
64 output_cost_per_token: output_cost,
65 }
66 }
67
68 pub fn openai_gpt4o_mini() -> Self {
70 Self::new("openai", "gpt-4o-mini", 0.00015 / 1000.0, 0.0006 / 1000.0)
71 }
72
73 pub fn openai_gpt4o() -> Self {
74 Self::new("openai", "gpt-4o", 0.005 / 1000.0, 0.015 / 1000.0)
75 }
76
77 pub fn anthropic_claude_sonnet() -> Self {
78 Self::new(
79 "anthropic",
80 "claude-3-5-sonnet-20241022",
81 0.003 / 1000.0,
82 0.015 / 1000.0,
83 )
84 }
85
86 pub fn gemini_flash() -> Self {
87 Self::new(
88 "gemini",
89 "gemini-2.0-flash-exp",
90 0.000075 / 1000.0,
91 0.0003 / 1000.0,
92 )
93 }
94}
95
96pub struct TokenTrackingMiddleware {
102 config: TokenTrackingConfig,
103 inner_model: Arc<dyn LanguageModel>,
104 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
105 usage_stats: Arc<RwLock<Vec<TokenUsage>>>,
106}
107
108impl TokenTrackingMiddleware {
109 pub fn new(
110 config: TokenTrackingConfig,
111 inner_model: Arc<dyn LanguageModel>,
112 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
113 ) -> Self {
114 Self {
115 config,
116 inner_model,
117 event_dispatcher,
118 usage_stats: Arc::new(RwLock::new(Vec::new())),
119 }
120 }
121
122 pub fn get_usage_stats(&self) -> Vec<TokenUsage> {
124 self.usage_stats.read().unwrap().clone()
125 }
126
127 pub fn get_total_usage(&self) -> TokenUsageSummary {
129 let stats = self.get_usage_stats();
130 let mut total_input = 0;
131 let mut total_output = 0;
132 let mut total_cost = 0.0;
133 let mut total_duration = 0;
134
135 for usage in &stats {
136 total_input += usage.input_tokens;
137 total_output += usage.output_tokens;
138 total_cost += usage.estimated_cost;
139 total_duration += usage.duration_ms;
140 }
141
142 TokenUsageSummary {
143 total_input_tokens: total_input,
144 total_output_tokens: total_output,
145 total_tokens: total_input + total_output,
146 total_cost,
147 total_duration_ms: total_duration,
148 request_count: stats.len(),
149 }
150 }
151
152 pub fn clear_stats(&self) {
154 self.usage_stats.write().unwrap().clear();
155 }
156
157 fn emit_token_event(&self, usage: TokenUsage) {
158 if self.config.emit_events {
159 if let Some(dispatcher) = &self.event_dispatcher {
160 let event = AgentEvent::TokenUsage(TokenUsageEvent {
161 metadata: EventMetadata::new(
162 "default".to_string(),
163 uuid::Uuid::new_v4().to_string(),
164 None,
165 ),
166 usage,
167 });
168
169 let dispatcher_clone = dispatcher.clone();
170 tokio::spawn(async move {
171 dispatcher_clone.dispatch(event).await;
172 });
173 }
174 }
175 }
176
177 fn log_usage(&self, usage: &TokenUsage) {
178 if self.config.log_usage {
179 tracing::info!(
180 provider = %usage.provider,
181 model = %usage.model,
182 input_tokens = usage.input_tokens,
183 output_tokens = usage.output_tokens,
184 total_tokens = usage.total_tokens,
185 estimated_cost = usage.estimated_cost,
186 duration_ms = usage.duration_ms,
187 "🔢 Token usage tracked"
188 );
189 }
190 }
191
192 fn extract_token_usage(&self, request: &LlmRequest, response: &LlmResponse) -> TokenUsage {
193 let start_time = Instant::now();
194
195 let input_tokens = self.estimate_tokens(&request.system_prompt)
197 + request
198 .messages
199 .iter()
200 .map(|msg| self.estimate_tokens(&self.message_to_text(msg)))
201 .sum::<u32>();
202
203 let output_tokens = self.estimate_tokens(&self.message_to_text(&response.message));
204 let duration_ms = start_time.elapsed().as_millis() as u64;
205
206 let (provider, model) = self.detect_provider_model();
208
209 let estimated_cost = if let Some(costs) = &self.config.custom_costs {
211 (input_tokens as f64 * costs.input_cost_per_token)
212 + (output_tokens as f64 * costs.output_cost_per_token)
213 } else {
214 0.0 };
216
217 TokenUsage::new(
218 input_tokens,
219 output_tokens,
220 provider,
221 model,
222 duration_ms,
223 estimated_cost,
224 )
225 }
226
227 fn estimate_tokens(&self, text: &str) -> u32 {
228 (text.len() as f32 / 4.0).ceil() as u32
231 }
232
233 fn message_to_text(&self, message: &AgentMessage) -> String {
234 match &message.content {
235 agents_core::messaging::MessageContent::Text(text) => text.clone(),
236 agents_core::messaging::MessageContent::Json(json) => json.to_string(),
237 }
238 }
239
240 fn detect_provider_model(&self) -> (String, String) {
241 ("unknown".to_string(), "unknown".to_string())
245 }
246}
247
248#[async_trait]
249impl LanguageModel for TokenTrackingMiddleware {
250 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
251 if !self.config.enabled {
252 return self.inner_model.generate(request).await;
253 }
254
255 let response = self.inner_model.generate(request.clone()).await?;
256
257 let usage = self.extract_token_usage(&request, &response);
258
259 {
261 let mut stats = self.usage_stats.write().unwrap();
262 stats.push(usage.clone());
263 }
264
265 self.emit_token_event(usage.clone());
267 self.log_usage(&usage);
268
269 Ok(response)
270 }
271
272 async fn generate_stream(
273 &self,
274 request: LlmRequest,
275 ) -> anyhow::Result<agents_core::llm::ChunkStream> {
276 if !self.config.enabled {
277 return self.inner_model.generate_stream(request).await;
278 }
279
280 let response = self.inner_model.generate_stream(request).await?;
284
285 let config = self.config.clone();
287 let usage_stats = self.usage_stats.clone();
288 let event_dispatcher = self.event_dispatcher.clone();
289
290 Ok(Box::pin(futures::stream::unfold(
291 (response, Instant::now()),
292 move |(mut stream, start_time)| {
293 let config = config.clone();
294 let usage_stats = usage_stats.clone();
295 let event_dispatcher = event_dispatcher.clone();
296 async move {
297 match stream.next().await {
298 Some(Ok(chunk)) => {
299 match chunk {
300 agents_core::llm::StreamChunk::Done { ref message } => {
301 let _response = LlmResponse {
303 message: message.clone(),
304 };
305 let duration_ms = start_time.elapsed().as_millis() as u64;
306
307 let input_tokens = 100; let output_tokens = 50; let estimated_cost = if let Some(costs) = &config.custom_costs {
312 (input_tokens as f64 * costs.input_cost_per_token)
313 + (output_tokens as f64 * costs.output_cost_per_token)
314 } else {
315 0.0 };
317
318 let usage = TokenUsage::new(
319 input_tokens,
320 output_tokens,
321 "unknown",
322 "unknown",
323 duration_ms,
324 estimated_cost,
325 );
326
327 {
329 let mut stats = usage_stats.write().unwrap();
330 stats.push(usage.clone());
331 }
332
333 if config.emit_events {
334 if let Some(dispatcher) = &event_dispatcher {
335 let event = AgentEvent::TokenUsage(TokenUsageEvent {
336 metadata: EventMetadata::new(
337 "default".to_string(),
338 uuid::Uuid::new_v4().to_string(),
339 None,
340 ),
341 usage,
342 });
343
344 let dispatcher_clone = dispatcher.clone();
345 tokio::spawn(async move {
346 dispatcher_clone.dispatch(event).await;
347 });
348 }
349 }
350
351 if config.log_usage {
352 tracing::info!(
353 provider = "unknown",
354 model = "unknown",
355 input_tokens = input_tokens,
356 output_tokens = output_tokens,
357 total_tokens = input_tokens + output_tokens,
358 estimated_cost = estimated_cost,
359 duration_ms = duration_ms,
360 "🔢 Token usage tracked"
361 );
362 }
363
364 Some((Ok(chunk), (stream, start_time)))
365 }
366 _ => Some((Ok(chunk), (stream, start_time))),
367 }
368 }
369 Some(Err(e)) => Some((Err(e), (stream, start_time))),
370 None => None,
371 }
372 }
373 },
374 )))
375 }
376}
377
378#[async_trait]
379impl AgentMiddleware for TokenTrackingMiddleware {
380 fn id(&self) -> &'static str {
381 "token_tracking"
382 }
383
384 async fn modify_model_request(&self, _ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
385 Ok(())
387 }
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct TokenUsageSummary {
393 pub total_input_tokens: u32,
394 pub total_output_tokens: u32,
395 pub total_tokens: u32,
396 pub total_cost: f64,
397 pub total_duration_ms: u64,
398 pub request_count: usize,
399}
400
401impl TokenUsageSummary {
402 pub fn average_tokens_per_request(&self) -> f64 {
403 if self.request_count > 0 {
404 self.total_tokens as f64 / self.request_count as f64
405 } else {
406 0.0
407 }
408 }
409
410 pub fn average_cost_per_request(&self) -> f64 {
411 if self.request_count > 0 {
412 self.total_cost / self.request_count as f64
413 } else {
414 0.0
415 }
416 }
417}