diff --git a/langchain_litellm/chat_models/litellm.py b/langchain_litellm/chat_models/litellm.py index 143e3a3..ac386d3 100644 --- a/langchain_litellm/chat_models/litellm.py +++ b/langchain_litellm/chat_models/litellm.py @@ -105,6 +105,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: if _dict.get("tool_calls"): additional_kwargs["tool_calls"] = _dict["tool_calls"] + # Add reasoning_content support for thinking-enabled models + if _dict.get("reasoning_content"): + additional_kwargs["reasoning_content"] = _dict["reasoning_content"] + return AIMessage(content=content, additional_kwargs=additional_kwargs) elif role == "system": return SystemMessage(content=_dict["content"]) @@ -148,7 +152,7 @@ def _convert_delta_to_message_chunk( function_call = delta.function_call raw_tool_calls = delta.tool_calls reasoning_content = getattr(delta, "reasoning_content", None) - + if function_call: additional_kwargs = {"function_call": dict(function_call)} # The hasattr check is necessary because litellm explicitly deletes the @@ -173,7 +177,7 @@ def _convert_delta_to_message_chunk( ) for rtc in raw_tool_calls ] - except KeyError: + except (KeyError, AttributeError): pass if role == "user" or default_class == HumanMessageChunk: @@ -288,7 +292,7 @@ def _default_params(self) -> Dict[str, Any]: set_model_value = self.model if self.model_name is not None: set_model_value = self.model_name - return { + params = { "model": set_model_value, "force_timeout": self.request_timeout, "max_tokens": self.max_tokens, @@ -299,6 +303,13 @@ def _default_params(self) -> Dict[str, Any]: **self.model_kwargs, } + # Add stream_options for usage tracking in streaming responses + # This enables token usage metadata in streaming chunks + if self.streaming: + params["stream_options"] = {"include_usage": True} + + return params + @property def _client_params(self) -> Dict[str, Any]: """Get the parameters used for the openai client.""" @@ -453,6 +464,10 @@ def _stream( message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} + # Ensure stream_options is set for usage tracking + if "stream_options" not in params: + params["stream_options"] = {"include_usage": True} + default_chunk_class = AIMessageChunk for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params @@ -461,12 +476,23 @@ def _stream( chunk = chunk.model_dump() if len(chunk["choices"]) == 0: continue + + # Extract usage metadata from chunk if present + usage_metadata = None + if "usage" in chunk and chunk["usage"]: + usage_metadata = _create_usage_metadata(chunk["usage"]) + delta = chunk["choices"][0]["delta"] - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) + message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = message_chunk.__class__ + + # Attach usage metadata to the message chunk + if usage_metadata: + message_chunk.usage_metadata = usage_metadata + + cg_chunk = ChatGenerationChunk(message=message_chunk) if run_manager: - run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk) yield cg_chunk async def _astream( @@ -479,20 +505,37 @@ async def _astream( message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} + # Ensure stream_options is set for usage tracking + if "stream_options" not in params: + params["stream_options"] = {"include_usage": True} + default_chunk_class = AIMessageChunk - async for chunk in await acompletion_with_retry( - self, messages=message_dicts, run_manager=run_manager, **params + + # For async streaming, we need to use the async completion method properly + async for chunk in self.acompletion_with_retry( + messages=message_dicts, run_manager=run_manager, **params ): if not isinstance(chunk, dict): chunk = chunk.model_dump() if len(chunk["choices"]) == 0: continue + + # Extract usage metadata from chunk if present + usage_metadata = None + if "usage" in chunk and chunk["usage"]: + usage_metadata = _create_usage_metadata(chunk["usage"]) + delta = chunk["choices"][0]["delta"] - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) + message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = message_chunk.__class__ + + # Attach usage metadata to the message chunk + if usage_metadata: + message_chunk.usage_metadata = usage_metadata + + cg_chunk = ChatGenerationChunk(message=message_chunk) if run_manager: - await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + await run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk) yield cg_chunk async def _agenerate( @@ -598,8 +641,34 @@ def _llm_type(self) -> str: def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata: input_tokens = token_usage.get("prompt_tokens", 0) output_tokens = token_usage.get("completion_tokens", 0) + + # Extract advanced usage details + input_token_details = {} + output_token_details = {} + + # Cache tokens (for providers that support it like OpenAI, Anthropic) + if "cache_read_input_tokens" in token_usage: + input_token_details["cache_read"] = token_usage["cache_read_input_tokens"] + + if "cache_creation_input_tokens" in token_usage: + input_token_details["cache_creation"] = token_usage["cache_creation_input_tokens"] + + # Audio tokens (for multimodal models) + if "audio_input_tokens" in token_usage: + input_token_details["audio"] = token_usage["audio_input_tokens"] + + if "audio_output_tokens" in token_usage: + output_token_details["audio"] = token_usage["audio_output_tokens"] + + # Reasoning tokens (for o1 models, Claude thinking, etc.) + completion_tokens_details = token_usage.get("completion_tokens_details", {}) + if completion_tokens_details and "reasoning_tokens" in completion_tokens_details: + output_token_details["reasoning"] = completion_tokens_details["reasoning_tokens"] + return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=input_tokens + output_tokens, + input_token_details=input_token_details if input_token_details else {}, + output_token_details=output_token_details if output_token_details else {}, ) diff --git a/tests/integration_tests/test_streaming_usage_metadata.py b/tests/integration_tests/test_streaming_usage_metadata.py new file mode 100644 index 0000000..fd986d7 --- /dev/null +++ b/tests/integration_tests/test_streaming_usage_metadata.py @@ -0,0 +1,186 @@ +"""Integration tests for streaming usage metadata functionality.""" + +import os +import pytest +from typing import List + +from langchain_core.messages import HumanMessage, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk +from langchain_litellm.chat_models import ChatLiteLLM + + +class TestStreamingUsageMetadata: + """Test streaming usage metadata with real API calls.""" + + def test_openai_streaming_usage_metadata(self): + """Test OpenAI streaming with usage metadata.""" + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set") + + llm = ChatLiteLLM( + model="gpt-3.5-turbo", + openai_api_key=api_key, + streaming=True, + max_retries=1 + ) + + messages = [HumanMessage(content="Say hello in exactly 5 words.")] + + chunks = [] + usage_metadata_found = False + + for chunk in llm.stream(messages): + chunks.append(chunk) + # chunk is an AIMessageChunk directly, not ChatGenerationChunk + if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: + usage_metadata_found = True + usage = chunk.usage_metadata + assert usage["input_tokens"] > 0 + assert usage["output_tokens"] > 0 + assert usage["total_tokens"] > 0 + assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] + + assert len(chunks) > 0 + assert usage_metadata_found, "No usage metadata found in streaming chunks" + + def test_openai_streaming_usage_metadata_with_cache(self): + """Test OpenAI streaming with cache tokens (if supported).""" + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set") + + llm = ChatLiteLLM( + model="gpt-4o-mini", # Use a model that supports caching + openai_api_key=api_key, + streaming=True, + max_retries=1 + ) + + # Send the same message twice to potentially trigger caching + messages = [HumanMessage(content="What is the capital of France? Please answer in exactly one word.")] + + # First call + chunks1 = list(llm.stream(messages)) + + # Second call (might use cache) + chunks2 = list(llm.stream(messages)) + + # Check if any chunks have cache information + for chunks in [chunks1, chunks2]: + for chunk in chunks: + if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: + usage = chunk.usage_metadata + if usage.get("input_token_details") and "cache_read" in usage["input_token_details"]: + assert usage["input_token_details"]["cache_read"] >= 0 + + def test_anthropic_streaming_usage_metadata(self): + """Test Anthropic streaming with usage metadata.""" + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + pytest.skip("ANTHROPIC_API_KEY not set") + + llm = ChatLiteLLM( + model="claude-3-haiku-20240307", + anthropic_api_key=api_key, + streaming=True, + max_retries=1 + ) + + messages = [HumanMessage(content="Say hello in exactly 3 words.")] + + chunks = [] + usage_metadata_found = False + + for chunk in llm.stream(messages): + chunks.append(chunk) + if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: + usage_metadata_found = True + usage = chunk.usage_metadata + assert usage["input_tokens"] > 0 + assert usage["output_tokens"] > 0 + assert usage["total_tokens"] > 0 + + assert len(chunks) > 0 + assert usage_metadata_found, "No usage metadata found in Anthropic streaming chunks" + + @pytest.mark.asyncio + async def test_openai_async_streaming_usage_metadata(self): + """Test OpenAI async streaming with usage metadata.""" + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set") + + llm = ChatLiteLLM( + model="gpt-3.5-turbo", + openai_api_key=api_key, + streaming=True, + max_retries=1 + ) + + messages = [HumanMessage(content="Count from 1 to 3.")] + + chunks = [] + usage_metadata_found = False + + async for chunk in llm.astream(messages): + chunks.append(chunk) + if hasattr(chunk.message, 'usage_metadata') and chunk.message.usage_metadata: + usage_metadata_found = True + usage = chunk.message.usage_metadata + assert usage.input_tokens > 0 + assert usage.output_tokens > 0 + assert usage.total_tokens > 0 + + assert len(chunks) > 0 + assert usage_metadata_found, "No usage metadata found in async streaming chunks" + + def test_stream_options_override(self): + """Test that stream_options can be overridden in kwargs.""" + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set") + + llm = ChatLiteLLM( + model="gpt-3.5-turbo", + openai_api_key=api_key, + streaming=False, # Not streaming by default + max_retries=1 + ) + + messages = [HumanMessage(content="Say hi.")] + + chunks = [] + usage_metadata_found = False + + # Override streaming and stream_options in kwargs + for chunk in llm.stream(messages, stream_options={"include_usage": True}): + chunks.append(chunk) + if hasattr(chunk.message, 'usage_metadata') and chunk.message.usage_metadata: + usage_metadata_found = True + + assert len(chunks) > 0 + # Usage metadata should be found even though streaming=False initially + # because we override with stream_options + + def test_non_streaming_usage_metadata_still_works(self): + """Test that non-streaming usage metadata still works after our changes.""" + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set") + + llm = ChatLiteLLM( + model="gpt-3.5-turbo", + openai_api_key=api_key, + streaming=False, + max_retries=1 + ) + + messages = [HumanMessage(content="Say hello.")] + result = llm.invoke(messages) + + assert hasattr(result, 'usage_metadata') + assert result.usage_metadata is not None + assert result.usage_metadata.input_tokens > 0 + assert result.usage_metadata.output_tokens > 0 + assert result.usage_metadata.total_tokens > 0 \ No newline at end of file diff --git a/tests/unit_tests/test_litellm.py b/tests/unit_tests/test_litellm.py index be75013..0f45b9f 100644 --- a/tests/unit_tests/test_litellm.py +++ b/tests/unit_tests/test_litellm.py @@ -1,100 +1,274 @@ -"""Test chat model integration.""" - -from typing import Type - -from langchain_core.messages import AIMessageChunk -from langchain_tests.unit_tests import ChatModelUnitTests -from litellm.types.utils import ChatCompletionDeltaToolCall, Delta, Function - -from langchain_litellm.chat_models import ChatLiteLLM -from langchain_litellm.chat_models.litellm import _convert_delta_to_message_chunk - - -class TestChatLiteLLMUnit(ChatModelUnitTests): - @property - def chat_model_class(self) -> Type[ChatLiteLLM]: - return ChatLiteLLM - - @property - def chat_model_params(self) -> dict: - # These should be parameters used to initialize your integration for testing - return { - "custom_llm_provider": "openai", - "model": "gpt-3.5-turbo", - "api_key": "", - "max_retries": 1, - } - - @property - def has_tool_calling(self) -> bool: - return True - - @property - def has_tool_choice(self) -> bool: - return False - - @property - def has_structured_output(self) -> bool: - return False - - @property - def supports_json_mode(self) -> bool: - return False - - @property - def supports_image_inputs(self) -> bool: - return False - - @property - def returns_usage_metadata(self) -> bool: - return True - - @property - def supports_anthropic_inputs(self) -> bool: - return False - - @property - def supports_image_tool_message(self) -> bool: - return False - - def test_litellm_delta_to_langchain_message_chunk(self): - """Test the litellm._convert_delta_to_message_chunk method, to ensure compatibility when converting a LiteLLM delta to a LangChain message chunk.""" - mock_content = "This is a test content" - mock_tool_call_id = "call_test" - mock_tool_call_name = "test_tool_call" - mock_tool_call_arguments = "" - mock_tool_call_index = 3 - mock_delta = Delta( - content=mock_content, - role="assistant", - tool_calls=[ - ChatCompletionDeltaToolCall( - id=mock_tool_call_id, - function=Function( - arguments=mock_tool_call_arguments, name=mock_tool_call_name - ), - type="function", - index=mock_tool_call_index, - ) - ], - ) - message_chunk = _convert_delta_to_message_chunk(mock_delta, AIMessageChunk) - assert isinstance(message_chunk, AIMessageChunk) - assert message_chunk.content == mock_content - tool_call_chunk = message_chunk.tool_call_chunks[0] - assert tool_call_chunk["id"] == mock_tool_call_id - assert tool_call_chunk["name"] == mock_tool_call_name - assert tool_call_chunk["args"] == mock_tool_call_arguments - assert tool_call_chunk["index"] == mock_tool_call_index - - def test_convert_dict_to_tool_message(self): - """Ensure tool role dicts convert to ToolMessage.""" - from langchain_litellm.chat_models.litellm import _convert_dict_to_message - - mock_dict = {"role": "tool", "content": "result", "tool_call_id": "123"} - message = _convert_dict_to_message(mock_dict) - from langchain_core.messages import ToolMessage - - assert isinstance(message, ToolMessage) - assert message.content == "result" - assert message.tool_call_id == "123" +"""Test chat model integration.""" + +from typing import Type + +from langchain_core.messages import AIMessageChunk +from langchain_tests.unit_tests import ChatModelUnitTests +from litellm.types.utils import ChatCompletionDeltaToolCall, Delta, Function + +from langchain_litellm.chat_models import ChatLiteLLM +from langchain_litellm.chat_models.litellm import _convert_delta_to_message_chunk + + +class TestChatLiteLLMUnit(ChatModelUnitTests): + @property + def chat_model_class(self) -> Type[ChatLiteLLM]: + return ChatLiteLLM + + @property + def chat_model_params(self) -> dict: + # These should be parameters used to initialize your integration for testing + return { + "custom_llm_provider": "openai", + "model": "gpt-3.5-turbo", + "api_key": "", + "max_retries": 1, + } + + @property + def has_tool_calling(self) -> bool: + return True + + @property + def has_tool_choice(self) -> bool: + return False + + @property + def has_structured_output(self) -> bool: + return False + + @property + def supports_json_mode(self) -> bool: + return False + + @property + def supports_image_inputs(self) -> bool: + return False + + @property + def returns_usage_metadata(self) -> bool: + return True + + @property + def supports_anthropic_inputs(self) -> bool: + return False + + @property + def supports_image_tool_message(self) -> bool: + return False + + def test_litellm_delta_to_langchain_message_chunk(self): + """Test the litellm._convert_delta_to_message_chunk method, to ensure compatibility when converting a LiteLLM delta to a LangChain message chunk.""" + mock_content = "This is a test content" + mock_tool_call_id = "call_test" + mock_tool_call_name = "test_tool_call" + mock_tool_call_arguments = "" + mock_tool_call_index = 3 + mock_delta = Delta( + content=mock_content, + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + id=mock_tool_call_id, + function=Function( + arguments=mock_tool_call_arguments, name=mock_tool_call_name + ), + type="function", + index=mock_tool_call_index, + ) + ], + ) + message_chunk = _convert_delta_to_message_chunk(mock_delta, AIMessageChunk) + assert isinstance(message_chunk, AIMessageChunk) + assert message_chunk.content == mock_content + tool_call_chunk = message_chunk.tool_call_chunks[0] + assert tool_call_chunk["id"] == mock_tool_call_id + assert tool_call_chunk["name"] == mock_tool_call_name + assert tool_call_chunk["args"] == mock_tool_call_arguments + assert tool_call_chunk["index"] == mock_tool_call_index + + def test_convert_dict_to_tool_message(self): + """Ensure tool role dicts convert to ToolMessage.""" + from langchain_litellm.chat_models.litellm import _convert_dict_to_message + + mock_dict = {"role": "tool", "content": "result", "tool_call_id": "123"} + message = _convert_dict_to_message(mock_dict) + from langchain_core.messages import ToolMessage + + assert isinstance(message, ToolMessage) + assert message.content == "result" + assert message.tool_call_id == "123" + + def test_default_params_includes_stream_options_when_streaming(self): + """Test that _default_params includes stream_options when streaming is enabled.""" + from langchain_litellm.chat_models.litellm import ChatLiteLLM + + # Test with streaming=True + llm = ChatLiteLLM(model="gpt-3.5-turbo", streaming=True) + params = llm._default_params + assert "stream_options" in params + assert params["stream_options"] == {"include_usage": True} + + # Test with streaming=False + llm_no_stream = ChatLiteLLM(model="gpt-3.5-turbo", streaming=False) + params_no_stream = llm_no_stream._default_params + assert "stream_options" not in params_no_stream + + def test_create_usage_metadata_basic(self): + """Test _create_usage_metadata with basic token usage.""" + from langchain_litellm.chat_models.litellm import _create_usage_metadata + + token_usage = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + + usage_metadata = _create_usage_metadata(token_usage) + assert usage_metadata["input_tokens"] == 10 + assert usage_metadata["output_tokens"] == 20 + assert usage_metadata["total_tokens"] == 30 + assert usage_metadata["input_token_details"] == {} + assert usage_metadata["output_token_details"] == {} + + def test_create_usage_metadata_with_cache_tokens(self): + """Test _create_usage_metadata with cache tokens.""" + from langchain_litellm.chat_models.litellm import _create_usage_metadata + + token_usage = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + "cache_read_input_tokens": 5 + } + + usage_metadata = _create_usage_metadata(token_usage) + assert usage_metadata["input_tokens"] == 10 + assert usage_metadata["output_tokens"] == 20 + assert usage_metadata["total_tokens"] == 30 + assert usage_metadata["input_token_details"] == {"cache_read": 5} + assert usage_metadata["output_token_details"] == {} + + def test_create_usage_metadata_with_reasoning_tokens(self): + """Test _create_usage_metadata with reasoning tokens.""" + from langchain_litellm.chat_models.litellm import _create_usage_metadata + + token_usage = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + "completion_tokens_details": { + "reasoning_tokens": 15 + } + } + + usage_metadata = _create_usage_metadata(token_usage) + assert usage_metadata["input_tokens"] == 10 + assert usage_metadata["output_tokens"] == 20 + assert usage_metadata["total_tokens"] == 30 + assert usage_metadata["input_token_details"] == {} + assert usage_metadata["output_token_details"] == {"reasoning": 15} + + def test_create_usage_metadata_with_all_advanced_fields(self): + """Test _create_usage_metadata with all advanced fields.""" + from langchain_litellm.chat_models.litellm import _create_usage_metadata + + token_usage = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + "cache_read_input_tokens": 5, + "completion_tokens_details": { + "reasoning_tokens": 15 + } + } + + usage_metadata = _create_usage_metadata(token_usage) + assert usage_metadata["input_tokens"] == 10 + assert usage_metadata["output_tokens"] == 20 + assert usage_metadata["total_tokens"] == 30 + assert usage_metadata["input_token_details"] == {"cache_read": 5} + assert usage_metadata["output_token_details"] == {"reasoning": 15} + + def test_litellm_normalize_messages(self): + """ + Test that _normalize_messages correctly handles different multimodal formats: + - LiteLLM format should be preserved (not transformed) + - OpenAI format should be transformed correctly + - Vertex format should be preserved + """ + import base64 + from langchain_core.messages import HumanMessage + from langchain_core.language_models._utils import _normalize_messages + + # Create dummy PDF data + dummy_pdf_data = base64.b64encode(b"dummy pdf content").decode('utf-8') + + # Test 1: LiteLLM's official format should be preserved + litellm_message = HumanMessage(content=[ + {"type": "text", "text": "Analyze this PDF"}, + { + "type": "file", + "file": { + "file_data": f"data:application/pdf;base64,{dummy_pdf_data}" + } + } + ]) + + normalized_litellm = _normalize_messages([litellm_message])[0] + litellm_file_content = next( + (item for item in normalized_litellm.content if isinstance(item, dict) and item.get('type') == 'file'), + None + ) + + assert litellm_file_content is not None, "LiteLLM file content not found" + # LiteLLM format should be preserved + assert 'file' in litellm_file_content, "LiteLLM format should preserve 'file' key" + assert 'file_data' in litellm_file_content['file'], "LiteLLM format should preserve 'file_data'" + assert 'source_type' not in litellm_file_content, "LiteLLM format should not have 'source_type' key" + + # Test 2: OpenAI format should be transformed appropriately + openai_message = HumanMessage(content=[ + {"type": "text", "text": "Analyze this image"}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{dummy_pdf_data}" + } + } + ]) + + normalized_openai = _normalize_messages([openai_message])[0] + openai_file_content = next( + (item for item in normalized_openai.content if isinstance(item, dict) and item.get('type') == 'image_url'), + None + ) + + assert openai_file_content is not None, "OpenAI file content not found" + # OpenAI format should be preserved as-is + assert 'image_url' in openai_file_content, "OpenAI format should preserve 'image_url' key" + assert 'url' in openai_file_content['image_url'], "OpenAI format should preserve 'url'" + + # Test 3: Vertex format should be preserved + vertex_message = HumanMessage(content=[ + {"type": "text", "text": "Analyze this file"}, + { + "type": "file", + "file": { + "file_data": f"data:application/pdf;base64,{dummy_pdf_data}", + "format": "application/pdf" + } + } + ]) + + normalized_vertex = _normalize_messages([vertex_message])[0] + vertex_file_content = next( + (item for item in normalized_vertex.content if isinstance(item, dict) and item.get('type') == 'file'), + None + ) + + assert vertex_file_content is not None, "Vertex file content not found" + # Vertex format should be preserved + assert 'file' in vertex_file_content, "Vertex format should preserve 'file' key" + assert 'format' in vertex_file_content['file'], "Vertex format should preserve 'format' key" + assert vertex_file_content['file']['format'] == "application/pdf", "Vertex format should preserve format value" diff --git a/tests/unit_tests/test_usage_metadata.py b/tests/unit_tests/test_usage_metadata.py new file mode 100644 index 0000000..57735dc --- /dev/null +++ b/tests/unit_tests/test_usage_metadata.py @@ -0,0 +1,170 @@ +"""Test usage metadata functionality.""" + +from langchain_litellm.chat_models.litellm import _create_usage_metadata + + +def test_create_usage_metadata_basic(): + """Test _create_usage_metadata with basic token usage.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + } + + metadata = _create_usage_metadata(token_usage) + + assert metadata.input_tokens == 100 + assert metadata.output_tokens == 50 + assert metadata.total_tokens == 150 + assert metadata.input_token_details == {} + assert metadata.output_token_details == {} + + +def test_create_usage_metadata_with_cache_tokens(): + """Test _create_usage_metadata with cache tokens.""" + token_usage = { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + "cache_read_input_tokens": 150, + "cache_creation_input_tokens": 50 + } + + metadata = _create_usage_metadata(token_usage) + + assert metadata.input_tokens == 200 + assert metadata.output_tokens == 100 + assert metadata.total_tokens == 300 + assert metadata.input_token_details["cache_read"] == 150 + assert metadata.input_token_details["cache_creation"] == 50 + assert metadata.output_token_details == {} + + +def test_create_usage_metadata_with_audio_tokens(): + """Test _create_usage_metadata with audio tokens for multimodal models.""" + token_usage = { + "prompt_tokens": 300, + "completion_tokens": 150, + "total_tokens": 450, + "audio_input_tokens": 25, + "audio_output_tokens": 35 + } + + metadata = _create_usage_metadata(token_usage) + + assert metadata.input_tokens == 300 + assert metadata.output_tokens == 150 + assert metadata.total_tokens == 450 + assert metadata.input_token_details["audio"] == 25 + assert metadata.output_token_details["audio"] == 35 + + +def test_create_usage_metadata_with_reasoning_tokens(): + """Test _create_usage_metadata with reasoning tokens for thinking models.""" + token_usage = { + "prompt_tokens": 400, + "completion_tokens": 200, + "total_tokens": 600, + "completion_tokens_details": { + "reasoning_tokens": 180 + } + } + + metadata = _create_usage_metadata(token_usage) + + assert metadata.input_tokens == 400 + assert metadata.output_tokens == 200 + assert metadata.total_tokens == 600 + assert metadata.input_token_details == {} + assert metadata.output_token_details["reasoning"] == 180 + + +def test_create_usage_metadata_complete_schema(): + """Test _create_usage_metadata with complete schema including all token types.""" + token_usage = { + "prompt_tokens": 350, + "completion_tokens": 240, + "total_tokens": 590, + "cache_read_input_tokens": 100, + "cache_creation_input_tokens": 200, + "audio_input_tokens": 10, + "audio_output_tokens": 10, + "completion_tokens_details": { + "reasoning_tokens": 200 + } + } + + metadata = _create_usage_metadata(token_usage) + + # Basic tokens + assert metadata.input_tokens == 350 + assert metadata.output_tokens == 240 + assert metadata.total_tokens == 590 + + # Input token details + assert metadata.input_token_details["cache_read"] == 100 + assert metadata.input_token_details["cache_creation"] == 200 + assert metadata.input_token_details["audio"] == 10 + + # Output token details + assert metadata.output_token_details["audio"] == 10 + assert metadata.output_token_details["reasoning"] == 200 + + +def test_create_usage_metadata_edge_cases(): + """Test _create_usage_metadata with edge cases and missing fields.""" + # Test with empty completion_tokens_details + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "completion_tokens_details": {} + } + + metadata = _create_usage_metadata(token_usage) + assert metadata.output_token_details == {} + + # Test with missing completion_tokens_details + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50 + } + + metadata = _create_usage_metadata(token_usage) + assert metadata.output_token_details == {} + + # Test with zero tokens + token_usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, + "audio_input_tokens": 0, + "audio_output_tokens": 0 + } + + metadata = _create_usage_metadata(token_usage) + assert metadata.input_tokens == 0 + assert metadata.output_tokens == 0 + assert metadata.total_tokens == 0 + assert metadata.input_token_details["cache_read"] == 0 + assert metadata.input_token_details["cache_creation"] == 0 + assert metadata.input_token_details["audio"] == 0 + assert metadata.output_token_details["audio"] == 0 + + +def test_create_usage_metadata_missing_optional_fields(): + """Test _create_usage_metadata with missing optional fields.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + # No cache, audio, or reasoning tokens + } + + metadata = _create_usage_metadata(token_usage) + + assert metadata.input_tokens == 100 + assert metadata.output_tokens == 50 + assert metadata.total_tokens == 150 + assert metadata.input_token_details == {} + assert metadata.output_token_details == {} \ No newline at end of file