diff --git a/monarch_rdma/extension/lib.rs b/monarch_rdma/extension/lib.rs index 6edbc97be..10b68a341 100644 --- a/monarch_rdma/extension/lib.rs +++ b/monarch_rdma/extension/lib.rs @@ -89,6 +89,26 @@ impl PyRdmaBuffer { )) } + #[classmethod] + fn create_rdma_buffer_blocking<'py>( + _cls: &Bound<'_, PyType>, + py: Python<'py>, + addr: usize, + size: usize, + proc_id: String, + client: PyMailbox, + ) -> PyResult { + if !ibverbs_supported() { + return Err(PyException::new_err( + "ibverbs is not supported on this system", + )); + } + signal_safe_block_on( + py, + create_rdma_buffer(addr, size, proc_id.parse().unwrap(), client), + )? + } + #[classmethod] fn rdma_supported<'py>(_cls: &Bound<'_, PyType>, _py: Python<'py>) -> bool { ibverbs_supported() diff --git a/python/monarch/_rust_bindings/rdma/__init__.pyi b/python/monarch/_rust_bindings/rdma/__init__.pyi index 624128d0f..baaf2ea78 100644 --- a/python/monarch/_rust_bindings/rdma/__init__.pyi +++ b/python/monarch/_rust_bindings/rdma/__init__.pyi @@ -23,6 +23,10 @@ async def create_rdma_manager_nonblocking(proc_mesh: Any) -> Optional[_RdmaManag class _RdmaBuffer: name: str + @classmethod + def create_rdma_buffer_blocking( + cls, addr: int, size: int, proc_id: str, client: Any + ) -> _RdmaBuffer: ... @classmethod def create_rdma_buffer_nonblocking( cls, addr: int, size: int, proc_id: str, client: Any diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 7765d0434..89163ff5c 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -6,6 +6,7 @@ # pyre-strict +import logging import os import sys import warnings @@ -24,7 +25,6 @@ ) from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient - from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension Alloc, AllocConstraints, @@ -63,13 +63,14 @@ HAS_TENSOR_ENGINE = False try: - # TODO: while the tensor_engine submodule doesn't exist yet, use the - # available of monarch.rdma as a proxy. - # type: ignore - from monarch.rdma import RDMAManager # @manual + from monarch._rust_bindings.rdma import ( # type: ignore[import] + _RdmaManager, + create_rdma_manager_blocking, + ) HAS_TENSOR_ENGINE = True except ImportError: + logging.warning("RDMA is not available on this platform") pass @@ -102,7 +103,7 @@ def __init__( self._proc_mesh = hy_proc_mesh self._mock_shape: Optional[Shape] = _mock_shape # type: ignore[21] - self._rdma_manager: Optional["RDMAManager"] = None + self._rdma_manager: Optional["_RdmaManager"] = None self._debug_manager: Optional[DebugManager] = None self._mailbox: Mailbox = self._proc_mesh.client self._code_sync_client: Optional[CodeSyncMeshClient] = None @@ -117,7 +118,7 @@ def __init__( with fake_sync_state(): if _mock_shape is None and HAS_TENSOR_ENGINE: # type: ignore[21] - self._rdma_manager = self.spawn("rdma_manager", RDMAManager).get() + self._rdma_manager = create_rdma_manager_blocking(self._proc_mesh) if not _is_initializing_debugger and _mock_shape is None: self._debug_manager = self.spawn( _DEBUG_MANAGER_ACTOR_NAME, DebugManager, debug_client() diff --git a/python/monarch/_src/tensor_engine/rdma.py b/python/monarch/_src/tensor_engine/rdma.py index 11f504c63..a18b9a1b7 100644 --- a/python/monarch/_src/tensor_engine/rdma.py +++ b/python/monarch/_src/tensor_engine/rdma.py @@ -9,12 +9,15 @@ from typing import Optional import torch -from monarch._rust_bindings.rdma import _RdmaBuffer +try: + from monarch._rust_bindings.rdma import _RdmaBuffer +except ImportError as e: + logging.error("RDMA is not available: {}".format(e)) + raise e +from monarch._src.actor.actor_mesh import MonarchContext from monarch._src.actor.future import Future -from monarch.actor import MonarchContext - # RDMARead/WriteTransferWarnings are warnings that are only printed once per process. # Remove these once GPU support is added. @@ -30,7 +33,7 @@ class RDMAWriteTransferWarning(Warning): warnings.simplefilter("once", RDMAWriteTransferWarning) -def rdma_supported(): +def is_available(): return _RdmaBuffer.rdma_supported() @@ -52,7 +55,9 @@ def __init__(self, data: torch.Tensor) -> None: TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors """ - assert _RdmaBuffer.rdma_supported() + assert ( + is_available() + ), "Tried to create an RDMABuffer, but RDMA is not available on this platform." if data.device.type != "cpu": # TODO - CUDA support for RDMABuffer exists at the Rust layer, but @@ -72,16 +77,12 @@ def __init__(self, data: torch.Tensor) -> None: addr: int = storage.data_ptr() size = storage.element_size() * data.numel() ctx = MonarchContext.get() - f = Future( - impl=lambda: _RdmaBuffer.create_rdma_buffer_nonblocking( - addr=addr, - size=size, - proc_id=ctx.proc_id, - client=ctx.mailbox, - ), - requires_loop=False, + self._buffer: _RdmaBuffer = _RdmaBuffer.create_rdma_buffer_blocking( + addr=addr, + size=size, + proc_id=ctx.proc_id, + client=ctx.mailbox, ) - self._buffer: _RdmaBuffer = f.get() # TODO - specific exception except Exception as e: logging.error("Failed to create buffer %s", e) diff --git a/python/monarch/rdma.py b/python/monarch/rdma.py deleted file mode 100644 index b9cc771a0..000000000 --- a/python/monarch/rdma.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import ctypes - -from dataclasses import dataclass -from typing import cast, Dict, Optional, Tuple - -import torch - -from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._src.actor.actor_mesh import ( - _ActorMeshRefImpl, - Actor, - ActorMeshRef, - endpoint, - MonarchContext, -) - - -@dataclass -class LocalRDMARecord: - data: torch.Tensor - - -_local_buffers: Dict[int, "LocalRDMARecord"] = {} - - -def _get_bytes(storage: torch.Tensor, offset: int, size: int) -> bytearray: - """Extracts a bytearray from a 1D, 1byte per item tensor.""" - if offset + size > storage.numel(): - raise ValueError(f"Read out of range: {offset + size} > {storage.size()}") - addr = storage.data_ptr() - if storage.device.type != "cpu": - result = bytearray(size) - result_tensor = torch.frombuffer( - result, - dtype=torch.uint8, - ) - source_tensor = storage[offset:] - result_tensor.copy_(source_tensor) - else: - ctypes_array = (ctypes.c_byte * size).from_address(addr) - result = bytearray(ctypes_array) - return result - - -class RDMAManager(Actor): - @staticmethod - def on_proc(proc_id: str) -> "RDMAManager": - ctx = MonarchContext.get() - return cast( - RDMAManager, - ActorMeshRef( - RDMAManager, - _ActorMeshRefImpl.from_actor_id( - ctx.mailbox, - ActorId.from_string(f"{proc_id}.rdma_manager[0]"), - ), - ctx.mailbox, - ), - ) - - @endpoint - async def drop(self, addr: int) -> None: - if addr in _local_buffers: - del _local_buffers[addr] - - @endpoint - async def fetch(self, addr: int, offset: int, nbytes: int) -> bytearray: - if addr not in _local_buffers: - raise ValueError(f"Unknown buffer {addr}") - storage = _local_buffers[addr].data - return _get_bytes(storage, offset, nbytes) - - @endpoint - async def put(self, addr: int, offset: int, bytes: bytearray) -> None: - if addr not in _local_buffers: - raise ValueError(f"Unknown buffer {addr}") - storage = _local_buffers[addr].data - storage[offset : offset + len(bytes)] = torch.frombuffer( - bytes, dtype=storage.dtype - ) - - -def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None: - if t.ndim != 1: - raise ValueError(f"Tensor must be 1D, got {t.ndim}D") - if t.dtype != torch.uint8: - raise ValueError(f"Tensor must be uint8, got {t.dtype}") - if not t.is_contiguous(): - raise ValueError("Tensor must be contiguous") - - -class RDMABuffer: - def __init__(self, data: torch.Tensor) -> None: - """ - RDMABuffer only supports 1D contiguous tensors that are 1 byte per item. - - To create a 1 byte, 1D view, use t.view(torch.uint8).flatten() - - TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors - """ - _assert_tensor_is_1d_contiguous_uint8(data) - assert data.storage_offset() == 0 - storage = data.untyped_storage() - self.addr: int = storage.data_ptr() - self.begin = 0 - self.end: int = storage.size() - self.proc_id: str = MonarchContext.get().proc_id - self.local_data: object = None - _local_buffers[self.addr] = LocalRDMARecord(data) - - def drop(self) -> None: - if self.proc_id is None: - del _local_buffers[self.addr] - return - rmda_actor = RDMAManager.on_proc(self.proc_id) - # pyre-ignore[16]: Undefined attribute [16]: `Endpoint` has no attribute `cast`. - rmda_actor.drop.cast(self.addr) - - def __getstate__(self) -> Tuple[int, int, int, Optional[str]]: - proc_id = self.proc_id - # locally created RDMABuffer being set remotely, - # record its proc_id so we know how to establish connections to it - if proc_id is None: - proc_id = MonarchContext.get().proc_id - return (self.addr, self.begin, self.end, proc_id) - - def __setstate__(self, state: Tuple[int, int, int, str]) -> None: - self.local_data = None - self.addr, self.begin, self.end, self.proc_id = state - - async def read_into(self, dst: torch.Tensor, offset: int = 0) -> None: - """ - Read data from the RDMABuffer into a destination tensor. - - The destination tensor must be contiguous and 1 byte per item. - """ - _assert_tensor_is_1d_contiguous_uint8(dst) - bytes = await RDMAManager.on_proc(self.proc_id).fetch.call_one( - self.addr, offset, dst.numel() - ) - dst.copy_(torch.frombuffer(bytes, dtype=torch.uint8)) - - async def write(self, src: torch.Tensor, offset: int = 0) -> None: - """ - Write data from a source tensor into the RDMABuffer. - - The source tensor must be contiguous and 1 byte per item. - """ - _assert_tensor_is_1d_contiguous_uint8(src) - bytes = _get_bytes( - src, - cast(int, src.storage_offset()), - src.numel(), - ) - await RDMAManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes) diff --git a/python/monarch/tensor_engine/__init__.py b/python/monarch/tensor_engine/__init__.py new file mode 100644 index 000000000..172a7a5b1 --- /dev/null +++ b/python/monarch/tensor_engine/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Monarch Tensor Engine API - Public interface for tensor engine functionality. +""" + +from monarch._src.tensor_engine.rdma import ( + is_available, + RDMABuffer, + RDMAReadTransferWarning, + RDMAWriteTransferWarning, +) + +__all__ = [ + "is_available", + "RDMABuffer", + "RDMAReadTransferWarning", + "RDMAWriteTransferWarning", +] diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 57513d46b..96e2a43ad 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -33,7 +33,6 @@ local_proc_mesh, proc_mesh, ) -from monarch.rdma import RDMABuffer from typing_extensions import assert_type @@ -66,26 +65,6 @@ async def call_value(self, c: Counter) -> int: return await c.value.choose() -class ParameterServer(Actor): - def __init__(self): - self.params = torch.rand(10, 10) - self.grad_buffer = torch.rand(10, 10) - - @endpoint - async def grad_handle(self) -> RDMABuffer: - byte_tensor = self.grad_buffer.view(torch.uint8).flatten() - return RDMABuffer(byte_tensor) - - @endpoint - async def update(self): - self.params += 0.01 * self.grad_buffer - - @endpoint - async def get_grad_buffer(self) -> torch.Tensor: - # just used for testing - return self.grad_buffer - - async def test_choose(): proc = await local_proc_mesh(gpus=2) v = await proc.spawn("counter", Counter, 3) @@ -112,79 +91,6 @@ async def test_stream(): assert 8 == sum([x async for x in v.value.stream()]) -class ParameterClient(Actor): - def __init__(self, server, buffer): - self.server = server - byte_tensor = buffer.view(torch.uint8).flatten() - self.buffer = byte_tensor - - @endpoint - async def upload(self, tensor): - gh = await self.server.grad_handle.call_one() - await gh.write(tensor) - - @endpoint - async def download(self): - gh = await self.server.grad_handle.call_one() - await gh.read_into(self.buffer) - - @endpoint - async def get_buffer(self): - return self.buffer - - -@needs_cuda -async def test_proc_mesh_rdma(): - proc = await proc_mesh(gpus=1) - server = await proc.spawn("server", ParameterServer) - - # --- CPU TESTS --- - client_cpu = await proc.spawn( - "client_cpu", ParameterClient, server, torch.ones(10, 10) - ) - x = await client_cpu.get_buffer.call_one() - assert torch.sum(x.view(torch.float32).view(10, 10)) == 100 - zeros = torch.zeros(10, 10) - await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten()) - await client_cpu.download.call_one() - x = await client_cpu.get_buffer.call_one() - assert torch.sum(x.view(torch.float32).view(10, 10)) == 0 - - # --- Modify server's backing buffer directly --- - await server.update.call_one() - - # Should reflect updated values - await client_cpu.download.call_one() - - buffer = await client_cpu.get_buffer.call_one() - remote_grad = await server.get_grad_buffer.call_one() - assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad) - - # --- GPU TESTS --- - client_gpu = await proc.spawn( - "client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda") - ) - x = await client_gpu.get_buffer.call_one() - buffer = x.view(torch.float32).view(10, 10) - assert torch.sum(buffer) == 100 - zeros = torch.zeros(10, 10, device="cuda") - await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten()) - await client_gpu.download.call_one() - x = await client_gpu.get_buffer.call_one() - buffer_gpu = x.view(torch.float32).view(10, 10) - assert torch.sum(buffer_gpu) == 0 - # copying a tensor across hosts moves it to CPU - assert buffer_gpu.device.type == "cpu" - - # Modify server state again - await server.update.call_one() - await client_gpu.download.call_one() - x = await client_gpu.get_buffer.call_one() - buffer_gpu = x.view(torch.float32).view(10, 10) - remote_grad = await server.get_grad_buffer.call_one() - assert torch.allclose(buffer_gpu.cpu(), remote_grad) - - class To(Actor): @endpoint async def whoami(self): @@ -256,69 +162,6 @@ async def test_rank_size(): assert 4 == await acc.accumulate(lambda: current_size()["gpus"]) -class TrainerActor(Actor): - def __init__(self): - super().__init__() - self.trainer = torch.nn.Linear(10, 10).to("cuda") - self.trainer.weight.data.zero_() - - @endpoint - async def init(self, gen): - ranks = current_rank() - self.gen = gen.slice(**ranks) - - @endpoint - async def exchange_metadata(self): - byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten() - self.handle = RDMABuffer(byte_tensor) - await self.gen.attach_weight_buffer.call(self.handle) - - @endpoint - async def weights_ready(self): - self.trainer.weight.data.add_(1.0) - - -class GeneratorActor(Actor): - def __init__(self): - super().__init__() - self.generator = torch.nn.Linear(10, 10).to("cuda") - self.step = 0 - - @endpoint - async def init(self, trainer): - ranks = current_rank() - self.trainer = trainer.slice(**ranks) - - @endpoint - async def attach_weight_buffer(self, handle): - self.handle = handle - - @endpoint - async def update_weights(self): - self.step += 1 - byte_tensor = self.generator.weight.data.view(torch.uint8).flatten() - await self.handle.read_into(byte_tensor) - assert ( - torch.sum(self.generator.weight.data) == self.step * 100 - ), f"{torch.sum(self.generator.weight.data)=}, {self.step=}" - - -@needs_cuda -async def test_gpu_trainer_generator(): - trainer_proc = await proc_mesh(gpus=1) - gen_proc = await proc_mesh(gpus=1) - trainer = await trainer_proc.spawn("trainer", TrainerActor) - generator = await gen_proc.spawn("gen", GeneratorActor) - - await generator.init.call(trainer) - await trainer.init.call(generator) - await trainer.exchange_metadata.call() - - for _ in range(3): - await trainer.weights_ready.call() - await generator.update_weights.call() - - class SyncActor(Actor): @endpoint def sync_endpoint(self, a_counter: Counter): @@ -333,22 +176,6 @@ async def test_sync_actor(): assert r == 5 -@needs_cuda -def test_gpu_trainer_generator_sync() -> None: - trainer_proc = proc_mesh(gpus=1).get() - gen_proc = proc_mesh(gpus=1).get() - trainer = trainer_proc.spawn("trainer", TrainerActor).get() - generator = gen_proc.spawn("gen", GeneratorActor).get() - - generator.init.call(trainer).get() - trainer.init.call(generator).get() - trainer.exchange_metadata.call().get() - - for _ in range(3): - trainer.weights_ready.call().get() - generator.update_weights.call().get() - - def test_sync_actor_sync_client(): proc = local_proc_mesh(gpus=2).get() a = proc.spawn("actor", SyncActor).get() diff --git a/python/tests/test_rdma.py b/python/tests/test_rdma.py new file mode 100644 index 000000000..9259c1f26 --- /dev/null +++ b/python/tests/test_rdma.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from monarch.tensor_engine import is_available as rdma_available, RDMABuffer + + +needs_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +needs_rdma = pytest.mark.skipif( + not rdma_available, + reason="RDMA not available", +) + + +class ParameterServer(Actor): + def __init__(self): + self.params = torch.rand(10, 10) + self.grad_buffer = torch.rand(10, 10) + + @endpoint + async def grad_handle(self) -> RDMABuffer: + byte_tensor = self.grad_buffer.view(torch.uint8).flatten() + buffer = RDMABuffer(byte_tensor) + return buffer + + @endpoint + async def update(self): + self.params += 0.01 * self.grad_buffer + + @endpoint + async def get_grad_buffer(self) -> torch.Tensor: + # just used for testing + return self.grad_buffer + + +class ParameterClient(Actor): + def __init__(self, server, buffer): + self.server = server + byte_tensor = buffer.view(torch.uint8).flatten() + self.buffer = byte_tensor + + @endpoint + async def upload(self, tensor): + gh = await self.server.grad_handle.call_one() + await gh.write_from(tensor) + + @endpoint + async def download(self): + gh = await self.server.grad_handle.call_one() + await gh.read_into(self.buffer) + + @endpoint + async def get_buffer(self): + return self.buffer + + +@needs_rdma +@needs_cuda +async def test_proc_mesh_rdma(): + proc = await proc_mesh(gpus=1) + server = await proc.spawn("server", ParameterServer) + + # --- CPU TESTS --- + client_cpu = await proc.spawn( + "client_cpu", ParameterClient, server, torch.ones(10, 10) + ) + x = await client_cpu.get_buffer.call_one() + assert torch.sum(x.view(torch.float32).view(10, 10)) == 100 + zeros = torch.zeros(10, 10) + await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten()) + await client_cpu.download.call_one() + x = await client_cpu.get_buffer.call_one() + assert torch.sum(x.view(torch.float32).view(10, 10)) == 0 + + # --- Modify server's backing buffer directly --- + await server.update.call_one() + + # Should reflect updated values + await client_cpu.download.call_one() + + buffer = await client_cpu.get_buffer.call_one() + remote_grad = await server.get_grad_buffer.call_one() + assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad) + + # --- GPU TESTS --- + client_gpu = await proc.spawn( + "client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda") + ) + x = await client_gpu.get_buffer.call_one() + buffer = x.view(torch.float32).view(10, 10) + assert torch.sum(buffer) == 100 + zeros = torch.zeros(10, 10, device="cuda") + await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten()) + await client_gpu.download.call_one() + x = await client_gpu.get_buffer.call_one() + buffer_gpu = x.view(torch.float32).view(10, 10) + assert torch.sum(buffer_gpu) == 0 + # copying a tensor across hosts moves it to CPU + assert buffer_gpu.device.type == "cpu" + + # Modify server state again + await server.update.call_one() + await client_gpu.download.call_one() + x = await client_gpu.get_buffer.call_one() + buffer_gpu = x.view(torch.float32).view(10, 10) + remote_grad = await server.get_grad_buffer.call_one() + assert torch.allclose(buffer_gpu.cpu(), remote_grad) + + +class TrainerActor(Actor): + def __init__(self): + super().__init__() + # TODO - switch to CUDA once GPU support is added + self.trainer = torch.nn.Linear(10, 10).to("cpu") + self.trainer.weight.data.zero_() + + @endpoint + async def init(self, gen): + ranks = current_rank() + self.gen = gen.slice(**ranks) + + @endpoint + async def exchange_metadata(self): + byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten() + self.handle = RDMABuffer(byte_tensor) + await self.gen.attach_weight_buffer.call(self.handle) + + @endpoint + async def weights_ready(self): + self.trainer.weight.data.add_(1.0) + + +class GeneratorActor(Actor): + def __init__(self): + super().__init__() + self.generator = torch.nn.Linear(10, 10).to("cuda") + self.step = 0 + + @endpoint + async def init(self, trainer): + ranks = current_rank() + self.trainer = trainer.slice(**ranks) + + @endpoint + async def attach_weight_buffer(self, handle): + self.handle = handle + + @endpoint + async def update_weights(self): + self.step += 1 + byte_tensor = self.generator.weight.data.view(torch.uint8).flatten() + await self.handle.read_into(byte_tensor) + assert ( + torch.sum(self.generator.weight.data) == self.step * 100 + ), f"{torch.sum(self.generator.weight.data)=}, {self.step=}" + + +@needs_rdma +@needs_cuda +async def test_gpu_trainer_generator(): + trainer_proc = await proc_mesh(gpus=1) + gen_proc = await proc_mesh(gpus=1) + trainer = await trainer_proc.spawn("trainer", TrainerActor) + generator = await gen_proc.spawn("gen", GeneratorActor) + + await generator.init.call(trainer) + await trainer.init.call(generator) + await trainer.exchange_metadata.call() + + for _ in range(3): + await trainer.weights_ready.call() + await generator.update_weights.call() + + +@needs_rdma +@needs_cuda +def test_gpu_trainer_generator_sync() -> None: + trainer_proc = proc_mesh(gpus=1).get() + gen_proc = proc_mesh(gpus=1).get() + trainer = trainer_proc.spawn("trainer", TrainerActor).get() + generator = gen_proc.spawn("gen", GeneratorActor).get() + + generator.init.call(trainer).get() + trainer.init.call(generator).get() + trainer.exchange_metadata.call().get() + + for _ in range(1): + trainer.weights_ready.call().get() + generator.update_weights.call().get()