Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/codex_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,13 @@ impl Agent for CodexAgent {

// Get the session state
let thread = self.get_thread(&request.session_id)?;
let stop_reason = thread.prompt(request).await?;
let (stop_reason, usage) = thread.prompt(request).await?;

Ok(PromptResponse::new(stop_reason))
let mut response = PromptResponse::new(stop_reason);
if let Some(usage) = usage {
response = response.usage(usage);
}
Ok(response)
}

async fn cancel(&self, args: CancelNotification) -> Result<(), Error> {
Expand Down
84 changes: 60 additions & 24 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{

use agent_client_protocol::{
AvailableCommand, AvailableCommandInput, AvailableCommandsUpdate, Client, ClientCapabilities,
ConfigOptionUpdate, Content, ContentBlock, ContentChunk, Diff, EmbeddedResource,
ConfigOptionUpdate, Content, ContentBlock, ContentChunk, Cost, Diff, EmbeddedResource,
EmbeddedResourceResource, Error, LoadSessionResponse, Meta, ModelId, ModelInfo,
PermissionOption, PermissionOptionKind, Plan, PlanEntry, PlanEntryPriority, PlanEntryStatus,
PromptRequest, RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse,
Expand All @@ -18,7 +18,7 @@ use agent_client_protocol::{
SessionInfoUpdate, SessionMode, SessionModeId, SessionModeState, SessionModelState,
SessionNotification, SessionUpdate, StopReason, Terminal, TextResourceContents, ToolCall,
ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate,
ToolCallUpdateFields, ToolKind, UnstructuredCommandInput,
ToolCallUpdateFields, ToolKind, UnstructuredCommandInput, Usage, UsageUpdate,
};
use codex_apply_patch::parse_patch;
use codex_core::{
Expand All @@ -38,7 +38,8 @@ use codex_core::{
ModelRerouteEvent, Op, PatchApplyBeginEvent, PatchApplyEndEvent, PatchApplyStatus,
ReasoningContentDeltaEvent, ReasoningRawContentDeltaEvent, ReviewDecision,
ReviewOutputEvent, ReviewRequest, ReviewTarget, SandboxPolicy, StreamErrorEvent,
TerminalInteractionEvent, TurnAbortedEvent, TurnCompleteEvent, TurnStartedEvent,
TerminalInteractionEvent, TokenCountEvent, TurnAbortedEvent, TurnCompleteEvent,
TurnStartedEvent,
UserMessageEvent, ViewImageToolCallEvent, WarningEvent, WebSearchBeginEvent,
WebSearchEndEvent,
},
Expand Down Expand Up @@ -130,7 +131,8 @@ enum ThreadMessage {
},
Prompt {
request: PromptRequest,
response_tx: oneshot::Sender<Result<oneshot::Receiver<Result<StopReason, Error>>, Error>>,
response_tx:
oneshot::Sender<Result<oneshot::Receiver<Result<(StopReason, Option<Usage>), Error>>, Error>>,
},
SetMode {
mode: SessionModeId,
Expand Down Expand Up @@ -210,7 +212,7 @@ impl Thread {
.map_err(|e| Error::internal_error().data(e.to_string()))?
}

pub async fn prompt(&self, request: PromptRequest) -> Result<StopReason, Error> {
pub async fn prompt(&self, request: PromptRequest) -> Result<(StopReason, Option<Usage>), Error> {
let (response_tx, response_rx) = oneshot::channel();

let message = ThreadMessage::Prompt {
Expand Down Expand Up @@ -362,15 +364,17 @@ struct PromptState {
active_web_search: Option<String>,
thread: Arc<dyn CodexThreadImpl>,
event_count: usize,
response_tx: Option<oneshot::Sender<Result<StopReason, Error>>>,
response_tx: Option<oneshot::Sender<Result<(StopReason, Option<Usage>), Error>>>,
seen_message_deltas: bool,
seen_reasoning_deltas: bool,
/// Last token usage info received from a TokenCount event.
last_usage: Option<Usage>,
}

impl PromptState {
fn new(
thread: Arc<dyn CodexThreadImpl>,
response_tx: oneshot::Sender<Result<StopReason, Error>>,
response_tx: oneshot::Sender<Result<(StopReason, Option<Usage>), Error>>,
) -> Self {
Self {
active_commands: HashMap::new(),
Expand All @@ -380,6 +384,7 @@ impl PromptState {
response_tx: Some(response_tx),
seen_message_deltas: false,
seen_reasoning_deltas: false,
last_usage: None,
}
}

Expand Down Expand Up @@ -620,7 +625,8 @@ impl PromptState {
self.event_count
);
if let Some(response_tx) = self.response_tx.take() {
response_tx.send(Ok(StopReason::EndTurn)).ok();
let usage = self.last_usage.take();
response_tx.send(Ok((StopReason::EndTurn, usage))).ok();
}
}
EventMsg::UndoStarted(event) => {
Expand Down Expand Up @@ -665,13 +671,13 @@ impl PromptState {
EventMsg::TurnAborted(TurnAbortedEvent { reason, turn_id }) => {
info!("Turn {turn_id:?} aborted: {reason:?}");
if let Some(response_tx) = self.response_tx.take() {
response_tx.send(Ok(StopReason::Cancelled)).ok();
response_tx.send(Ok((StopReason::Cancelled, None))).ok();
}
}
EventMsg::ShutdownComplete => {
info!("Agent shutting down");
if let Some(response_tx) = self.response_tx.take() {
response_tx.send(Ok(StopReason::Cancelled)).ok();
response_tx.send(Ok((StopReason::Cancelled, None))).ok();
}
}
EventMsg::ViewImageToolCall(ViewImageToolCallEvent { call_id, path }) => {
Expand Down Expand Up @@ -726,11 +732,41 @@ impl PromptState {
info!("Model reroute: from={from_model}, to={to_model}, reason={reason:?}");
}

EventMsg::TokenCount(TokenCountEvent { info, .. }) => {
if let Some(info) = info {
let total = &info.total_token_usage;
let input_tokens = total.input_tokens.max(0) as u64;
let output_tokens = total.output_tokens.max(0) as u64;
let cached_input = total.cached_input_tokens.max(0) as u64;
let reasoning = total.reasoning_output_tokens.max(0) as u64;
let total_tokens = total.total_tokens.max(0) as u64;
let context_window = info.model_context_window.unwrap_or(0).max(0) as u64;

// PromptResponse.usage: cumulative session totals (per ACP spec)
let usage = Usage::new(total_tokens, input_tokens, output_tokens)
.cached_read_tokens(cached_input)
.thought_tokens(reasoning);
self.last_usage = Some(usage);

// UsageUpdate.used: current context window usage from the last turn.
// Use last_token_usage (per-turn) not total_token_usage (cumulative),
// and only count input tokens (input + cached_input) since those
// represent the actual context sent to the model. Output and reasoning
// tokens are the model's response, not context window consumption.
let last = &info.last_token_usage;
let context_used =
(last.input_tokens.max(0) as u64) + (last.cached_input_tokens.max(0) as u64);

self.send_notification(SessionUpdate::UsageUpdate(
UsageUpdate::new(context_used, context_window),
))
.await;
}
}

// Ignore these events
EventMsg::AgentReasoningRawContent(..)
| EventMsg::ThreadRolledBack(..)
// In the future we can use this to update usage stats
| EventMsg::TokenCount(..)
// we already have a way to diff the turn, so ignore
| EventMsg::TurnDiff(..)
// Revisit when we can emit status updates
Expand Down Expand Up @@ -2192,7 +2228,7 @@ impl<A: Auth> ThreadActor<A> {
async fn handle_prompt(
&mut self,
request: PromptRequest,
) -> Result<oneshot::Receiver<Result<StopReason, Error>>, Error> {
) -> Result<oneshot::Receiver<Result<(StopReason, Option<Usage>), Error>>, Error> {
let (response_tx, response_rx) = oneshot::channel();

let items = build_prompt_items(request.prompt);
Expand Down Expand Up @@ -2891,7 +2927,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -2927,7 +2963,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -2965,7 +3001,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3015,7 +3051,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3065,7 +3101,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3120,7 +3156,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3175,7 +3211,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3232,7 +3268,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3294,7 +3330,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3346,7 +3382,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down Expand Up @@ -3723,7 +3759,7 @@ mod tests {

tokio::try_join!(
async {
let stop_reason = prompt_response_rx.await??.await??;
let (stop_reason, _usage) = prompt_response_rx.await??.await??;
assert_eq!(stop_reason, StopReason::EndTurn);
drop(message_tx);
anyhow::Ok(())
Expand Down