|
27 | 27 | from google.adk.models.llm_response import LlmResponse
|
28 | 28 | from google.adk.utils.variant_utils import GoogleLLMVariant
|
29 | 29 | from google.genai import types
|
30 |
| -from google.genai import version as genai_version |
31 | 30 | from google.genai.types import Content
|
32 | 31 | from google.genai.types import Part
|
33 | 32 | import pytest
|
34 | 33 |
|
35 | 34 |
|
| 35 | +class MockAsyncIterator: |
| 36 | + """Mock for async iterator.""" |
| 37 | + |
| 38 | + def __init__(self, seq): |
| 39 | + self.iter = iter(seq) |
| 40 | + |
| 41 | + def __aiter__(self): |
| 42 | + return self |
| 43 | + |
| 44 | + async def __anext__(self): |
| 45 | + try: |
| 46 | + return next(self.iter) |
| 47 | + except StopIteration as exc: |
| 48 | + raise StopAsyncIteration from exc |
| 49 | + |
| 50 | + |
36 | 51 | @pytest.fixture
|
37 | 52 | def generate_content_response():
|
38 | 53 | return types.GenerateContentResponse(
|
@@ -215,21 +230,6 @@ async def mock_coro():
|
215 | 230 | @pytest.mark.asyncio
|
216 | 231 | async def test_generate_content_async_stream(gemini_llm, llm_request):
|
217 | 232 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
218 |
| - # Create mock stream responses |
219 |
| - class MockAsyncIterator: |
220 |
| - |
221 |
| - def __init__(self, seq): |
222 |
| - self.iter = iter(seq) |
223 |
| - |
224 |
| - def __aiter__(self): |
225 |
| - return self |
226 |
| - |
227 |
| - async def __anext__(self): |
228 |
| - try: |
229 |
| - return next(self.iter) |
230 |
| - except StopIteration: |
231 |
| - raise StopAsyncIteration |
232 |
| - |
233 | 233 | mock_responses = [
|
234 | 234 | types.GenerateContentResponse(
|
235 | 235 | candidates=[
|
@@ -292,21 +292,6 @@ async def test_generate_content_async_stream_preserves_thinking_and_text_parts(
|
292 | 292 | gemini_llm, llm_request
|
293 | 293 | ):
|
294 | 294 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
295 |
| - |
296 |
| - class MockAsyncIterator: |
297 |
| - |
298 |
| - def __init__(self, seq): |
299 |
| - self._iter = iter(seq) |
300 |
| - |
301 |
| - def __aiter__(self): |
302 |
| - return self |
303 |
| - |
304 |
| - async def __anext__(self): |
305 |
| - try: |
306 |
| - return next(self._iter) |
307 |
| - except StopIteration: |
308 |
| - raise StopAsyncIteration |
309 |
| - |
310 | 295 | response1 = types.GenerateContentResponse(
|
311 | 296 | candidates=[
|
312 | 297 | types.Candidate(
|
@@ -436,21 +421,6 @@ async def test_generate_content_async_stream_with_custom_headers(
|
436 | 421 | llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
|
437 | 422 |
|
438 | 423 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
439 |
| - # Create mock stream responses |
440 |
| - class MockAsyncIterator: |
441 |
| - |
442 |
| - def __init__(self, seq): |
443 |
| - self.iter = iter(seq) |
444 |
| - |
445 |
| - def __aiter__(self): |
446 |
| - return self |
447 |
| - |
448 |
| - async def __anext__(self): |
449 |
| - try: |
450 |
| - return next(self.iter) |
451 |
| - except StopIteration: |
452 |
| - raise StopAsyncIteration |
453 |
| - |
454 | 424 | mock_responses = [
|
455 | 425 | types.GenerateContentResponse(
|
456 | 426 | candidates=[
|
@@ -488,35 +458,58 @@ async def mock_coro():
|
488 | 458 | assert len(responses) == 2
|
489 | 459 |
|
490 | 460 |
|
| 461 | +@pytest.mark.parametrize("stream", [True, False]) |
491 | 462 | @pytest.mark.asyncio
|
492 |
| -async def test_generate_content_async_without_custom_headers( |
493 |
| - gemini_llm, llm_request, generate_content_response |
| 463 | +async def test_generate_content_async_patches_tracking_headers( |
| 464 | + stream, gemini_llm, llm_request, generate_content_response |
494 | 465 | ):
|
495 |
| - """Test that tracking headers are not modified when no custom headers exist.""" |
496 |
| - # Ensure no http_options exist initially |
| 466 | + """Tests that tracking headers are added to the request config.""" |
| 467 | + # Set the request's config.http_options to None. |
497 | 468 | llm_request.config.http_options = None
|
498 | 469 |
|
499 | 470 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
| 471 | + if stream: |
| 472 | + # Create a mock coroutine that returns the mock_responses. |
| 473 | + async def mock_coro(): |
| 474 | + return MockAsyncIterator([generate_content_response]) |
500 | 475 |
|
501 |
| - async def mock_coro(): |
502 |
| - return generate_content_response |
| 476 | + # Mock for streaming response. |
| 477 | + mock_client.aio.models.generate_content_stream.return_value = mock_coro() |
| 478 | + else: |
| 479 | + # Create a mock coroutine that returns the generate_content_response. |
| 480 | + async def mock_coro(): |
| 481 | + return generate_content_response |
503 | 482 |
|
504 |
| - mock_client.aio.models.generate_content.return_value = mock_coro() |
| 483 | + # Mock for non-streaming response. |
| 484 | + mock_client.aio.models.generate_content.return_value = mock_coro() |
505 | 485 |
|
| 486 | + # Call the generate_content_async method. |
506 | 487 | responses = [
|
507 | 488 | resp
|
508 | 489 | async for resp in gemini_llm.generate_content_async(
|
509 |
| - llm_request, stream=False |
| 490 | + llm_request, stream=stream |
510 | 491 | )
|
511 | 492 | ]
|
512 | 493 |
|
513 |
| - # Verify that the config passed to generate_content has no http_options |
514 |
| - mock_client.aio.models.generate_content.assert_called_once() |
515 |
| - call_args = mock_client.aio.models.generate_content.call_args |
516 |
| - config_arg = call_args.kwargs["config"] |
517 |
| - assert config_arg.http_options is None |
| 494 | + # Assert that the config passed to the generate_content or |
| 495 | + # generate_content_stream method contains the tracking headers. |
| 496 | + if stream: |
| 497 | + mock_client.aio.models.generate_content_stream.assert_called_once() |
| 498 | + call_args = mock_client.aio.models.generate_content_stream.call_args |
| 499 | + else: |
| 500 | + mock_client.aio.models.generate_content.assert_called_once() |
| 501 | + call_args = mock_client.aio.models.generate_content.call_args |
518 | 502 |
|
519 |
| - assert len(responses) == 1 |
| 503 | + final_config = call_args.kwargs["config"] |
| 504 | + |
| 505 | + assert final_config is not None |
| 506 | + assert final_config.http_options is not None |
| 507 | + assert ( |
| 508 | + final_config.http_options.headers["x-goog-api-client"] |
| 509 | + == gemini_llm._tracking_headers["x-goog-api-client"] |
| 510 | + ) |
| 511 | + |
| 512 | + assert len(responses) == 2 if stream else 1 |
520 | 513 |
|
521 | 514 |
|
522 | 515 | def test_live_api_version_vertex_ai(gemini_llm):
|
@@ -665,8 +658,7 @@ async def test_preprocess_request_handles_backend_specific_fields(
|
665 | 658 | expected_inline_display_name: Optional[str],
|
666 | 659 | expected_labels: Optional[str],
|
667 | 660 | ):
|
668 |
| - """ |
669 |
| - Tests that _preprocess_request correctly sanitizes fields based on the API backend. |
| 661 | + """Tests that _preprocess_request correctly sanitizes fields based on the API backend. |
670 | 662 |
|
671 | 663 | - For GEMINI_API, it should remove 'display_name' from file/inline data
|
672 | 664 | and remove 'labels' from the config.
|
@@ -732,21 +724,6 @@ async def test_generate_content_async_stream_aggregated_content_regardless_of_fi
|
732 | 724 | )
|
733 | 725 |
|
734 | 726 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
735 |
| - |
736 |
| - class MockAsyncIterator: |
737 |
| - |
738 |
| - def __init__(self, seq): |
739 |
| - self.iter = iter(seq) |
740 |
| - |
741 |
| - def __aiter__(self): |
742 |
| - return self |
743 |
| - |
744 |
| - async def __anext__(self): |
745 |
| - try: |
746 |
| - return next(self.iter) |
747 |
| - except StopIteration: |
748 |
| - raise StopAsyncIteration |
749 |
| - |
750 | 727 | # Test with different finish reasons
|
751 | 728 | test_cases = [
|
752 | 729 | types.FinishReason.MAX_TOKENS,
|
@@ -820,21 +797,6 @@ async def test_generate_content_async_stream_with_thought_and_text_error_handlin
|
820 | 797 | )
|
821 | 798 |
|
822 | 799 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
823 |
| - |
824 |
| - class MockAsyncIterator: |
825 |
| - |
826 |
| - def __init__(self, seq): |
827 |
| - self.iter = iter(seq) |
828 |
| - |
829 |
| - def __aiter__(self): |
830 |
| - return self |
831 |
| - |
832 |
| - async def __anext__(self): |
833 |
| - try: |
834 |
| - return next(self.iter) |
835 |
| - except StopIteration: |
836 |
| - raise StopAsyncIteration |
837 |
| - |
838 | 800 | mock_responses = [
|
839 | 801 | types.GenerateContentResponse(
|
840 | 802 | candidates=[
|
@@ -902,21 +864,6 @@ async def test_generate_content_async_stream_error_info_none_for_stop_finish_rea
|
902 | 864 | )
|
903 | 865 |
|
904 | 866 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
905 |
| - |
906 |
| - class MockAsyncIterator: |
907 |
| - |
908 |
| - def __init__(self, seq): |
909 |
| - self.iter = iter(seq) |
910 |
| - |
911 |
| - def __aiter__(self): |
912 |
| - return self |
913 |
| - |
914 |
| - async def __anext__(self): |
915 |
| - try: |
916 |
| - return next(self.iter) |
917 |
| - except StopIteration: |
918 |
| - raise StopAsyncIteration |
919 |
| - |
920 | 867 | mock_responses = [
|
921 | 868 | types.GenerateContentResponse(
|
922 | 869 | candidates=[
|
@@ -980,21 +927,6 @@ async def test_generate_content_async_stream_error_info_set_for_non_stop_finish_
|
980 | 927 | )
|
981 | 928 |
|
982 | 929 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
983 |
| - |
984 |
| - class MockAsyncIterator: |
985 |
| - |
986 |
| - def __init__(self, seq): |
987 |
| - self.iter = iter(seq) |
988 |
| - |
989 |
| - def __aiter__(self): |
990 |
| - return self |
991 |
| - |
992 |
| - async def __anext__(self): |
993 |
| - try: |
994 |
| - return next(self.iter) |
995 |
| - except StopIteration: |
996 |
| - raise StopAsyncIteration |
997 |
| - |
998 | 930 | mock_responses = [
|
999 | 931 | types.GenerateContentResponse(
|
1000 | 932 | candidates=[
|
@@ -1058,21 +990,6 @@ async def test_generate_content_async_stream_no_aggregated_content_without_text(
|
1058 | 990 | )
|
1059 | 991 |
|
1060 | 992 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
1061 |
| - |
1062 |
| - class MockAsyncIterator: |
1063 |
| - |
1064 |
| - def __init__(self, seq): |
1065 |
| - self.iter = iter(seq) |
1066 |
| - |
1067 |
| - def __aiter__(self): |
1068 |
| - return self |
1069 |
| - |
1070 |
| - async def __anext__(self): |
1071 |
| - try: |
1072 |
| - return next(self.iter) |
1073 |
| - except StopIteration: |
1074 |
| - raise StopAsyncIteration |
1075 |
| - |
1076 | 993 | # Mock response with no text content
|
1077 | 994 | mock_responses = [
|
1078 | 995 | types.GenerateContentResponse(
|
@@ -1127,21 +1044,6 @@ async def test_generate_content_async_stream_mixed_text_function_call_text():
|
1127 | 1044 | )
|
1128 | 1045 |
|
1129 | 1046 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
1130 |
| - |
1131 |
| - class MockAsyncIterator: |
1132 |
| - |
1133 |
| - def __init__(self, seq): |
1134 |
| - self.iter = iter(seq) |
1135 |
| - |
1136 |
| - def __aiter__(self): |
1137 |
| - return self |
1138 |
| - |
1139 |
| - async def __anext__(self): |
1140 |
| - try: |
1141 |
| - return next(self.iter) |
1142 |
| - except StopIteration: |
1143 |
| - raise StopAsyncIteration |
1144 |
| - |
1145 | 1047 | # Create responses with pattern: text -> function_call -> text
|
1146 | 1048 | mock_responses = [
|
1147 | 1049 | # First text chunk
|
@@ -1247,21 +1149,6 @@ async def test_generate_content_async_stream_multiple_text_parts_in_single_respo
|
1247 | 1149 | )
|
1248 | 1150 |
|
1249 | 1151 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
1250 |
| - |
1251 |
| - class MockAsyncIterator: |
1252 |
| - |
1253 |
| - def __init__(self, seq): |
1254 |
| - self.iter = iter(seq) |
1255 |
| - |
1256 |
| - def __aiter__(self): |
1257 |
| - return self |
1258 |
| - |
1259 |
| - async def __anext__(self): |
1260 |
| - try: |
1261 |
| - return next(self.iter) |
1262 |
| - except StopIteration: |
1263 |
| - raise StopAsyncIteration |
1264 |
| - |
1265 | 1152 | # Create a response with multiple text parts
|
1266 | 1153 | mock_responses = [
|
1267 | 1154 | types.GenerateContentResponse(
|
@@ -1314,21 +1201,6 @@ async def test_generate_content_async_stream_complex_mixed_thought_text_function
|
1314 | 1201 | )
|
1315 | 1202 |
|
1316 | 1203 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
1317 |
| - |
1318 |
| - class MockAsyncIterator: |
1319 |
| - |
1320 |
| - def __init__(self, seq): |
1321 |
| - self.iter = iter(seq) |
1322 |
| - |
1323 |
| - def __aiter__(self): |
1324 |
| - return self |
1325 |
| - |
1326 |
| - async def __anext__(self): |
1327 |
| - try: |
1328 |
| - return next(self.iter) |
1329 |
| - except StopIteration: |
1330 |
| - raise StopAsyncIteration |
1331 |
| - |
1332 | 1204 | # Complex pattern: thought -> text -> function_call -> thought -> text
|
1333 | 1205 | mock_responses = [
|
1334 | 1206 | # Thought
|
@@ -1450,21 +1322,6 @@ async def test_generate_content_async_stream_two_separate_text_aggregations():
|
1450 | 1322 | )
|
1451 | 1323 |
|
1452 | 1324 | with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
1453 |
| - |
1454 |
| - class MockAsyncIterator: |
1455 |
| - |
1456 |
| - def __init__(self, seq): |
1457 |
| - self.iter = iter(seq) |
1458 |
| - |
1459 |
| - def __aiter__(self): |
1460 |
| - return self |
1461 |
| - |
1462 |
| - async def __anext__(self): |
1463 |
| - try: |
1464 |
| - return next(self.iter) |
1465 |
| - except StopIteration: |
1466 |
| - raise StopAsyncIteration |
1467 |
| - |
1468 | 1325 | # Create responses: multiple text chunks -> function_call -> multiple text chunks
|
1469 | 1326 | mock_responses = [
|
1470 | 1327 | # First text accumulation (multiple chunks)
|
|
0 commit comments