Skip to content

Commit f685fbe

Browse files
authored
(torchx/workspace) Support multi-project/directory workspace
Differential Revision: D82169554 Pull Request resolved: #1114
1 parent 3f9e19e commit f685fbe

File tree

10 files changed

+410
-62
lines changed

10 files changed

+410
-62
lines changed

torchx/cli/cmd_run.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from torchx.util.log_tee_helpers import tee_logs
3838
from torchx.util.types import none_throws
39+
from torchx.workspace import Workspace
3940

4041

4142
MISSING_COMPONENT_ERROR_MSG = (
@@ -92,7 +93,7 @@ def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:
9293

9394
torchx_args = TorchXRunArgs(**filtered_json_data)
9495
if torchx_args.workspace == "":
95-
torchx_args.workspace = f"file://{Path.cwd()}"
96+
torchx_args.workspace = f"{Path.cwd()}"
9697
return torchx_args
9798

9899

@@ -250,7 +251,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
250251
subparser.add_argument(
251252
"--workspace",
252253
"--buck-target",
253-
default=f"file://{Path.cwd()}",
254+
default=f"{Path.cwd()}",
254255
action=torchxconfig_run,
255256
help="local workspace to build/patch (buck-target of main binary if using buck)",
256257
)
@@ -289,12 +290,14 @@ def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None:
289290
else args.component_args
290291
)
291292
try:
293+
workspace = Workspace.from_str(args.workspace) if args.workspace else None
294+
292295
if args.dryrun:
293296
dryrun_info = runner.dryrun_component(
294297
args.component_name,
295298
component_args,
296299
args.scheduler,
297-
workspace=args.workspace,
300+
workspace=workspace,
298301
cfg=args.scheduler_cfg,
299302
parent_run_id=args.parent_run_id,
300303
)

torchx/cli/test/cmd_run_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:
401401

402402
def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
403403
"""Test that arguments with values conflict with stdin."""
404-
args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"])
404+
args = self.parser.parse_args(["--stdin", "--workspace", "/custom/path"])
405405
with self.assertRaises(SystemExit):
406406
self.cmd_run.verify_no_extra_args(args)
407407

@@ -499,7 +499,7 @@ def test_torchx_run_args_from_json(self) -> None:
499499
self.assertEqual(result.dryrun, False)
500500
self.assertEqual(result.wait, False)
501501
self.assertEqual(result.log, False)
502-
self.assertEqual(result.workspace, f"file://{Path.cwd()}")
502+
self.assertEqual(result.workspace, f"{Path.cwd()}")
503503
self.assertEqual(result.parent_run_id, None)
504504
self.assertEqual(result.tee_logs, False)
505505
self.assertEqual(result.component_args, {})
@@ -515,7 +515,7 @@ def test_torchx_run_args_from_json(self) -> None:
515515
"dryrun": True,
516516
"wait": True,
517517
"log": True,
518-
"workspace": "file:///custom/path",
518+
"workspace": "/custom/path",
519519
"parent_run_id": "parent123",
520520
"tee_logs": True,
521521
}
@@ -529,7 +529,7 @@ def test_torchx_run_args_from_json(self) -> None:
529529
self.assertEqual(result2.dryrun, True)
530530
self.assertEqual(result2.wait, True)
531531
self.assertEqual(result2.log, True)
532-
self.assertEqual(result2.workspace, "file:///custom/path")
532+
self.assertEqual(result2.workspace, "/custom/path")
533533
self.assertEqual(result2.parent_run_id, "parent123")
534534
self.assertEqual(result2.tee_logs, True)
535535

@@ -626,7 +626,7 @@ def test_torchx_run_args_from_argparse(self) -> None:
626626
args.dryrun = True
627627
args.wait = False
628628
args.log = True
629-
args.workspace = "file:///custom/workspace"
629+
args.workspace = "/custom/workspace"
630630
args.parent_run_id = "parent_123"
631631
args.tee_logs = False
632632

@@ -654,7 +654,7 @@ def test_torchx_run_args_from_argparse(self) -> None:
654654
self.assertEqual(result.dryrun, True)
655655
self.assertEqual(result.wait, False)
656656
self.assertEqual(result.log, True)
657-
self.assertEqual(result.workspace, "file:///custom/workspace")
657+
self.assertEqual(result.workspace, "/custom/workspace")
658658
self.assertEqual(result.parent_run_id, "parent_123")
659659
self.assertEqual(result.tee_logs, False)
660660
self.assertEqual(result.component_args, {})

torchx/runner/api.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from torchx.util.session import get_session_id_or_create_new, TORCHX_INTERNAL_SESSION_ID
5555

5656
from torchx.util.types import none_throws
57-
from torchx.workspace.api import WorkspaceMixin
57+
from torchx.workspace.api import Workspace, WorkspaceMixin
5858

