Skip to content

Commit 9f92b44

Browse files
xuanyang15copybara-github
authored andcommitted
fix: merge tracking headers even when llm_request.config.http_options is not set in Gemini.generate_content_async
PiperOrigin-RevId: 787278157
1 parent c69dcf8 commit 9f92b44

File tree

2 files changed

+60
-200
lines changed

2 files changed

+60
-200
lines changed

src/google/adk/models/google_llm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,12 @@ async def generate_content_async(
116116
)
117117
logger.debug(_build_request_log(llm_request))
118118

119-
# add tracking headers to custom headers given it will override the headers
120-
# set in the api client constructor
121-
if llm_request.config and llm_request.config.http_options:
119+
# Always add tracking headers to custom headers given it will override
120+
# the headers set in the api client constructor to avoid tracking headers
121+
# being dropped if user provides custom headers or overrides the api client.
122+
if llm_request.config:
123+
if not llm_request.config.http_options:
124+
llm_request.config.http_options = types.HttpOptions()
122125
if not llm_request.config.http_options.headers:
123126
llm_request.config.http_options.headers = {}
124127
llm_request.config.http_options.headers.update(self._tracking_headers)

tests/unittests/models/test_google_llm.py

Lines changed: 54 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,27 @@
2727
from google.adk.models.llm_response import LlmResponse
2828
from google.adk.utils.variant_utils import GoogleLLMVariant
2929
from google.genai import types
30-
from google.genai import version as genai_version
3130
from google.genai.types import Content
3231
from google.genai.types import Part
3332
import pytest
3433

3534

