Skip to content

Commit f23893b

Browse files
committed
worker: move clean to after worker_main
1 parent e07e276 commit f23893b

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

src/dvc_task/worker/temporary.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import threading
55
import time
6-
from typing import Any, List, Mapping
6+
from typing import Any, Dict, List, Mapping, Optional
77

88
from celery import Celery
99
from celery.utils.nodenames import default_nodename
@@ -36,6 +36,15 @@ def __init__( # pylint: disable=too-many-arguments
3636
self.timeout = timeout
3737
self.config = kwargs
3838

39+
def ping(self, name: str, timeout: float = 1.0) -> Optional[List[Dict[str, Any]]]:
40+
"""Ping the specified worker."""
41+
return self._ping(destination=[default_nodename(name)], timeout=timeout)
42+
43+
def _ping(
44+
self, *, destination: Optional[List[str]] = None, timeout: float = 1.0
45+
) -> Optional[List[Dict[str, Any]]]:
46+
return self.app.control.ping(destination=destination, timeout=timeout)
47+
3948
def start(self, name: str, fsapp_clean: bool = False) -> None:
4049
"""Start the worker if it does not already exist.
4150
@@ -50,19 +59,22 @@ def start(self, name: str, fsapp_clean: bool = False) -> None:
5059
# see https://github.com/celery/billiard/issues/247
5160
os.environ["FORKED_BY_MULTIPROCESSING"] = "1"
5261

53-
if not self.app.control.ping(destination=[name]):
62+
if not self.ping(name):
5463
monitor = threading.Thread(
5564
target=self.monitor,
5665
daemon=True,
5766
args=(name,),
58-
kwargs={"fsapp_clean": fsapp_clean},
5967
)
6068
monitor.start()
6169
config = dict(self.config)
6270
config["hostname"] = name
6371
argv = ["worker"]
6472
argv.extend(self._parse_config(config))
6573
self.app.worker_main(argv=argv)
74+
if fsapp_clean and isinstance(self.app, FSApp): # type: ignore[unreachable]
75+
logger.info("cleaning up FSApp broker.")
76+
self.app.clean()
77+
logger.info("done")
6678

6779
@staticmethod
6880
def _parse_config(config: Mapping[str, Any]) -> List[str]:
@@ -85,13 +97,9 @@ def _parse_config(config: Mapping[str, Any]) -> List[str]:
8597
argv.append("-E")
8698
return argv
8799

88-
def monitor(self, name: str, fsapp_clean: bool = False) -> None:
100+
def monitor(self, name: str) -> None:
89101
"""Monitor the worker and stop it when the queue is empty."""
90-
logger.debug("monitor: waiting for worker to start")
91102
nodename = default_nodename(name)
92-
while not self.app.control.ping(destination=[nodename]):
93-
# wait for worker to start
94-
time.sleep(1)
95103

96104
def _tasksets(nodes):
97105
for taskset in (
@@ -105,17 +113,16 @@ def _tasksets(nodes):
105113
if isinstance(self.app, FSApp):
106114
yield from self.app.iter_queued()
107115

108-
logger.info("monitor: watching celery worker '%s'", nodename)
109-
while self.app.control.ping(destination=[nodename]):
116+
logger.debug("monitor: watching celery worker '%s'", nodename)
117+
while True:
110118
time.sleep(self.timeout)
111119
nodes = self.app.control.inspect( # type: ignore[call-arg]
112-
destination=[nodename]
120+
destination=[nodename],
121+
limit=1,
113122
)
114123
if nodes is None or not any(tasks for tasks in _tasksets(nodes)):
115124
logger.info("monitor: shutting down due to empty queue.")
116-
self.app.control.shutdown(destination=[nodename])
117125
break
118-
if fsapp_clean and isinstance(self.app, FSApp):
119-
logger.info("monitor: cleanup FSApp broker.")
120-
self.app.clean()
121-
logger.info("monitor: done")
126+
logger.debug("monitor: sending shutdown to '%s'.", nodename)
127+
self.app.control.shutdown()
128+
logger.debug("monitor: done")

tests/worker/test_temporary.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ def test_start(celery_app: Celery, mocker: MockerFixture):
2424
assert kwargs["pool"] == TaskPool
2525
assert kwargs["concurrency"] == 1
2626
assert kwargs["prefetch_multiplier"] == 1
27-
thread.assert_called_once_with(
28-
target=worker.monitor, daemon=True, args=(name,), kwargs={"fsapp_clean": False}
29-
)
27+
thread.assert_called_once_with(target=worker.monitor, daemon=True, args=(name,))
3028

3129

3230
@pytest.mark.flaky(

0 commit comments

Comments
 (0)