Skip to content

Commit 8908273

Browse files
allenwang28facebook-github-bot
authored andcommitted
(6/final) Updates rdma.py with tensor_engine RDMA (#582)
Summary: This diff: - Marks rdma.py for deprecation - Replaces the existing RDMABuffer stubs with the real monarch_extension based version - Updates examples/meta/rl toy_actor example to use RDMABuffer - Splits out a few RDMA tests from test_python_actors.py into `test_rdma.py` Differential Revision: D78366430
1 parent 0aae80c commit 8908273

File tree

6 files changed

+245
-337
lines changed

6 files changed

+245
-337
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import logging
910
import os
1011
import sys
1112
import warnings
@@ -24,7 +25,6 @@
2425
)
2526

2627
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
27-
2828
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
2929
Alloc,
3030
AllocConstraints,
@@ -63,13 +63,14 @@
6363

6464
HAS_TENSOR_ENGINE = False
6565
try:
66-
# TODO: while the tensor_engine submodule doesn't exist yet, use the
67-
# available of monarch.rdma as a proxy.
68-
# type: ignore
69-
from monarch.rdma import RDMAManager # @manual
66+
from monarch._rust_bindings.rdma import ( # type: ignore[import]
67+
_RdmaManager,
68+
create_rdma_manager_blocking,
69+
)
7070

7171
HAS_TENSOR_ENGINE = True
7272
except ImportError:
73+
logging.warning("RDMA is not available on this platform")
7374
pass
7475

7576