35+
class MockAsyncIterator:
36+
"""Mock for async iterator."""
37+
38+
def __init__(self, seq):
39+
self.iter = iter(seq)
40+
41+
def __aiter__(self):
42+
return self
43+
44+
async def __anext__(self):
45+
try:
46+
return next(self.iter)
47+
except StopIteration as exc:
48+
raise StopAsyncIteration from exc
49+
50+
3651
@pytest.fixture
3752
def generate_content_response():
3853
return types.GenerateContentResponse(
@@ -215,21 +230,6 @@ async def mock_coro():
215230
@pytest.mark.asyncio
216231
async def test_generate_content_async_stream(gemini_llm, llm_request):
217232
with mock.patch.object(gemini_llm, "api_client") as mock_client:
218-
# Create mock stream responses
219-
class MockAsyncIterator:
220-
221-
def __init__(self, seq):
222-
self.iter = iter(seq)
223-
224-
def __aiter__(self):
225-
return self
226-
227-
async def __anext__(self):
228-
try:
229-
return next(self.iter)
230-
except StopIteration:
231-
raise StopAsyncIteration
232-
233233
mock_responses = [
234234
types.GenerateContentResponse(
235235
candidates=[
@@ -292,21 +292,6 @@ async def test_generate_content_async_stream_preserves_thinking_and_text_parts(
292292
gemini_llm, llm_request
293293
):
294294
with mock.patch.object(gemini_llm, "api_client") as mock_client:
295-
296-
class MockAsyncIterator:
297-
298-
def __init__(self, seq):
299-
self._iter = iter(seq)
300-
301-
def __aiter__(self):
302-
return self
303-
304-
async def __anext__(self):
305-
try:
306-
return next(self._iter)
307-
except StopIteration:
308-
raise StopAsyncIteration
309-
310295
response1 = types.GenerateContentResponse(
311296
candidates=[
312297
types.Candidate(
@@ -436,21 +421,6 @@ async def test_generate_content_async_stream_with_custom_headers(
436421
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
437422

438423
with mock.patch.object(gemini_llm, "api_client") as mock_client:
439-
# Create mock stream responses
440-
class MockAsyncIterator:
441-
442-
def __init__(self, seq):
443-
self.iter = iter(seq)
444-
445-
def __aiter__(self):
446-
return self
447-
448-
async def __anext__(self):
449-
try:
450-
return next(self.iter)
451-
except StopIteration:
452-
raise StopAsyncIteration
453-
454424
mock_responses = [
455425
types.GenerateContentResponse(
456426
candidates=[
@@ -488,35 +458,58 @@ async def mock_coro():
488458
assert len(responses) == 2
489459

490460

461+
@pytest.mark.parametrize("stream", [True, False])
491462
@pytest.mark.asyncio
492-
async def test_generate_content_async_without_custom_headers(
493-
gemini_llm, llm_request, generate_content_response
463+
async def test_generate_content_async_patches_tracking_headers(
464+
stream, gemini_llm, llm_request, generate_content_response
494465
):
495-
"""Test that tracking headers are not modified when no custom headers exist."""
496-
# Ensure no http_options exist initially
466+
"""Tests that tracking headers are added to the request config."""
467+
# Set the request's config.http_options to None.
497468
llm_request.config.http_options = None
498469

499470
with mock.patch.object(gemini_llm, "api_client") as mock_client:
471+
if stream:
472+
# Create a mock coroutine that returns the mock_responses.
473+
async def mock_coro():
474+
return MockAsyncIterator([generate_content_response])
500475

501-
async def mock_coro():
502-
return generate_content_response
476+
# Mock for streaming response.
477+
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
478+
else:
479+
# Create a mock coroutine that returns the generate_content_response.
480+
async def mock_coro():
481+
return generate_content_response
503482

504-
mock_client.aio.models.generate_content.return_value = mock_coro()
483+
# Mock for non-streaming response.
484+
mock_client.aio.models.generate_content.return_value = mock_coro()
505485

486+
# Call the generate_content_async method.
506487
responses = [
507488
resp
508489
async for resp in gemini_llm.generate_content_async(
509-
llm_request, stream=False
490+
llm_request, stream=stream
510491
)
511492
]
512493

513-
# Verify that the config passed to generate_content has no http_options
514-
mock_client.aio.models.generate_content.assert_called_once()
515-
call_args = mock_client.aio.models.generate_content.call_args
516-
config_arg = call_args.kwargs["config"]
517-
assert config_arg.http_options is None
494+
# Assert that the config passed to the generate_content or
495+
# generate_content_stream method contains the tracking headers.
496+
if stream:
497+
mock_client.aio.models.generate_content_stream.assert_called_once()
498+
call_args = mock_client.aio.models.generate_content_stream.call_args
499+
else:
500+
mock_client.aio.models.generate_content.assert_called_once()
501+
call_args = mock_client.aio.models.generate_content.call_args
518502

519-
assert len(responses) == 1
503+
final_config = call_args.kwargs["config"]
504+
505+
assert final_config is not None
506+
assert final_config.http_options is not None
507+
assert (
508+
final_config.http_options.headers["x-goog-api-client"]
509+
== gemini_llm._tracking_headers["x-goog-api-client"]
510+
)
511+
512+
assert len(responses) == 2 if stream else 1
520513

521514

522515
def test_live_api_version_vertex_ai(gemini_llm):
@@ -665,8 +658,7 @@ async def test_preprocess_request_handles_backend_specific_fields(
665658
expected_inline_display_name: Optional[str],
666659
expected_labels: Optional[str],
667660
):
668-
"""
669-
Tests that _preprocess_request correctly sanitizes fields based on the API backend.
661+
"""Tests that _preprocess_request correctly sanitizes fields based on the API backend.
670662
671663
- For GEMINI_API, it should remove 'display_name' from file/inline data
672664
and remove 'labels' from the config.
@@ -732,21 +724,6 @@ async def test_generate_content_async_stream_aggregated_content_regardless_of_fi
732724
)
733725

734726
with mock.patch.object(gemini_llm, "api_client") as mock_client:
735-
736-
class MockAsyncIterator:
737-
738-
def __init__(self, seq):
739-
self.iter = iter(seq)
740-
741-
def __aiter__(self):
742-
return self
743-
744-
async def __anext__(self):
745-
try:
746-
return next(self.iter)
747-
except StopIteration:
748-
raise StopAsyncIteration
749-
750727
# Test with different finish reasons
751728
test_cases = [
752729
types.FinishReason.MAX_TOKENS,
@@ -820,21 +797,6 @@ async def test_generate_content_async_stream_with_thought_and_text_error_handlin
820797
)
821798

822799
with mock.patch.object(gemini_llm, "api_client") as mock_client:
823-
824-
class MockAsyncIterator:
825-
826-
def __init__(self, seq):
827-
self.iter = iter(seq)
828-
829-
def __aiter__(self):
830-
return self
831-
832-
async def __anext__(self):
833-
try:
834-
return next(self.iter)
835-
except StopIteration:
836-
raise StopAsyncIteration
837-
838800
mock_responses = [
839801
types.GenerateContentResponse(
840802
candidates=[
@@ -902,21 +864,6 @@ async def test_generate_content_async_stream_error_info_none_for_stop_finish_rea
902864
)
903865

904866
with mock.patch.object(gemini_llm, "api_client") as mock_client:
905-
906-
class MockAsyncIterator:
907-
908-
def __init__(self, seq):
909-
self.iter = iter(seq)
910-
911-
def __aiter__(self):
912-
return self
913-
914-
async def __anext__(self):
915-
try:
916-
return next(self.iter)
917-
except StopIteration:
918-
raise StopAsyncIteration
919-
920867
mock_responses = [
921868
types.GenerateContentResponse(
922869
candidates=[
@@ -980,21 +927,6 @@ async def test_generate_content_async_stream_error_info_set_for_non_stop_finish_
980927
)
981928

982929
with mock.patch.object(gemini_llm, "api_client") as mock_client:
983-
984-
class MockAsyncIterator:
985-
986-
def __init__(self, seq):
987-
self.iter = iter(seq)
988-
989-
def __aiter__(self):
990-
return self
991-
992-
async def __anext__(self):
993-
try:
994-
return next(self.iter)
995-
except StopIteration:
996-
raise StopAsyncIteration
997-
998930
mock_responses = [
999931
types.GenerateContentResponse(
1000932
candidates=[
@@ -1058,21 +990,6 @@ async def test_generate_content_async_stream_no_aggregated_content_without_text(
1058990
)
1059991

1060992
with mock.patch.object(gemini_llm, "api_client") as mock_client:
1061-
1062-
class MockAsyncIterator:
1063-
1064-
def __init__(self, seq):
1065-
self.iter = iter(seq)
1066-
1067-
def __aiter__(self):
1068-
return self
1069-
1070-
async def __anext__(self):
1071-
try:
1072-
return next(self.iter)
1073-
except StopIteration:
1074-
raise StopAsyncIteration
1075-
1076993
# Mock response with no text content
1077994
mock_responses = [
1078995
types.GenerateContentResponse(
@@ -1127,21 +1044,6 @@ async def test_generate_content_async_stream_mixed_text_function_call_text():
11271044
)
11281045

11291046
with mock.patch.object(gemini_llm, "api_client") as mock_client:
1130-
1131-
class MockAsyncIterator:
1132-
1133-
def __init__(self, seq):
1134-
self.iter = iter(seq)
1135-
1136-
def __aiter__(self):
1137-
return self
1138-
1139-
async def __anext__(self):
1140-
try:
1141-
return next(self.iter)
1142-
except StopIteration:
1143-
raise StopAsyncIteration
1144-
11451047
# Create responses with pattern: text -> function_call -> text
11461048
mock_responses = [
11471049
# First text chunk
@@ -1247,21 +1149,6 @@ async def test_generate_content_async_stream_multiple_text_parts_in_single_respo
12471149
)
12481150

12491151
with mock.patch.object(gemini_llm, "api_client") as mock_client:
1250-
1251-
class MockAsyncIterator:
1252-
1253-
def __init__(self, seq):
1254-
self.iter = iter(seq)
1255-
1256-
def __aiter__(self):
1257-
return self
1258-
1259-
async def __anext__(self):
1260-
try:
1261-
return next(self.iter)
1262-
except StopIteration:
1263-
raise StopAsyncIteration
1264-
12651152
# Create a response with multiple text parts
12661153
mock_responses = [
12671154
types.GenerateContentResponse(
@@ -1314,21 +1201,6 @@ async def test_generate_content_async_stream_complex_mixed_thought_text_function
13141201
)
13151202

13161203
with mock.patch.object(gemini_llm, "api_client") as mock_client:
1317-
1318-
class MockAsyncIterator:
1319-
1320-
def __init__(self, seq):
1321-
self.iter = iter(seq)
1322-
1323-
def __aiter__(self):
1324-
return self
1325-
1326-
async def __anext__(self):
1327-
try:
1328-
return next(self.iter)
1329-
except StopIteration:
1330-
raise StopAsyncIteration
1331-
13321204
# Complex pattern: thought -> text -> function_call -> thought -> text
13331205
mock_responses = [
13341206
# Thought
@@ -1450,21 +1322,6 @@ async def test_generate_content_async_stream_two_separate_text_aggregations():
14501322
)
14511323

14521324
with mock.patch.object(gemini_llm, "api_client") as mock_client:
1453-
1454-
class MockAsyncIterator:
1455-
1456-
def __init__(self, seq):
1457-
self.iter = iter(seq)
1458-
1459-
def __aiter__(self):
1460-
return self
1461-
1462-
async def __anext__(self):
1463-
try:
1464-
return next(self.iter)
1465-
except StopIteration:
1466-
raise StopAsyncIteration
1467-
14681325
# Create responses: multiple text chunks -> function_call -> multiple text chunks
14691326
mock_responses = [
14701327
# First text accumulation (multiple chunks)

0 commit comments

Comments
 (0)