Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 82 additions & 13 deletions langchain_litellm/chat_models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]

# Add reasoning_content support for thinking-enabled models
if _dict.get("reasoning_content"):
additional_kwargs["reasoning_content"] = _dict["reasoning_content"]

return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
Expand Down Expand Up @@ -148,7 +152,7 @@ def _convert_delta_to_message_chunk(
function_call = delta.function_call
raw_tool_calls = delta.tool_calls
reasoning_content = getattr(delta, "reasoning_content", None)

if function_call:
additional_kwargs = {"function_call": dict(function_call)}
# The hasattr check is necessary because litellm explicitly deletes the
Expand All @@ -173,7 +177,7 @@ def _convert_delta_to_message_chunk(
)
for rtc in raw_tool_calls
]
except KeyError:
except (KeyError, AttributeError):
pass

if role == "user" or default_class == HumanMessageChunk:
Expand Down Expand Up @@ -288,7 +292,7 @@ def _default_params(self) -> Dict[str, Any]:
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
return {
params = {
"model": set_model_value,
"force_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
Expand All @@ -299,6 +303,13 @@ def _default_params(self) -> Dict[str, Any]:
**self.model_kwargs,
}

# Add stream_options for usage tracking in streaming responses
# This enables token usage metadata in streaming chunks
if self.streaming:
params["stream_options"] = {"include_usage": True}

return params

@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the openai client."""
Expand Down Expand Up @@ -453,6 +464,10 @@ def _stream(
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}

# Ensure stream_options is set for usage tracking
if "stream_options" not in params:
params["stream_options"] = {"include_usage": True}

default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
Expand All @@ -461,12 +476,23 @@ def _stream(
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue

# Extract usage metadata from chunk if present
usage_metadata = None
if "usage" in chunk and chunk["usage"]:
usage_metadata = _create_usage_metadata(chunk["usage"])

delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = message_chunk.__class__

# Attach usage metadata to the message chunk
if usage_metadata:
message_chunk.usage_metadata = usage_metadata

cg_chunk = ChatGenerationChunk(message=message_chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk)
yield cg_chunk

async def _astream(
Expand All @@ -479,20 +505,37 @@ async def _astream(
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}

# Ensure stream_options is set for usage tracking
if "stream_options" not in params:
params["stream_options"] = {"include_usage": True}

default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params

# For async streaming, we need to use the async completion method properly
async for chunk in self.acompletion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue

# Extract usage metadata from chunk if present
usage_metadata = None
if "usage" in chunk and chunk["usage"]:
usage_metadata = _create_usage_metadata(chunk["usage"])

delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = message_chunk.__class__

# Attach usage metadata to the message chunk
if usage_metadata:
message_chunk.usage_metadata = usage_metadata

cg_chunk = ChatGenerationChunk(message=message_chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
await run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk)
yield cg_chunk

async def _agenerate(
Expand Down Expand Up @@ -598,8 +641,34 @@ def _llm_type(self) -> str:
def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata:
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)

# Extract advanced usage details
input_token_details = {}
output_token_details = {}

# Cache tokens (for providers that support it like OpenAI, Anthropic)
if "cache_read_input_tokens" in token_usage:
input_token_details["cache_read"] = token_usage["cache_read_input_tokens"]

if "cache_creation_input_tokens" in token_usage:
input_token_details["cache_creation"] = token_usage["cache_creation_input_tokens"]

# Audio tokens (for multimodal models)
if "audio_input_tokens" in token_usage:
input_token_details["audio"] = token_usage["audio_input_tokens"]

if "audio_output_tokens" in token_usage:
output_token_details["audio"] = token_usage["audio_output_tokens"]

# Reasoning tokens (for o1 models, Claude thinking, etc.)
completion_tokens_details = token_usage.get("completion_tokens_details", {})
if completion_tokens_details and "reasoning_tokens" in completion_tokens_details:
output_token_details["reasoning"] = completion_tokens_details["reasoning_tokens"]

return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
input_token_details=input_token_details if input_token_details else {},
output_token_details=output_token_details if output_token_details else {},
)
186 changes: 186 additions & 0 deletions tests/integration_tests/test_streaming_usage_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Integration tests for streaming usage metadata functionality."""

import os
import pytest
from typing import List

from langchain_core.messages import HumanMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_litellm.chat_models import ChatLiteLLM


class TestStreamingUsageMetadata:
"""Test streaming usage metadata with real API calls."""

def test_openai_streaming_usage_metadata(self):
"""Test OpenAI streaming with usage metadata."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not set")

llm = ChatLiteLLM(
model="gpt-3.5-turbo",
openai_api_key=api_key,
streaming=True,
max_retries=1
)

messages = [HumanMessage(content="Say hello in exactly 5 words.")]

chunks = []
usage_metadata_found = False

for chunk in llm.stream(messages):
chunks.append(chunk)
# chunk is an AIMessageChunk directly, not ChatGenerationChunk
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
usage_metadata_found = True
usage = chunk.usage_metadata
assert usage["input_tokens"] > 0
assert usage["output_tokens"] > 0
assert usage["total_tokens"] > 0
assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"]

assert len(chunks) > 0
assert usage_metadata_found, "No usage metadata found in streaming chunks"

def test_openai_streaming_usage_metadata_with_cache(self):
"""Test OpenAI streaming with cache tokens (if supported)."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not set")

llm = ChatLiteLLM(
model="gpt-4o-mini", # Use a model that supports caching
openai_api_key=api_key,
streaming=True,
max_retries=1
)

# Send the same message twice to potentially trigger caching
messages = [HumanMessage(content="What is the capital of France? Please answer in exactly one word.")]

# First call
chunks1 = list(llm.stream(messages))

# Second call (might use cache)
chunks2 = list(llm.stream(messages))

# Check if any chunks have cache information
for chunks in [chunks1, chunks2]:
for chunk in chunks:
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
usage = chunk.usage_metadata
if usage.get("input_token_details") and "cache_read" in usage["input_token_details"]:
assert usage["input_token_details"]["cache_read"] >= 0

def test_anthropic_streaming_usage_metadata(self):
"""Test Anthropic streaming with usage metadata."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not set")

llm = ChatLiteLLM(
model="claude-3-haiku-20240307",
anthropic_api_key=api_key,
streaming=True,
max_retries=1
)

messages = [HumanMessage(content="Say hello in exactly 3 words.")]

chunks = []
usage_metadata_found = False

for chunk in llm.stream(messages):
chunks.append(chunk)
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
usage_metadata_found = True
usage = chunk.usage_metadata
assert usage["input_tokens"] > 0
assert usage["output_tokens"] > 0
assert usage["total_tokens"] > 0

assert len(chunks) > 0
assert usage_metadata_found, "No usage metadata found in Anthropic streaming chunks"

@pytest.mark.asyncio
async def test_openai_async_streaming_usage_metadata(self):
"""Test OpenAI async streaming with usage metadata."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not set")

llm = ChatLiteLLM(
model="gpt-3.5-turbo",
openai_api_key=api_key,
streaming=True,
max_retries=1
)

messages = [HumanMessage(content="Count from 1 to 3.")]

chunks = []
usage_metadata_found = False

async for chunk in llm.astream(messages):
chunks.append(chunk)
if hasattr(chunk.message, 'usage_metadata') and chunk.message.usage_metadata:
usage_metadata_found = True
usage = chunk.message.usage_metadata
assert usage.input_tokens > 0
assert usage.output_tokens > 0
assert usage.total_tokens > 0

assert len(chunks) > 0
assert usage_metadata_found, "No usage metadata found in async streaming chunks"

def test_stream_options_override(self):
"""Test that stream_options can be overridden in kwargs."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not set")

llm = ChatLiteLLM(
model="gpt-3.5-turbo",
openai_api_key=api_key,
streaming=False, # Not streaming by default
max_retries=1
)

messages = [HumanMessage(content="Say hi.")]

chunks = []
usage_metadata_found = False

# Override streaming and stream_options in kwargs
for chunk in llm.stream(messages, stream_options={"include_usage": True}):
chunks.append(chunk)
if hasattr(chunk.message, 'usage_metadata') and chunk.message.usage_metadata:
usage_metadata_found = True

assert len(chunks) > 0
# Usage metadata should be found even though streaming=False initially
# because we override with stream_options

def test_non_streaming_usage_metadata_still_works(self):
"""Test that non-streaming usage metadata still works after our changes."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not set")

llm = ChatLiteLLM(
model="gpt-3.5-turbo",
openai_api_key=api_key,
streaming=False,
max_retries=1
)

messages = [HumanMessage(content="Say hello.")]
result = llm.invoke(messages)

assert hasattr(result, 'usage_metadata')
assert result.usage_metadata is not None
assert result.usage_metadata.input_tokens > 0
assert result.usage_metadata.output_tokens > 0
assert result.usage_metadata.total_tokens > 0
Loading