5959
if TYPE_CHECKING:
6060
from typing_extensions import Self
@@ -171,7 +171,7 @@ def run_component(
171171
component_args: Union[list[str], dict[str, Any]],
172172
scheduler: str,
173173
cfg: Optional[Mapping[str, CfgVal]] = None,
174-
workspace: Optional[str] = None,
174+
workspace: Optional[Union[Workspace, str]] = None,
175175
parent_run_id: Optional[str] = None,
176176
) -> AppHandle:
177177
"""
@@ -206,7 +206,7 @@ def run_component(
206206
ComponentNotFoundException: if the ``component_path`` is failed to resolve.
207207
"""
208208

209-
with log_event("run_component", workspace=workspace) as ctx:
209+
with log_event("run_component") as ctx:
210210
dryrun_info = self.dryrun_component(
211211
component,
212212
component_args,
@@ -217,7 +217,8 @@ def run_component(
217217
)
218218
handle = self.schedule(dryrun_info)
219219
app = none_throws(dryrun_info._app)
220-
ctx._torchx_event.workspace = workspace
220+
221+
ctx._torchx_event.workspace = str(workspace)
221222
ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
222223
ctx._torchx_event.app_image = app.roles[0].image
223224
ctx._torchx_event.app_id = parse_app_handle(handle)[2]
@@ -230,7 +231,7 @@ def dryrun_component(
230231
component_args: Union[list[str], dict[str, Any]],
231232
scheduler: str,
232233
cfg: Optional[Mapping[str, CfgVal]] = None,
233-
workspace: Optional[str] = None,
234+
workspace: Optional[Union[Workspace, str]] = None,
234235
parent_run_id: Optional[str] = None,
235236
) -> AppDryRunInfo:
236237
"""
@@ -259,7 +260,7 @@ def run(
259260
app: AppDef,
260261
scheduler: str,
261262
cfg: Optional[Mapping[str, CfgVal]] = None,
262-
workspace: Optional[str] = None,
263+
workspace: Optional[Union[Workspace, str]] = None,
263264
parent_run_id: Optional[str] = None,
264265
) -> AppHandle:
265266
"""
@@ -272,9 +273,7 @@ def run(
272273
An application handle that is used to call other action APIs on the app.
273274
"""
274275

275-
with log_event(
276-
api="run", runcfg=json.dumps(cfg) if cfg else None, workspace=workspace
277-
) as ctx:
276+
with log_event(api="run") as ctx:
278277
dryrun_info = self.dryrun(
279278
app,
280279
scheduler,
@@ -283,10 +282,15 @@ def run(
283282
parent_run_id=parent_run_id,
284283
)
285284
handle = self.schedule(dryrun_info)
286-
ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
287-
ctx._torchx_event.app_image = none_throws(dryrun_info._app).roles[0].image
288-
ctx._torchx_event.app_id = parse_app_handle(handle)[2]
289-
ctx._torchx_event.app_metadata = app.metadata
285+
286+
event = ctx._torchx_event
287+
event.scheduler = scheduler
288+
event.runcfg = json.dumps(cfg) if cfg else None
289+
event.workspace = str(workspace)
290+
event.app_id = parse_app_handle(handle)[2]
291+
event.app_image = none_throws(dryrun_info._app).roles[0].image
292+
event.app_metadata = app.metadata
293+
290294
return handle
291295

292296
def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle:
@@ -320,21 +324,22 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle:
320324
321325
"""
322326
scheduler = none_throws(dryrun_info._scheduler)
323-
app_image = none_throws(dryrun_info._app).roles[0].image
324327
cfg = dryrun_info._cfg
325-
with log_event(
326-
"schedule",
327-
scheduler,
328-
app_image=app_image,
329-
runcfg=json.dumps(cfg) if cfg else None,
330-
) as ctx:
328+
with log_event("schedule") as ctx:
331329
sched = self._scheduler(scheduler)
332330
app_id = sched.schedule(dryrun_info)
333331
app_handle = make_app_handle(scheduler, self._name, app_id)
332+
334333
app = none_throws(dryrun_info._app)
335334
self._apps[app_handle] = app
336-
_, _, app_id = parse_app_handle(app_handle)
337-
ctx._torchx_event.app_id = app_id
335+
336+
event = ctx._torchx_event
337+
event.scheduler = scheduler
338+
event.runcfg = json.dumps(cfg) if cfg else None
339+
event.app_id = app_id
340+
event.app_image = none_throws(dryrun_info._app).roles[0].image
341+
event.app_metadata = app.metadata
342+
338343
return app_handle
339344

340345
def name(self) -> str:
@@ -345,7 +350,7 @@ def dryrun(
345350
app: AppDef,
346351
scheduler: str,
347352
cfg: Optional[Mapping[str, CfgVal]] = None,
348-
workspace: Optional[str] = None,
353+
workspace: Optional[Union[Workspace, str]] = None,
349354
parent_run_id: Optional[str] = None,
350355
) -> AppDryRunInfo:
351356
"""
@@ -414,7 +419,7 @@ def dryrun(
414419
"dryrun",
415420
scheduler,
416421
runcfg=json.dumps(cfg) if cfg else None,
417-
workspace=workspace,
422+
workspace=str(workspace),
418423
):
419424
sched = self._scheduler(scheduler)
420425
resolved_cfg = sched.run_opts().resolve(cfg)
@@ -429,7 +434,7 @@ def dryrun(
429434
logger.info(
430435
'To disable workspaces pass: --workspace="" from CLI or workspace=None programmatically.'
431436
)
432-
sched.build_workspace_and_update_role(role, workspace, resolved_cfg)
437+
sched.build_workspace_and_update_role2(role, workspace, resolved_cfg)
433438

434439
if old_img != role.image:
435440
logger.info(

torchx/runner/events/__init__.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333

3434
from .api import SourceType, TorchxEvent # noqa F401
3535

36-
# pyre-fixme[9]: _events_logger is a global variable
37-
_events_logger: logging.Logger = None
36+
_events_logger: Optional[logging.Logger] = None
37+
38+
log: logging.Logger = logging.getLogger(__name__)
3839

3940

4041
def _get_or_create_logger(destination: str = "null") -> logging.Logger:
@@ -51,19 +52,28 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger:
5152
a new logger if None provided.
5253
"""
5354
global _events_logger
55+
5456
if _events_logger:
5557
return _events_logger
56-
logging_handler = get_logging_handler(destination)
57-
logging_handler.setLevel(logging.DEBUG)
58-
_events_logger = logging.getLogger(f"torchx-events-{destination}")
59-
# Do not propagate message to the root logger
60-
_events_logger.propagate = False
61-
_events_logger.addHandler(logging_handler)
62-
return _events_logger
58+
else:
59+
logging_handler = get_logging_handler(destination)
60+
logging_handler.setLevel(logging.DEBUG)
61+
_events_logger = logging.getLogger(f"torchx-events-{destination}")
62+
# Do not propagate message to the root logger
63+
_events_logger.propagate = False
64+
_events_logger.addHandler(logging_handler)
65+
66+
assert _events_logger # make type-checker happy
67+
return _events_logger
6368

6469

6570
def record(event: TorchxEvent, destination: str = "null") -> None:
66-
_get_or_create_logger(destination).info(event.serialize())
71+
try:
72+
serialized_event = event.serialize()
73+
except Exception:
74+
log.exception("failed to serialize event, will not record event")
75+
else:
76+
_get_or_create_logger(destination).info(serialized_event)
6777

6878

6979
class log_event:

torchx/runner/events/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class TorchxEvent:
2929
scheduler: Scheduler that is used to execute request
3030
api: Api name
3131
app_id: Unique id that is set by the underlying scheduler
32-
image: Image/container bundle that is used to execute request.
32+
app_image: Image/container bundle that is used to execute request.
3333
app_metadata: metadata to the app (treatment of metadata is scheduler dependent)
3434
runcfg: Run config that was used to schedule app.
3535
source: Type of source the event is generated.

torchx/runner/test/config_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Stream
2828
from torchx.specs import AppDef, AppDryRunInfo, CfgVal, runopts
2929
from torchx.test.fixtures import TestWithTmpDir
30+
from torchx.workspace import Workspace
3031

3132

3233
class TestScheduler(Scheduler):
@@ -506,3 +507,31 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:
506507
opt_name in cfg,
507508
f"missing {opt_name} in {sched} run opts with cfg {cfg}",
508509
)
510+
511+
def test_get_workspace_config(self) -> None:
512+
configdir = self.tmpdir
513+
self.write(
514+
str(configdir / ".torchxconfig"),
515+
"""#
516+
[cli:run]
517+
workspace =
518+
/home/foo/third-party/verl: verl
519+
/home/foo/bar/scripts/.torchxconfig: verl/.torchxconfig
520+
/home/foo/baz:
521+
""",
522+
)
523+
524+
workspace_config = get_config(
525+
prefix="cli", name="run", key="workspace", dirs=[str(configdir)]
526+
)
527+
self.assertIsNotNone(workspace_config)
528+
529+
workspace = Workspace.from_str(workspace_config)
530+
self.assertDictEqual(
531+
{
532+
"/home/foo/third-party/verl": "verl",
533+
"/home/foo/bar/scripts/.torchxconfig": "verl/.torchxconfig",
534+
"/home/foo/baz": "",
535+
},
536+
workspace.projects,
537+
)

0 commit comments

Comments
 (0)