Skip to content

Commit 77ade82

Browse files
authored
set_correlation_ids should handle the empty context case (#334)
* set_correlation_ids handles empty trace context case
1 parent 834aa15 commit 77ade82

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

datadog_lambda/tracing.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,12 @@ def get_dd_trace_context():
515515
automatically, but this function can be used to manually inject the trace
516516
context to an outgoing request.
517517
"""
518-
global dd_trace_context
518+
if dd_tracing_enabled:
519+
dd_trace_py_context = _get_dd_trace_py_context()
520+
if dd_trace_py_context is not None:
521+
return _context_obj_to_headers(dd_trace_py_context)
519522

520-
context = None
521-
xray_context = None
523+
global dd_trace_context
522524

523525
try:
524526
xray_context = _get_xray_trace_context() # xray (sub)segment
@@ -527,22 +529,17 @@ def get_dd_trace_context():
527529
"get_dd_trace_context couldn't read from segment from x-ray, with error %s"
528530
% e
529531
)
532+
if not xray_context:
533+
return {}
530534

531-
if xray_context and not dd_trace_context:
532-
context = xray_context
533-
elif xray_context and dd_trace_context:
534-
context = dd_trace_context.copy()
535-
context["parent-id"] = xray_context.get("parent-id")
536-
logger.debug(
537-
"Set parent id from xray trace context: %s", context.get("parent-id")
538-
)
535+
if not dd_trace_context:
536+
return _context_obj_to_headers(xray_context)
539537

540-
if dd_tracing_enabled:
541-
dd_trace_py_context = _get_dd_trace_py_context()
542-
if dd_trace_py_context is not None:
543-
context = dd_trace_py_context
538+
context = dd_trace_context.copy()
539+
context["parent-id"] = xray_context.get("parent-id")
540+
logger.debug("Set parent id from xray trace context: %s", context.get("parent-id"))
544541

545-
return _context_obj_to_headers(context) if context is not None else {}
542+
return _context_obj_to_headers(context)
546543

547544

548545
def set_correlation_ids():
@@ -561,6 +558,8 @@ def set_correlation_ids():
561558
return
562559

563560
context = get_dd_trace_context()
561+
if not context:
562+
return
564563

565564
span = tracer.trace("dummy.span")
566565
span.trace_id = int(context[TraceHeader.TRACE_ID])

tests/test_tracing.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,32 @@
4242

4343
event_samples = "tests/event_samples/"
4444

45+
span_to_finish = None
46+
47+
48+
def _clean_up_span():
49+
global span_to_finish
50+
if span_to_finish is not None:
51+
span_to_finish.finish()
52+
span_to_finish = None
53+
54+
55+
def register_span(span):
56+
global span_to_finish
57+
_clean_up_span()
58+
span_to_finish = span
59+
return span
60+
61+
62+
def wrapped_span_creator(span_creator_func):
63+
def result_func(*args, **kwargs):
64+
return register_span(span_creator_func(*args, **kwargs))
65+
66+
return result_func
67+
68+
69+
create_inferred_span = wrapped_span_creator(create_inferred_span)
70+
4571

4672
class ClientContext(object):
4773
def __init__(self, custom=None):
@@ -482,6 +508,15 @@ def test_set_correlation_ids(self):
482508
span = tracer.current_span()
483509
self.assertEqual(span.trace_id, 123)
484510
self.assertEqual(span.span_id, 456)
511+
span.finish()
512+
513+
def test_set_correlation_ids_handle_empty_trace_context(self):
514+
# neither x-ray or ddtrace is used. no tracing context at all.
515+
self.mock_get_dd_trace_context.return_value = {}
516+
# no exception thrown
517+
set_correlation_ids()
518+
span = tracer.current_span()
519+
self.assertIsNone(span)
485520

486521

487522
class TestFunctionSpanTags(unittest.TestCase):
@@ -587,6 +622,9 @@ def setUp(self):
587622
self.mock_span_stop = patcher.start()
588623
self.addCleanup(patcher.stop)
589624

625+
def tearDown(self):
626+
_clean_up_span()
627+
590628
def test_create_inferred_span_from_authorizer_request_api_gateway_v1_event(self):
591629
event_sample_source = "authorizer-request-api-gateway-v1"
592630
finish_time = (
@@ -733,6 +771,9 @@ def _basic_common_checks(
733771

734772

735773
class TestInferredSpans(unittest.TestCase):
774+
def tearDown(self):
775+
_clean_up_span()
776+
736777
def test_create_inferred_span_from_api_gateway_event(self):
737778
event_sample_source = "api-gateway"
738779
test_file = event_samples + event_sample_source + ".json"

0 commit comments

Comments
 (0)