Skip to content

fix: merge tracking headers even when llm_request.config.http_options is not set in Gemini.generate_content_async #2190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ async def generate_content_async(
)
logger.debug(_build_request_log(llm_request))

# add tracking headers to custom headers given it will override the headers
# set in the api client constructor
if llm_request.config and llm_request.config.http_options:
# Always add tracking headers to custom headers given it will override
# the headers set in the api client constructor to avoid tracking headers
# being dropped if user provides custom headers or overrides the api client.
if llm_request.config:
if not llm_request.config.http_options:
llm_request.config.http_options = types.HttpOptions()
if not llm_request.config.http_options.headers:
llm_request.config.http_options.headers = {}
llm_request.config.http_options.headers.update(self._tracking_headers)
Expand Down
251 changes: 54 additions & 197 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,27 @@
from google.adk.models.llm_response import LlmResponse
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai import version as genai_version
from google.genai.types import Content
from google.genai.types import Part
import pytest


class MockAsyncIterator:
"""Mock for async iterator."""

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration as exc:
raise StopAsyncIteration from exc


@pytest.fixture
def generate_content_response():
return types.GenerateContentResponse(
Expand Down Expand Up @@ -215,21 +230,6 @@ async def mock_coro():
@pytest.mark.asyncio
async def test_generate_content_async_stream(gemini_llm, llm_request):
with mock.patch.object(gemini_llm, "api_client") as mock_client:
# Create mock stream responses
class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

mock_responses = [
types.GenerateContentResponse(
candidates=[
Expand Down Expand Up @@ -292,21 +292,6 @@ async def test_generate_content_async_stream_preserves_thinking_and_text_parts(
gemini_llm, llm_request
):
with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self._iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self._iter)
except StopIteration:
raise StopAsyncIteration

response1 = types.GenerateContentResponse(
candidates=[
types.Candidate(
Expand Down Expand Up @@ -436,21 +421,6 @@ async def test_generate_content_async_stream_with_custom_headers(
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)

with mock.patch.object(gemini_llm, "api_client") as mock_client:
# Create mock stream responses
class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

mock_responses = [
types.GenerateContentResponse(
candidates=[
Expand Down Expand Up @@ -488,35 +458,58 @@ async def mock_coro():
assert len(responses) == 2


@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.asyncio
async def test_generate_content_async_without_custom_headers(
gemini_llm, llm_request, generate_content_response
async def test_generate_content_async_patches_tracking_headers(
stream, gemini_llm, llm_request, generate_content_response
):
"""Test that tracking headers are not modified when no custom headers exist."""
# Ensure no http_options exist initially
"""Tests that tracking headers are added to the request config."""
# Set the request's config.http_options to None.
llm_request.config.http_options = None

with mock.patch.object(gemini_llm, "api_client") as mock_client:
if stream:
# Create a mock coroutine that returns the mock_responses.
async def mock_coro():
return MockAsyncIterator([generate_content_response])

async def mock_coro():
return generate_content_response
# Mock for streaming response.
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
else:
# Create a mock coroutine that returns the generate_content_response.
async def mock_coro():
return generate_content_response

mock_client.aio.models.generate_content.return_value = mock_coro()
# Mock for non-streaming response.
mock_client.aio.models.generate_content.return_value = mock_coro()

# Call the generate_content_async method.
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=False
llm_request, stream=stream
)
]

# Verify that the config passed to generate_content has no http_options
mock_client.aio.models.generate_content.assert_called_once()
call_args = mock_client.aio.models.generate_content.call_args
config_arg = call_args.kwargs["config"]
assert config_arg.http_options is None
# Assert that the config passed to the generate_content or
# generate_content_stream method contains the tracking headers.
if stream:
mock_client.aio.models.generate_content_stream.assert_called_once()
call_args = mock_client.aio.models.generate_content_stream.call_args
else:
mock_client.aio.models.generate_content.assert_called_once()
call_args = mock_client.aio.models.generate_content.call_args

assert len(responses) == 1
final_config = call_args.kwargs["config"]

assert final_config is not None
assert final_config.http_options is not None
assert (
final_config.http_options.headers["x-goog-api-client"]
== gemini_llm._tracking_headers["x-goog-api-client"]
)

assert len(responses) == 2 if stream else 1


def test_live_api_version_vertex_ai(gemini_llm):
Expand Down Expand Up @@ -665,8 +658,7 @@ async def test_preprocess_request_handles_backend_specific_fields(
expected_inline_display_name: Optional[str],
expected_labels: Optional[str],
):
"""
Tests that _preprocess_request correctly sanitizes fields based on the API backend.
"""Tests that _preprocess_request correctly sanitizes fields based on the API backend.

- For GEMINI_API, it should remove 'display_name' from file/inline data
and remove 'labels' from the config.
Expand Down Expand Up @@ -732,21 +724,6 @@ async def test_generate_content_async_stream_aggregated_content_regardless_of_fi
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

# Test with different finish reasons
test_cases = [
types.FinishReason.MAX_TOKENS,
Expand Down Expand Up @@ -820,21 +797,6 @@ async def test_generate_content_async_stream_with_thought_and_text_error_handlin
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

mock_responses = [
types.GenerateContentResponse(
candidates=[
Expand Down Expand Up @@ -902,21 +864,6 @@ async def test_generate_content_async_stream_error_info_none_for_stop_finish_rea
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

mock_responses = [
types.GenerateContentResponse(
candidates=[
Expand Down Expand Up @@ -980,21 +927,6 @@ async def test_generate_content_async_stream_error_info_set_for_non_stop_finish_
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

mock_responses = [
types.GenerateContentResponse(
candidates=[
Expand Down Expand Up @@ -1058,21 +990,6 @@ async def test_generate_content_async_stream_no_aggregated_content_without_text(
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

# Mock response with no text content
mock_responses = [
types.GenerateContentResponse(
Expand Down Expand Up @@ -1127,21 +1044,6 @@ async def test_generate_content_async_stream_mixed_text_function_call_text():
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

# Create responses with pattern: text -> function_call -> text
mock_responses = [
# First text chunk
Expand Down Expand Up @@ -1247,21 +1149,6 @@ async def test_generate_content_async_stream_multiple_text_parts_in_single_respo
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

# Create a response with multiple text parts
mock_responses = [
types.GenerateContentResponse(
Expand Down Expand Up @@ -1314,21 +1201,6 @@ async def test_generate_content_async_stream_complex_mixed_thought_text_function
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

# Complex pattern: thought -> text -> function_call -> thought -> text
mock_responses = [
# Thought
Expand Down Expand Up @@ -1450,21 +1322,6 @@ async def test_generate_content_async_stream_two_separate_text_aggregations():
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

class MockAsyncIterator:

def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration

# Create responses: multiple text chunks -> function_call -> multiple text chunks
mock_responses = [
# First text accumulation (multiple chunks)
Expand Down