diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index c69c60e19..c7b10aa61 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -116,9 +116,12 @@ async def generate_content_async( ) logger.debug(_build_request_log(llm_request)) - # add tracking headers to custom headers given it will override the headers - # set in the api client constructor - if llm_request.config and llm_request.config.http_options: + # Always add tracking headers to custom headers given it will override + # the headers set in the api client constructor to avoid tracking headers + # being dropped if user provides custom headers or overrides the api client. + if llm_request.config: + if not llm_request.config.http_options: + llm_request.config.http_options = types.HttpOptions() if not llm_request.config.http_options.headers: llm_request.config.http_options.headers = {} llm_request.config.http_options.headers.update(self._tracking_headers) diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index bb11a5d1e..4e99c5a56 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -27,12 +27,27 @@ from google.adk.models.llm_response import LlmResponse from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types -from google.genai import version as genai_version from google.genai.types import Content from google.genai.types import Part import pytest +class MockAsyncIterator: + """Mock for async iterator.""" + + def __init__(self, seq): + self.iter = iter(seq) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration as exc: + raise StopAsyncIteration from exc + + @pytest.fixture def generate_content_response(): return types.GenerateContentResponse( @@ -215,21 +230,6 @@ async def mock_coro(): @pytest.mark.asyncio async def test_generate_content_async_stream(gemini_llm, llm_request): with mock.patch.object(gemini_llm, "api_client") as mock_client: - # Create mock stream responses - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - mock_responses = [ types.GenerateContentResponse( candidates=[ @@ -292,21 +292,6 @@ async def test_generate_content_async_stream_preserves_thinking_and_text_parts( gemini_llm, llm_request ): with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self._iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self._iter) - except StopIteration: - raise StopAsyncIteration - response1 = types.GenerateContentResponse( candidates=[ types.Candidate( @@ -436,21 +421,6 @@ async def test_generate_content_async_stream_with_custom_headers( llm_request.config.http_options = types.HttpOptions(headers=custom_headers) with mock.patch.object(gemini_llm, "api_client") as mock_client: - # Create mock stream responses - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - mock_responses = [ types.GenerateContentResponse( candidates=[ @@ -488,35 +458,58 @@ async def mock_coro(): assert len(responses) == 2 +@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.asyncio -async def test_generate_content_async_without_custom_headers( - gemini_llm, llm_request, generate_content_response +async def test_generate_content_async_patches_tracking_headers( + stream, gemini_llm, llm_request, generate_content_response ): - """Test that tracking headers are not modified when no custom headers exist.""" - # Ensure no http_options exist initially + """Tests that tracking headers are added to the request config.""" + # Set the request's config.http_options to None. llm_request.config.http_options = None with mock.patch.object(gemini_llm, "api_client") as mock_client: + if stream: + # Create a mock coroutine that returns the mock_responses. + async def mock_coro(): + return MockAsyncIterator([generate_content_response]) - async def mock_coro(): - return generate_content_response + # Mock for streaming response. + mock_client.aio.models.generate_content_stream.return_value = mock_coro() + else: + # Create a mock coroutine that returns the generate_content_response. + async def mock_coro(): + return generate_content_response - mock_client.aio.models.generate_content.return_value = mock_coro() + # Mock for non-streaming response. + mock_client.aio.models.generate_content.return_value = mock_coro() + # Call the generate_content_async method. responses = [ resp async for resp in gemini_llm.generate_content_async( - llm_request, stream=False + llm_request, stream=stream ) ] - # Verify that the config passed to generate_content has no http_options - mock_client.aio.models.generate_content.assert_called_once() - call_args = mock_client.aio.models.generate_content.call_args - config_arg = call_args.kwargs["config"] - assert config_arg.http_options is None + # Assert that the config passed to the generate_content or + # generate_content_stream method contains the tracking headers. + if stream: + mock_client.aio.models.generate_content_stream.assert_called_once() + call_args = mock_client.aio.models.generate_content_stream.call_args + else: + mock_client.aio.models.generate_content.assert_called_once() + call_args = mock_client.aio.models.generate_content.call_args - assert len(responses) == 1 + final_config = call_args.kwargs["config"] + + assert final_config is not None + assert final_config.http_options is not None + assert ( + final_config.http_options.headers["x-goog-api-client"] + == gemini_llm._tracking_headers["x-goog-api-client"] + ) + + assert len(responses) == 2 if stream else 1 def test_live_api_version_vertex_ai(gemini_llm): @@ -665,8 +658,7 @@ async def test_preprocess_request_handles_backend_specific_fields( expected_inline_display_name: Optional[str], expected_labels: Optional[str], ): - """ - Tests that _preprocess_request correctly sanitizes fields based on the API backend. + """Tests that _preprocess_request correctly sanitizes fields based on the API backend. - For GEMINI_API, it should remove 'display_name' from file/inline data and remove 'labels' from the config. @@ -732,21 +724,6 @@ async def test_generate_content_async_stream_aggregated_content_regardless_of_fi ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - # Test with different finish reasons test_cases = [ types.FinishReason.MAX_TOKENS, @@ -820,21 +797,6 @@ async def test_generate_content_async_stream_with_thought_and_text_error_handlin ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - mock_responses = [ types.GenerateContentResponse( candidates=[ @@ -902,21 +864,6 @@ async def test_generate_content_async_stream_error_info_none_for_stop_finish_rea ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - mock_responses = [ types.GenerateContentResponse( candidates=[ @@ -980,21 +927,6 @@ async def test_generate_content_async_stream_error_info_set_for_non_stop_finish_ ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - mock_responses = [ types.GenerateContentResponse( candidates=[ @@ -1058,21 +990,6 @@ async def test_generate_content_async_stream_no_aggregated_content_without_text( ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - # Mock response with no text content mock_responses = [ types.GenerateContentResponse( @@ -1127,21 +1044,6 @@ async def test_generate_content_async_stream_mixed_text_function_call_text(): ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - # Create responses with pattern: text -> function_call -> text mock_responses = [ # First text chunk @@ -1247,21 +1149,6 @@ async def test_generate_content_async_stream_multiple_text_parts_in_single_respo ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - # Create a response with multiple text parts mock_responses = [ types.GenerateContentResponse( @@ -1314,21 +1201,6 @@ async def test_generate_content_async_stream_complex_mixed_thought_text_function ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - # Complex pattern: thought -> text -> function_call -> thought -> text mock_responses = [ # Thought @@ -1450,21 +1322,6 @@ async def test_generate_content_async_stream_two_separate_text_aggregations(): ) with mock.patch.object(gemini_llm, "api_client") as mock_client: - - class MockAsyncIterator: - - def __init__(self, seq): - self.iter = iter(seq) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - # Create responses: multiple text chunks -> function_call -> multiple text chunks mock_responses = [ # First text accumulation (multiple chunks)