@@ -102,7 +103,7 @@ def __init__(
102103
self._proc_mesh = hy_proc_mesh
103104
self._mock_shape: Optional[Shape] = _mock_shape
104105
# type: ignore[21]
105-
self._rdma_manager: Optional["RDMAManager"] = None
106+
self._rdma_manager: Optional["_RdmaManager"] = None
106107
self._debug_manager: Optional[DebugManager] = None
107108
self._mailbox: Mailbox = self._proc_mesh.client
108109
self._code_sync_client: Optional[CodeSyncMeshClient] = None
@@ -118,7 +119,7 @@ def __init__(
118119
with fake_sync_state():
119120
if _mock_shape is None and HAS_TENSOR_ENGINE:
120121
# type: ignore[21]
121-
self._rdma_manager = self.spawn("rdma_manager", RDMAManager).get()
122+
self._rdma_manager = create_rdma_manager_blocking(self._proc_mesh)
122123
if not _is_initializing_debugger and _mock_shape is None:
123124
self._debug_manager = self.spawn(
124125
_DEBUG_MANAGER_ACTOR_NAME, DebugManager, debug_client()

python/monarch/_src/tensor_engine/rdma.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
from typing import Optional
1010

1111
import torch
12-
from monarch._rust_bindings.rdma import _RdmaBuffer
1312

13+
try:
14+
from monarch._rust_bindings.rdma import _RdmaBuffer
15+
except ImportError as e:
16+
logging.error("RDMA is not available: {}".format(e))
17+
raise e
18+
from monarch._src.actor.actor_mesh import MonarchContext
1419
from monarch._src.actor.future import Future
1520

16-
from monarch.actor import MonarchContext
17-
1821

1922
# RDMARead/WriteTransferWarnings are warnings that are only printed once per process.
2023
# Remove these once GPU support is added.
@@ -30,7 +33,7 @@ class RDMAWriteTransferWarning(Warning):
3033
warnings.simplefilter("once", RDMAWriteTransferWarning)
3134

3235

33-
def rdma_supported():
36+
def is_available():
3437
return _RdmaBuffer.rdma_supported()
3538

3639

@@ -52,7 +55,9 @@ def __init__(self, data: torch.Tensor) -> None:
5255
5356
TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors
5457
"""
55-
assert _RdmaBuffer.rdma_supported()
58+
assert (
59+
is_available()
60+
), "Tried to create an RDMABuffer, but RDMA is not available on this platform."
5661

5762
if data.device.type != "cpu":
5863
# TODO - CUDA support for RDMABuffer exists at the Rust layer, but

python/monarch/rdma.py

Lines changed: 6 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -4,158 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import ctypes
7+
import warnings
88

9-
from dataclasses import dataclass
10-
from typing import cast, Dict, Optional, Tuple
11-
12-
import torch
13-
14-
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
15-
from monarch._src.actor.actor_mesh import (
16-
_ActorMeshRefImpl,
17-
Actor,
18-
ActorMeshRef,
19-
endpoint,
20-
MonarchContext,
9+
warnings.warn(
10+
"monarch.rdma is deprecated, please import from monarch.tensor_engine.rdma instead.",
11+
DeprecationWarning,
12+
stacklevel=2,
2113
)
2214

23-
24-
@dataclass
25-
class LocalRDMARecord:
26-
data: torch.Tensor
27-
28-
29-
_local_buffers: Dict[int, "LocalRDMARecord"] = {}
30-
31-
32-
def _get_bytes(storage: torch.Tensor, offset: int, size: int) -> bytearray:
33-
"""Extracts a bytearray from a 1D, 1byte per item tensor."""
34-
if offset + size > storage.numel():
35-
raise ValueError(f"Read out of range: {offset + size} > {storage.size()}")
36-
addr = storage.data_ptr()
37-
if storage.device.type != "cpu":
38-
result = bytearray(size)
39-
result_tensor = torch.frombuffer(
40-
result,
41-
dtype=torch.uint8,
42-
)
43-
source_tensor = storage[offset:]
44-
result_tensor.copy_(source_tensor)
45-
else:
46-
ctypes_array = (ctypes.c_byte * size).from_address(addr)
47-
result = bytearray(ctypes_array)
48-
return result
49-
50-
51-
class RDMAManager(Actor):
52-
@staticmethod
53-
def on_proc(proc_id: str) -> "RDMAManager":
54-
ctx = MonarchContext.get()
55-
return cast(
56-
RDMAManager,
57-
ActorMeshRef(
58-
RDMAManager,
59-
_ActorMeshRefImpl.from_actor_id(
60-
ctx.mailbox,
61-
ActorId.from_string(f"{proc_id}.rdma_manager[0]"),
62-
),
63-
ctx.mailbox,
64-
),
65-
)
66-
67-
@endpoint
68-
async def drop(self, addr: int) -> None:
69-
if addr in _local_buffers:
70-
del _local_buffers[addr]
71-
72-
@endpoint
73-
async def fetch(self, addr: int, offset: int, nbytes: int) -> bytearray:
74-
if addr not in _local_buffers:
75-
raise ValueError(f"Unknown buffer {addr}")
76-
storage = _local_buffers[addr].data
77-
return _get_bytes(storage, offset, nbytes)
78-
79-
@endpoint
80-
async def put(self, addr: int, offset: int, bytes: bytearray) -> None:
81-
if addr not in _local_buffers:
82-
raise ValueError(f"Unknown buffer {addr}")
83-
storage = _local_buffers[addr].data
84-
storage[offset : offset + len(bytes)] = torch.frombuffer(
85-
bytes, dtype=storage.dtype
86-
)
87-
88-
89-
def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None:
90-
if t.ndim != 1:
91-
raise ValueError(f"Tensor must be 1D, got {t.ndim}D")
92-
if t.dtype != torch.uint8:
93-
raise ValueError(f"Tensor must be uint8, got {t.dtype}")
94-
if not t.is_contiguous():
95-
raise ValueError("Tensor must be contiguous")
96-
97-
98-
class RDMABuffer:
99-
def __init__(self, data: torch.Tensor) -> None:
100-
"""
101-
RDMABuffer only supports 1D contiguous tensors that are 1 byte per item.
102-
103-
To create a 1 byte, 1D view, use t.view(torch.uint8).flatten()
104-
105-
TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors
106-
"""
107-
_assert_tensor_is_1d_contiguous_uint8(data)
108-
assert data.storage_offset() == 0
109-
storage = data.untyped_storage()
110-
self.addr: int = storage.data_ptr()
111-
self.begin = 0
112-
self.end: int = storage.size()
113-
self.proc_id: str = MonarchContext.get().proc_id
114-
self.local_data: object = None
115-
_local_buffers[self.addr] = LocalRDMARecord(data)
116-
117-
def drop(self) -> None:
118-
if self.proc_id is None:
119-
del _local_buffers[self.addr]
120-
return
121-
rmda_actor = RDMAManager.on_proc(self.proc_id)
122-
# pyre-ignore[16]: Undefined attribute [16]: `Endpoint` has no attribute `cast`.
123-
rmda_actor.drop.cast(self.addr)
124-
125-
def __getstate__(self) -> Tuple[int, int, int, Optional[str]]:
126-
proc_id = self.proc_id
127-
# locally created RDMABuffer being set remotely,
128-
# record its proc_id so we know how to establish connections to it
129-
if proc_id is None:
130-
proc_id = MonarchContext.get().proc_id
131-
return (self.addr, self.begin, self.end, proc_id)
132-
133-
def __setstate__(self, state: Tuple[int, int, int, str]) -> None:
134-
self.local_data = None
135-
self.addr, self.begin, self.end, self.proc_id = state
136-
137-
async def read_into(self, dst: torch.Tensor, offset: int = 0) -> None:
138-
"""
139-
Read data from the RDMABuffer into a destination tensor.
140-
141-
The destination tensor must be contiguous and 1 byte per item.
142-
"""
143-
_assert_tensor_is_1d_contiguous_uint8(dst)
144-
bytes = await RDMAManager.on_proc(self.proc_id).fetch.call_one(
145-
self.addr, offset, dst.numel()
146-
)
147-
dst.copy_(torch.frombuffer(bytes, dtype=torch.uint8))
148-
149-
async def write(self, src: torch.Tensor, offset: int = 0) -> None:
150-
"""
151-
Write data from a source tensor into the RDMABuffer.
152-
153-
The source tensor must be contiguous and 1 byte per item.
154-
"""
155-
_assert_tensor_is_1d_contiguous_uint8(src)
156-
bytes = _get_bytes(
157-
src,
158-
cast(int, src.storage_offset()),
159-
src.numel(),
160-
)
161-
await RDMAManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes)
15+
from monarch.tensor_engine import * # noqa
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Monarch Tensor Engine API - Public interface for tensor engine functionality.
9+
"""
10+
11+
from monarch._src.tensor_engine.rdma import (
12+
is_available,
13+
RDMABuffer,
14+
RDMAReadTransferWarning,
15+
RDMAWriteTransferWarning,
16+
)
17+
18+
__all__ = [
19+
"is_available",
20+
"RDMABuffer",
21+
"RDMAReadTransferWarning",
22+
"RDMAWriteTransferWarning",
23+
]

0 commit comments

Comments
 (0)