From 4675d75b813b6db3b92717b55691c55e1c4c8def Mon Sep 17 00:00:00 2001 From: Tianyi Chen Date: Wed, 5 Nov 2025 10:54:48 +0800 Subject: [PATCH 1/2] Add pkill when timeout in worker stopping. (#1659) * add pkill when stopping worker timeout * debug ut * fix ut --- .../python/elastic_agent/torch/training.py | 20 +++++++- .../tests/test_elastic_training_agent.py | 49 ++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index d69582b22..ae945790c 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -17,6 +17,7 @@ import shutil import signal import socket +import subprocess import sys import tempfile import time @@ -961,7 +962,24 @@ def _stop_workers( signal.alarm(0) def _stop_timeout_handler(self, signum, frame): - raise StopWorkerTimeoutError("Timed out waiting for stopping workers.") + logger.warning( + "Use pkill to kill all sub-processes in 'stop_timeout_handler'." + ) + try: + subprocess.run( + ["pkill", "-9", "-g", str(os.getpgid(os.getpid()))], + capture_output=True, + text=True, + timeout=10, + ) + except Exception as e: + logger.error( + f"Unexpected error in stop_timeout_handler when killing process: {e}" + ) + + raise StopWorkerTimeoutError( + "Timed out waiting for stopping workers, forcefully kill all sub-processes." + ) def _set_numa_affinity(self): """set numa affinity to workers processes, diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 893504773..9f26b7d55 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -1081,7 +1081,8 @@ def test_orphan_workers(self): self.assertTrue(orphan_killed) - def test_stop_workers(self): + @patch("subprocess.run") + def test_stop_workers(self, mock_run): agent = ElasticTrainingAgent( node_rank=0, config=self.config, @@ -1099,6 +1100,7 @@ def sleep_10_seconds(*args, **kwargs): time.sleep(10) # with timeout + mock_run.return_value = MagicMock(returncode=0, stderr="") with patch.object( LocalElasticAgent, "_stop_workers", side_effect=sleep_10_seconds ): @@ -1117,6 +1119,51 @@ def sleep_10_seconds(*args, **kwargs): except StopWorkerTimeoutError: self.assertTrue(True) + @patch("os.getpgid") + @patch("subprocess.run") + def test_stop_timeout_handler_pkill(self, mock_run, mock_getpgid): + """Test _stop_timeout_handler with pkill implementation""" + agent = ElasticTrainingAgent( + node_rank=0, + config=self.config, + entrypoint="echo", + spec=self.spec, + start_method=self.config.start_method, + log_dir=self.config.log_dir, + exit_barrier_timeout=1, + ) + + # Mock getpgid to return a safe process group ID + mock_getpgid.return_value = 9999 + mock_run.return_value = MagicMock(returncode=0, stderr="") + + with self.assertRaises(StopWorkerTimeoutError): + agent._stop_timeout_handler(signal.SIGALRM, None) + + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + self.assertIn("pkill", args[0]) + self.assertIn("-9", args[0]) + self.assertIn("9999", str(args[0])) + + mock_run.return_value = MagicMock( + returncode=1, stderr="permission denied" + ) + with self.assertRaises(StopWorkerTimeoutError): + agent._stop_timeout_handler(signal.SIGALRM, None) + + mock_run.side_effect = subprocess.TimeoutExpired("pkill", 5) + with self.assertRaises(StopWorkerTimeoutError): + agent._stop_timeout_handler(signal.SIGALRM, None) + + mock_run.side_effect = subprocess.CalledProcessError(1, "pkill") + with self.assertRaises(StopWorkerTimeoutError): + agent._stop_timeout_handler(signal.SIGALRM, None) + + mock_run.side_effect = Exception("unexpected error") + with self.assertRaises(StopWorkerTimeoutError): + agent._stop_timeout_handler(signal.SIGALRM, None) + def test_diagnosis(self): agent = ElasticTrainingAgent( node_rank=0, From ee1700df9c254914734f5f6e798da3ae4954571a Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Tue, 11 Nov 2025 17:21:54 +0800 Subject: [PATCH 2/2] set trace in debug logging --- dlrover/python/util/function_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlrover/python/util/function_util.py b/dlrover/python/util/function_util.py index 50deb10bc..a535ae1fe 100644 --- a/dlrover/python/util/function_util.py +++ b/dlrover/python/util/function_util.py @@ -104,9 +104,9 @@ def wrapped(*args, **kwargs): type(e), e, e.__traceback__, limit=3 ) logger.warning( - f"Retry {i} to {class_name}.{func_name} with failure {e}, ", - f"with traceback {tb}", + f"Retry {i} to {class_name}.{func_name} with failure {e}" ) + logger.debug(f"Caused traceback: {tb}") exception = e time.sleep(retry_interval) if exception: