diff --git a/datadog/dogstatsd/base.py b/datadog/dogstatsd/base.py index 28a62d896..690922c4c 100644 --- a/datadog/dogstatsd/base.py +++ b/datadog/dogstatsd/base.py @@ -1501,29 +1501,31 @@ def pre_fork(self): self._stop_flush_thread() self._stop_sender_thread() + # Prevent concurrent calls to system libraries (notably + # getaddrinfo) which may leave internal locks in a locked + # state and deadlock the child. + self._socket_lock.acquire() + def post_fork_parent(self): """Restore the client state after a fork in the parent process.""" + self._socket_lock.release() self._start_flush_thread() self._start_sender_thread() self._config_lock.release() def post_fork_child(self): """Restore the client state after a fork in the child process.""" + self._socket_lock.release() self._config_lock.release() # Discard the locks that could have been locked at the time # when we forked. This may cause inconsistent internal state, # which we will fix in the next steps. - self._socket_lock = Lock() self._buffer_lock = RLock() # Reset the buffer so we don't send metrics from the parent # process. Also makes sure buffer properties are consistent. self._reset_buffer() - # Execute the socket_path setter to reconcile transport and - # payload size properties in respect to socket_path value. - self.socket_path = self.socket_path - self.close_socket() with self._config_lock: self._start_flush_thread() diff --git a/tests/integration/dogstatsd/test_statsd_fork.py b/tests/integration/dogstatsd/test_statsd_fork.py index c856376e0..49337a827 100644 --- a/tests/integration/dogstatsd/test_statsd_fork.py +++ b/tests/integration/dogstatsd/test_statsd_fork.py @@ -2,11 +2,14 @@ import itertools import socket import threading +import logging +import time import pytest from datadog.dogstatsd.base import DogStatsd, SUPPORTS_FORKING +logging.getLogger("datadog.dogstatsd").setLevel(logging.FATAL) @pytest.mark.parametrize( "disable_background_sender, disable_buffering", @@ -47,18 +50,20 @@ def inner(*args, **kwargs): def sender_a(statsd, running): while running[0]: statsd.gauge("spam", 1) + time.sleep(0) -def sender_b(statsd, signal): +def sender_b(statsd, running): while running[0]: with statsd: statsd.gauge("spam", 1) + time.sleep(0) @pytest.mark.parametrize( - "disable_background_sender, disable_buffering, sender", + "disable_background_sender, disable_buffering, sender_fn", list(itertools.product([True, False], [True, False], [sender_a, sender_b])), ) -def test_fork_with_thread(disable_background_sender, disable_buffering, sender): +def test_fork_with_thread(disable_background_sender, disable_buffering, sender_fn): if not SUPPORTS_FORKING: pytest.skip("os.register_at_fork is required for this test") @@ -71,12 +76,13 @@ def test_fork_with_thread(disable_background_sender, disable_buffering, sender): sender = None try: sender_running = [True] - sender = threading.Thread(target=sender, args=(statsd, sender_running)) + sender = threading.Thread(target=sender_fn, args=(statsd, sender_running)) sender.daemon = True sender.start() pid = os.fork() if pid == 0: + statsd.gauge("spam", 2) os._exit(42) assert pid > 0 @@ -84,7 +90,7 @@ def test_fork_with_thread(disable_background_sender, disable_buffering, sender): assert os.WEXITSTATUS(status) == 42 finally: - statsd.stop() if sender: sender_running[0] = False sender.join() + statsd.stop() diff --git a/tests/unit/dogstatsd/test_statsd.py b/tests/unit/dogstatsd/test_statsd.py index 02edcdf4c..8949f6431 100644 --- a/tests/unit/dogstatsd/test_statsd.py +++ b/tests/unit/dogstatsd/test_statsd.py @@ -2056,16 +2056,3 @@ def test_max_payload_size(self): self.assertEqual(statsd._max_payload_size, UDP_OPTIMAL_PAYLOAD_LENGTH) statsd.socket_path = "/foo" self.assertEqual(statsd._max_payload_size, UDS_OPTIMAL_PAYLOAD_LENGTH) - - def test_post_fork_locks(self): - def inner(): - statsd = DogStatsd(socket_path=None, port=8125) - # Statsd should survive this sequence of events - statsd.pre_fork() - statsd.get_socket() - statsd.post_fork_parent() - t = Thread(target=inner) - t.daemon = True - t.start() - t.join(timeout=5) - self.assertFalse(t.is_alive())