|
20 | 20 | import time |
21 | 21 | import uuid |
22 | 22 | from abc import ABC, abstractmethod |
23 | | -from typing import Any, Dict, List, Optional |
| 23 | +from typing import Any, Dict, List, Optional, Union |
24 | 24 |
|
25 | 25 | from nvflare.apis.fl_context import FLContext |
26 | 26 | from nvflare.apis.job_def import Job, JobDataKey, JobMetaKey, job_from_meta |
|
30 | 30 | from nvflare.fuel.utils import fobs |
31 | 31 | from nvflare.fuel.utils.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes |
32 | 32 |
|
| 33 | +_OBJ_TAG_SCHEDULED = "scheduled" |
| 34 | + |
| 35 | + |
| 36 | +class JobInfo: |
| 37 | + def __init__(self, meta: dict, job_id: str, uri: str): |
| 38 | + self.meta = meta |
| 39 | + self.job_id = job_id |
| 40 | + self.uri = uri |
| 41 | + |
33 | 42 |
|
34 | 43 | class _JobFilter(ABC): |
35 | 44 | @abstractmethod |
36 | | - def filter_job(self, meta: dict) -> bool: |
| 45 | + def filter_job(self, info: JobInfo) -> bool: |
37 | 46 | pass |
38 | 47 |
|
39 | 48 |
|
40 | 49 | class _StatusFilter(_JobFilter): |
41 | 50 | def __init__(self, status_to_check): |
42 | 51 | self.result = [] |
| 52 | + if not isinstance(status_to_check, list): |
| 53 | + # turning to list |
| 54 | + status_to_check = [status_to_check] |
43 | 55 | self.status_to_check = status_to_check |
44 | 56 |
|
45 | | - def filter_job(self, meta: dict): |
46 | | - if meta[JobMetaKey.STATUS] == self.status_to_check: |
47 | | - self.result.append(job_from_meta(meta)) |
| 57 | + def filter_job(self, info: JobInfo): |
| 58 | + status = info.meta.get(JobMetaKey.STATUS.value) |
| 59 | + if status in self.status_to_check: |
| 60 | + self.result.append(job_from_meta(info.meta)) |
48 | 61 | return True |
49 | 62 |
|
50 | 63 |
|
51 | 64 | class _AllJobsFilter(_JobFilter): |
52 | 65 | def __init__(self): |
53 | 66 | self.result = [] |
54 | 67 |
|
55 | | - def filter_job(self, meta: dict): |
56 | | - self.result.append(job_from_meta(meta)) |
| 68 | + def filter_job(self, info: JobInfo): |
| 69 | + self.result.append(job_from_meta(info.meta)) |
57 | 70 | return True |
58 | 71 |
|
59 | 72 |
|
60 | 73 | class _ReviewerFilter(_JobFilter): |
61 | | - def __init__(self, reviewer_name, fl_ctx: FLContext): |
| 74 | + def __init__(self, reviewer_name): |
62 | 75 | """Not used yet, for use in future implementations.""" |
63 | 76 | self.result = [] |
64 | 77 | self.reviewer_name = reviewer_name |
65 | 78 |
|
66 | | - def filter_job(self, meta: dict): |
67 | | - approvals = meta.get(JobMetaKey.APPROVALS) |
| 79 | + def filter_job(self, info: JobInfo): |
| 80 | + approvals = info.meta.get(JobMetaKey.APPROVALS.value) |
68 | 81 | if not approvals or self.reviewer_name not in approvals: |
69 | | - self.result.append(job_from_meta(meta)) |
| 82 | + self.result.append(job_from_meta(info.meta)) |
70 | 83 | return True |
71 | 84 |
|
72 | 85 |
|
73 | | -# TODO:: use try block around storage calls |
| 86 | +class _ScheduleJobFilter(_JobFilter): |
| 87 | + |
| 88 | + """ |
| 89 | + This filter is optimized for selecting jobs to schedule since it is used so frequently (every 1 sec). |
| 90 | + """ |
| 91 | + |
| 92 | + def __init__(self, store): |
| 93 | + self.store = store |
| 94 | + self.result = [] |
| 95 | + |
| 96 | + def filter_job(self, info: JobInfo): |
| 97 | + status = info.meta.get(JobMetaKey.STATUS.value) |
| 98 | + if status == RunStatus.SUBMITTED.value: |
| 99 | + self.result.append(job_from_meta(info.meta)) |
| 100 | + elif status: |
| 101 | + # skip this job in all future calls (so the meta file of this job won't be read) |
| 102 | + self.store.tag_object(uri=info.uri, tag=_OBJ_TAG_SCHEDULED) |
| 103 | + return True |
74 | 104 |
|
75 | 105 |
|
76 | 106 | class SimpleJobDefManager(JobDefManagerSpec): |
@@ -239,28 +269,40 @@ def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]: |
239 | 269 | self._scan(job_filter, fl_ctx) |
240 | 270 | return job_filter.result |
241 | 271 |
|
242 | | - def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext): |
| 272 | + def get_jobs_to_schedule(self, fl_ctx: FLContext) -> List[Job]: |
| 273 | + job_filter = _ScheduleJobFilter(self._get_job_store(fl_ctx)) |
| 274 | + self._scan(job_filter, fl_ctx, skip_tag=_OBJ_TAG_SCHEDULED) |
| 275 | + return job_filter.result |
| 276 | + |
| 277 | + def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext, skip_tag=None): |
243 | 278 | store = self._get_job_store(fl_ctx) |
244 | | - jid_paths = store.list_objects(self.uri_root) |
245 | | - if not jid_paths: |
| 279 | + obj_uris = store.list_objects(self.uri_root, without_tag=skip_tag) |
| 280 | + self.log_debug(fl_ctx, f"objects to scan: {len(obj_uris)}") |
| 281 | + if not obj_uris: |
246 | 282 | return |
247 | 283 |
|
248 | | - for jid_path in jid_paths: |
249 | | - jid = pathlib.PurePath(jid_path).name |
250 | | - |
251 | | - meta = store.get_meta(self.job_uri(jid)) |
| 284 | + for uri in obj_uris: |
| 285 | + jid = pathlib.PurePath(uri).name |
| 286 | + job_uri = self.job_uri(jid) |
| 287 | + meta = store.get_meta(job_uri) |
252 | 288 | if meta: |
253 | | - ok = job_filter.filter_job(meta) |
| 289 | + ok = job_filter.filter_job(JobInfo(meta, jid, job_uri)) |
254 | 290 | if not ok: |
255 | 291 | break |
256 | 292 |
|
257 | | - def get_jobs_by_status(self, status, fl_ctx: FLContext) -> List[Job]: |
| 293 | + def get_jobs_by_status(self, status: Union[RunStatus, List[RunStatus]], fl_ctx: FLContext) -> List[Job]: |
| 294 | + """Get jobs that are in the specified status |
| 295 | + Args: |
| 296 | + status: a single status value or a list of status values |
| 297 | + fl_ctx: the FL context |
| 298 | + Returns: list of jobs that are in specified status |
| 299 | + """ |
258 | 300 | job_filter = _StatusFilter(status) |
259 | 301 | self._scan(job_filter, fl_ctx) |
260 | 302 | return job_filter.result |
261 | 303 |
|
262 | 304 | def get_jobs_waiting_for_review(self, reviewer_name: str, fl_ctx: FLContext) -> List[Job]: |
263 | | - job_filter = _ReviewerFilter(reviewer_name, fl_ctx) |
| 305 | + job_filter = _ReviewerFilter(reviewer_name) |
264 | 306 | self._scan(job_filter, fl_ctx) |
265 | 307 | return job_filter.result |
266 | 308 |
|
|
0 commit comments