diff --git a/ddtrace/internal/remoteconfig/products/apm_tracing.py b/ddtrace/internal/remoteconfig/products/apm_tracing.py index 9abbd23bba2..0ff977ebf9f 100644 --- a/ddtrace/internal/remoteconfig/products/apm_tracing.py +++ b/ddtrace/internal/remoteconfig/products/apm_tracing.py @@ -1,3 +1,5 @@ +from collections import ChainMap +import enum import typing as t from ddtrace import config @@ -17,20 +19,23 @@ log = get_logger(__name__) -def _rc_callback(payloads: t.Sequence[Payload]) -> None: - for payload in payloads: - if payload.metadata is None or (content := payload.content) is None: - continue +class APMCapabilities(enum.IntFlag): + APM_TRACING_MULTICONFIG = 1 << 45 - if (service_target := t.cast(t.Optional[dict], content.get("service_target"))) is not None: - if (service := t.cast(str, service_target.get("service"))) is not None and service != config.service: - continue - if (env := t.cast(str, service_target.get("env"))) is not None and env != config.env: - continue +def config_key(payload: Payload) -> int: + content = t.cast(dict, payload.content) + + service_target = t.cast(t.Optional[dict], content.get("service_target")) + service = t.cast(str, service_target.get("service")) if service_target is not None else None + env = t.cast(str, service_target.get("env")) if service_target is not None else None + cluster_target = t.cast(t.Optional[dict], content.get("k8s_target_v2")) - if (lib_config := t.cast(dict, content.get("lib_config"))) is not None: - dispatch("apm-tracing.rc", (lib_config, config)) + return ( + ((service is not None and service != "*") << 2) + | ((env is not None and env != "*") << 1) + | (cluster_target is not None) << 0 + ) class APMTracingAdapter(PubSub): @@ -40,7 +45,54 @@ class APMTracingAdapter(PubSub): def __init__(self): self._publisher = self.__publisher_class__(self.__shared_data__) - self._subscriber = self.__subscriber_class__(self.__shared_data__, _rc_callback, "APM_TRACING") + self._subscriber = self.__subscriber_class__(self.__shared_data__, self.rc_callback, "APM_TRACING") + + # Configuration overrides + self.config_map = {} # type: t.Dict[str, Payload] + + def get_chained_lib_config(self) -> t.ChainMap: + return ChainMap( + *( + t.cast(dict, content["lib_config"]) + for content in ( + p.content + for p in sorted(self.config_map.values(), key=config_key, reverse=True) + if p.content is not None and "lib_config" in p.content + ) + ) + ) + + def rc_callback(self, payloads: t.Sequence[Payload]) -> None: + seen_config_ids = set() + for payload in payloads: + if payload.metadata is None: + continue + + config_id = payload.metadata.id + seen_config_ids.add(config_id) + + if (content := payload.content) is None: + continue + + service_target = t.cast(t.Optional[dict], content.get("service_target")) + + service = t.cast(str, service_target.get("service")) if service_target is not None else None + env = t.cast(str, service_target.get("env")) if service_target is not None else None + + if service is not None and service != "*" and service != config.service: + continue + + if env is not None and env != "*" and env != config.env: + continue + + self.config_map[config_id] = payload + + # Remove configurations that are no longer present + for config_id in set(self.config_map.keys()) - seen_config_ids: + log.debug("Removing APM tracing config %s", config_id) + self.config_map.pop(config_id, None) + + dispatch("apm-tracing.rc", (self.get_chained_lib_config(), config)) def post_preload():