Skip to content

Commit ad3d329

Browse files
authored
fixes: dont_throw on chunk processing, unify some async and sync work (#181)
* fixes: dont_throw on chunk processing, unify some async and sync work * minor fix: replace nested if with and * further small fixes * remove duplicated code * bump version to 0.7.9
1 parent 9725bc5 commit ad3d329

File tree

10 files changed

+558
-200
lines changed

10 files changed

+558
-200
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
[project]
88
name = "lmnr"
9-
version = "0.7.8"
9+
version = "0.7.9"
1010
description = "Python SDK for Laminar"
1111
authors = [
1212
{ name = "lmnr.ai", email = "[email protected]" }

src/lmnr/opentelemetry_lib/opentelemetry/instrumentation/google_genai/__init__.py

Lines changed: 103 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from google.genai import types
1010

11+
from lmnr.opentelemetry_lib.decorators import json_dumps
1112
from lmnr.opentelemetry_lib.tracing.context import (
1213
get_current_context,
1314
get_event_attributes_from_context,
@@ -20,9 +21,10 @@
2021
from .utils import (
2122
dont_throw,
2223
get_content,
24+
process_content_union,
25+
process_stream_chunk,
2326
role_from_content_union,
2427
set_span_attribute,
25-
process_content_union,
2628
to_dict,
2729
with_tracer_wrapper,
2830
)
@@ -139,9 +141,7 @@ def _set_request_attributes(span, args, kwargs):
139141
try:
140142
set_span_attribute(
141143
span,
142-
# TODO: change to SpanAttributes.LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA
143-
# when we upgrade to opentelemetry-semantic-conventions-ai>=0.4.10
144-
"gen_ai.request.structured_output_schema",
144+
SpanAttributes.LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA,
145145
json.dumps(process_schema(schema), cls=SchemaJSONEncoder),
146146
)
147147
except Exception:
@@ -150,10 +150,8 @@ def _set_request_attributes(span, args, kwargs):
150150
try:
151151
set_span_attribute(
152152
span,
153-
# TODO: change to SpanAttributes.LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA
154-
# when we upgrade to opentelemetry-semantic-conventions-ai>=0.4.10
155-
"gen_ai.request.structured_output_schema",
156-
json.dumps(json_schema),
153+
SpanAttributes.LLM_REQUEST_STRUCTURED_OUTPUT_SCHEMA,
154+
json_dumps(json_schema),
157155
)
158156
except Exception:
159157
pass
@@ -182,7 +180,7 @@ def _set_request_attributes(span, args, kwargs):
182180
set_span_attribute(
183181
span,
184182
f"{SpanAttributes.LLM_REQUEST_FUNCTIONS}.{tool_num}.parameters",
185-
json.dumps(tool_dict.get("parameters")),
183+
json_dumps(tool_dict.get("parameters")),
186184
)
187185

188186
if should_send_prompts():
@@ -215,7 +213,7 @@ def _set_request_attributes(span, args, kwargs):
215213
(
216214
content_str
217215
if isinstance(content_str, str)
218-
else json.dumps(content_str)
216+
else json_dumps(content_str)
219217
),
220218
)
221219
blocks = (
@@ -248,7 +246,7 @@ def _set_request_attributes(span, args, kwargs):
248246
set_span_attribute(
249247
span,
250248
f"{gen_ai_attributes.GEN_AI_PROMPT}.{i}.tool_calls.{tool_call_index}.arguments",
251-
json.dumps(function_call.get("arguments")),
249+
json_dumps(function_call.get("arguments")),
252250
)
253251
tool_call_index += 1
254252

@@ -300,22 +298,26 @@ def _set_response_attributes(span, response: types.GenerateContentResponse):
300298
span, f"{gen_ai_attributes.GEN_AI_COMPLETION}.0.role", "model"
301299
)
302300
candidates_list = candidates if isinstance(candidates, list) else [candidates]
303-
for i, candidate in enumerate(candidates_list):
301+
i = 0
302+
for candidate in candidates_list:
303+
has_content = False
304304
processed_content = process_content_union(candidate.content)
305305
content_str = get_content(processed_content)
306306

307307
set_span_attribute(
308308
span, f"{gen_ai_attributes.GEN_AI_COMPLETION}.{i}.role", "model"
309309
)
310-
set_span_attribute(
311-
span,
312-
f"{gen_ai_attributes.GEN_AI_COMPLETION}.{i}.content",
313-
(
314-
content_str
315-
if isinstance(content_str, str)
316-
else json.dumps(content_str)
317-
),
318-
)
310+
if content_str:
311+
has_content = True
312+
set_span_attribute(
313+
span,
314+
f"{gen_ai_attributes.GEN_AI_COMPLETION}.{i}.content",
315+
(
316+
content_str
317+
if isinstance(content_str, str)
318+
else json_dumps(content_str)
319+
),
320+
)
319321
blocks = (
320322
processed_content
321323
if isinstance(processed_content, list)
@@ -328,6 +330,7 @@ def _set_response_attributes(span, response: types.GenerateContentResponse):
328330
if not block_dict.get("function_call"):
329331
continue
330332
function_call = to_dict(block_dict.get("function_call", {}))
333+
has_content = True
331334
set_span_attribute(
332335
span,
333336
f"{gen_ai_attributes.GEN_AI_COMPLETION}.{i}.tool_calls.{tool_call_index}.name",
@@ -345,9 +348,11 @@ def _set_response_attributes(span, response: types.GenerateContentResponse):
345348
set_span_attribute(
346349
span,
347350
f"{gen_ai_attributes.GEN_AI_COMPLETION}.{i}.tool_calls.{tool_call_index}.arguments",
348-
json.dumps(function_call.get("arguments")),
351+
json_dumps(function_call.get("arguments")),
349352
)
350353
tool_call_index += 1
354+
if has_content:
355+
i += 1
351356

352357

353358
@dont_throw
@@ -359,54 +364,45 @@ def _build_from_streaming_response(
359364
aggregated_usage_metadata = defaultdict(int)
360365
model_version = None
361366
for chunk in response:
362-
if chunk.model_version:
363-
model_version = chunk.model_version
364-
365-
if chunk.candidates:
366-
# Currently gemini throws an error if you pass more than one candidate
367-
# with streaming
368-
if chunk.candidates and len(chunk.candidates) > 0:
369-
final_parts += chunk.candidates[0].content.parts or []
370-
role = chunk.candidates[0].content.role or role
371-
if chunk.usage_metadata:
372-
usage_dict = to_dict(chunk.usage_metadata)
373-
# prompt token count is sent in every chunk
374-
# (and is less by 1 in the last chunk, so we set it once);
375-
# total token count in every chunk is greater by prompt token count than it should be,
376-
# thus this awkward logic here
377-
if aggregated_usage_metadata.get("prompt_token_count") is None:
378-
# or 0, not .get(key, 0), because sometimes the value is explicitly None
379-
aggregated_usage_metadata["prompt_token_count"] = (
380-
usage_dict.get("prompt_token_count") or 0
381-
)
382-
aggregated_usage_metadata["total_token_count"] = (
383-
usage_dict.get("total_token_count") or 0
384-
)
385-
aggregated_usage_metadata["candidates_token_count"] += (
386-
usage_dict.get("candidates_token_count") or 0
387-
)
388-
aggregated_usage_metadata["total_token_count"] += (
389-
usage_dict.get("candidates_token_count") or 0
390-
)
367+
# Important: do all processing in a separate sync function, that is
368+
# wrapped in @dont_throw. If we did it here, the @dont_throw on top of
369+
# this function would not be able to catch the errors, as they are
370+
# raised later, after the generator is returned, and when it is being
371+
# consumed.
372+
chunk_result = process_stream_chunk(
373+
chunk,
374+
role,
375+
model_version,
376+
aggregated_usage_metadata,
377+
final_parts,
378+
)
379+
# even though process_stream_chunk can't return None, the result can be
380+
# None, if the processing throws an error (see @dont_throw)
381+
if chunk_result:
382+
role = chunk_result["role"]
383+
model_version = chunk_result["model_version"]
391384
yield chunk
392385

393-
compound_response = types.GenerateContentResponse(
394-
candidates=[
395-
{
396-
"content": {
397-
"parts": final_parts,
398-
"role": role,
399-
},
400-
}
401-
],
402-
usage_metadata=types.GenerateContentResponseUsageMetadataDict(
403-
**aggregated_usage_metadata
404-
),
405-
model_version=model_version,
406-
)
407-
if span.is_recording():
408-
_set_response_attributes(span, compound_response)
409-
span.end()
386+
try:
387+
compound_response = types.GenerateContentResponse(
388+
candidates=[
389+
{
390+
"content": {
391+
"parts": final_parts,
392+
"role": role,
393+
},
394+
}
395+
],
396+
usage_metadata=types.GenerateContentResponseUsageMetadataDict(
397+
**aggregated_usage_metadata
398+
),
399+
model_version=model_version,
400+
)
401+
if span.is_recording():
402+
_set_response_attributes(span, compound_response)
403+
finally:
404+
if span.is_recording():
405+
span.end()
410406

411407

412408
@dont_throw
@@ -418,54 +414,45 @@ async def _abuild_from_streaming_response(
418414
aggregated_usage_metadata = defaultdict(int)
419415
model_version = None
420416
async for chunk in response:
421-
if chunk.model_version:
422-
model_version = chunk.model_version
423-
424-
if chunk.candidates:
425-
# Currently gemini throws an error if you pass more than one candidate
426-
# with streaming
427-
if chunk.candidates and len(chunk.candidates) > 0:
428-
final_parts += chunk.candidates[0].content.parts or []
429-
role = chunk.candidates[0].content.role or role
430-
if chunk.usage_metadata:
431-
usage_dict = to_dict(chunk.usage_metadata)
432-
# prompt token count is sent in every chunk
433-
# (and is less by 1 in the last chunk, so we set it once);
434-
# total token count in every chunk is greater by prompt token count than it should be,
435-
# thus this awkward logic here
436-
if aggregated_usage_metadata.get("prompt_token_count") is None:
437-
# or 0, not .get(key, 0), because sometimes the value is explicitly None
438-
aggregated_usage_metadata["prompt_token_count"] = (
439-
usage_dict.get("prompt_token_count") or 0
440-
)
441-
aggregated_usage_metadata["total_token_count"] = (
442-
usage_dict.get("total_token_count") or 0
443-
)
444-
aggregated_usage_metadata["candidates_token_count"] += (
445-
usage_dict.get("candidates_token_count") or 0
446-
)
447-
aggregated_usage_metadata["total_token_count"] += (
448-
usage_dict.get("candidates_token_count") or 0
449-
)
417+
# Important: do all processing in a separate sync function, that is
418+
# wrapped in @dont_throw. If we did it here, the @dont_throw on top of
419+
# this function would not be able to catch the errors, as they are
420+
# raised later, after the generator is returned, and when it is being
421+
# consumed.
422+
chunk_result = process_stream_chunk(
423+
chunk,
424+
role,
425+
model_version,
426+
aggregated_usage_metadata,
427+
final_parts,
428+
)
429+
# even though process_stream_chunk can't return None, the result can be
430+
# None, if the processing throws an error (see @dont_throw)
431+
if chunk_result:
432+
role = chunk_result["role"]
433+
model_version = chunk_result["model_version"]
450434
yield chunk
451435

452-
compound_response = types.GenerateContentResponse(
453-
candidates=[
454-
{
455-
"content": {
456-
"parts": final_parts,
457-
"role": role,
458-
},
459-
}
460-
],
461-
usage_metadata=types.GenerateContentResponseUsageMetadataDict(
462-
**aggregated_usage_metadata
463-
),
464-
model_version=model_version,
465-
)
466-
if span.is_recording():
467-
_set_response_attributes(span, compound_response)
468-
span.end()
436+
try:
437+
compound_response = types.GenerateContentResponse(
438+
candidates=[
439+
{
440+
"content": {
441+
"parts": final_parts,
442+
"role": role,
443+
},
444+
}
445+
],
446+
usage_metadata=types.GenerateContentResponseUsageMetadataDict(
447+
**aggregated_usage_metadata
448+
),
449+
model_version=model_version,
450+
)
451+
if span.is_recording():
452+
_set_response_attributes(span, compound_response)
453+
finally:
454+
if span.is_recording():
455+
span.end()
469456

470457

471458
@with_tracer_wrapper
@@ -502,7 +489,7 @@ def _wrap(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
502489
span.record_exception(e, attributes=attributes)
503490
span.set_status(Status(StatusCode.ERROR, str(e)))
504491
span.end()
505-
raise e
492+
raise
506493

507494

508495
@with_tracer_wrapper
@@ -541,7 +528,7 @@ async def _awrap(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
541528
span.record_exception(e, attributes=attributes)
542529
span.set_status(Status(StatusCode.ERROR, str(e)))
543530
span.end()
544-
raise e
531+
raise
545532

546533

547534
class GoogleGenAiSdkInstrumentor(BaseInstrumentor):

src/lmnr/opentelemetry_lib/opentelemetry/instrumentation/google_genai/schema_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010

1111
def process_schema(schema: Any) -> dict[str, Any]:
1212
# The only thing we need from the client is the t_schema function
13-
json_schema = t_schema(DUMMY_CLIENT, schema).json_schema.model_dump(
14-
exclude_unset=True, exclude_none=True
15-
)
13+
try:
14+
json_schema = t_schema(DUMMY_CLIENT, schema).json_schema.model_dump(
15+
exclude_unset=True, exclude_none=True
16+
)
17+
except Exception:
18+
json_schema = {}
1619
return json_schema
1720

1821

0 commit comments

Comments
 (0)