|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 |
| -import ctypes |
| 7 | +import warnings |
8 | 8 |
|
9 |
| -from dataclasses import dataclass |
10 |
| -from typing import cast, Dict, Optional, Tuple |
11 |
| - |
12 |
| -import torch |
13 |
| - |
14 |
| -from monarch._rust_bindings.monarch_hyperactor.proc import ActorId |
15 |
| -from monarch._src.actor.actor_mesh import ( |
16 |
| - _ActorMeshRefImpl, |
17 |
| - Actor, |
18 |
| - ActorMeshRef, |
19 |
| - endpoint, |
20 |
| - MonarchContext, |
| 9 | +warnings.warn( |
| 10 | + "monarch.rdma is deprecated, please import from monarch.tensor_engine.rdma instead.", |
| 11 | + DeprecationWarning, |
| 12 | + stacklevel=2, |
21 | 13 | )
|
22 | 14 |
|
23 |
| - |
24 |
| -@dataclass |
25 |
| -class LocalRDMARecord: |
26 |
| - data: torch.Tensor |
27 |
| - |
28 |
| - |
29 |
| -_local_buffers: Dict[int, "LocalRDMARecord"] = {} |
30 |
| - |
31 |
| - |
32 |
| -def _get_bytes(storage: torch.Tensor, offset: int, size: int) -> bytearray: |
33 |
| - """Extracts a bytearray from a 1D, 1byte per item tensor.""" |
34 |
| - if offset + size > storage.numel(): |
35 |
| - raise ValueError(f"Read out of range: {offset + size} > {storage.size()}") |
36 |
| - addr = storage.data_ptr() |
37 |
| - if storage.device.type != "cpu": |
38 |
| - result = bytearray(size) |
39 |
| - result_tensor = torch.frombuffer( |
40 |
| - result, |
41 |
| - dtype=torch.uint8, |
42 |
| - ) |
43 |
| - source_tensor = storage[offset:] |
44 |
| - result_tensor.copy_(source_tensor) |
45 |
| - else: |
46 |
| - ctypes_array = (ctypes.c_byte * size).from_address(addr) |
47 |
| - result = bytearray(ctypes_array) |
48 |
| - return result |
49 |
| - |
50 |
| - |
51 |
| -class RDMAManager(Actor): |
52 |
| - @staticmethod |
53 |
| - def on_proc(proc_id: str) -> "RDMAManager": |
54 |
| - ctx = MonarchContext.get() |
55 |
| - return cast( |
56 |
| - RDMAManager, |
57 |
| - ActorMeshRef( |
58 |
| - RDMAManager, |
59 |
| - _ActorMeshRefImpl.from_actor_id( |
60 |
| - ctx.mailbox, |
61 |
| - ActorId.from_string(f"{proc_id}.rdma_manager[0]"), |
62 |
| - ), |
63 |
| - ctx.mailbox, |
64 |
| - ), |
65 |
| - ) |
66 |
| - |
67 |
| - @endpoint |
68 |
| - async def drop(self, addr: int) -> None: |
69 |
| - if addr in _local_buffers: |
70 |
| - del _local_buffers[addr] |
71 |
| - |
72 |
| - @endpoint |
73 |
| - async def fetch(self, addr: int, offset: int, nbytes: int) -> bytearray: |
74 |
| - if addr not in _local_buffers: |
75 |
| - raise ValueError(f"Unknown buffer {addr}") |
76 |
| - storage = _local_buffers[addr].data |
77 |
| - return _get_bytes(storage, offset, nbytes) |
78 |
| - |
79 |
| - @endpoint |
80 |
| - async def put(self, addr: int, offset: int, bytes: bytearray) -> None: |
81 |
| - if addr not in _local_buffers: |
82 |
| - raise ValueError(f"Unknown buffer {addr}") |
83 |
| - storage = _local_buffers[addr].data |
84 |
| - storage[offset : offset + len(bytes)] = torch.frombuffer( |
85 |
| - bytes, dtype=storage.dtype |
86 |
| - ) |
87 |
| - |
88 |
| - |
89 |
| -def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None: |
90 |
| - if t.ndim != 1: |
91 |
| - raise ValueError(f"Tensor must be 1D, got {t.ndim}D") |
92 |
| - if t.dtype != torch.uint8: |
93 |
| - raise ValueError(f"Tensor must be uint8, got {t.dtype}") |
94 |
| - if not t.is_contiguous(): |
95 |
| - raise ValueError("Tensor must be contiguous") |
96 |
| - |
97 |
| - |
98 |
| -class RDMABuffer: |
99 |
| - def __init__(self, data: torch.Tensor) -> None: |
100 |
| - """ |
101 |
| - RDMABuffer only supports 1D contiguous tensors that are 1 byte per item. |
102 |
| -
|
103 |
| - To create a 1 byte, 1D view, use t.view(torch.uint8).flatten() |
104 |
| -
|
105 |
| - TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors |
106 |
| - """ |
107 |
| - _assert_tensor_is_1d_contiguous_uint8(data) |
108 |
| - assert data.storage_offset() == 0 |
109 |
| - storage = data.untyped_storage() |
110 |
| - self.addr: int = storage.data_ptr() |
111 |
| - self.begin = 0 |
112 |
| - self.end: int = storage.size() |
113 |
| - self.proc_id: str = MonarchContext.get().proc_id |
114 |
| - self.local_data: object = None |
115 |
| - _local_buffers[self.addr] = LocalRDMARecord(data) |
116 |
| - |
117 |
| - def drop(self) -> None: |
118 |
| - if self.proc_id is None: |
119 |
| - del _local_buffers[self.addr] |
120 |
| - return |
121 |
| - rmda_actor = RDMAManager.on_proc(self.proc_id) |
122 |
| - # pyre-ignore[16]: Undefined attribute [16]: `Endpoint` has no attribute `cast`. |
123 |
| - rmda_actor.drop.cast(self.addr) |
124 |
| - |
125 |
| - def __getstate__(self) -> Tuple[int, int, int, Optional[str]]: |
126 |
| - proc_id = self.proc_id |
127 |
| - # locally created RDMABuffer being set remotely, |
128 |
| - # record its proc_id so we know how to establish connections to it |
129 |
| - if proc_id is None: |
130 |
| - proc_id = MonarchContext.get().proc_id |
131 |
| - return (self.addr, self.begin, self.end, proc_id) |
132 |
| - |
133 |
| - def __setstate__(self, state: Tuple[int, int, int, str]) -> None: |
134 |
| - self.local_data = None |
135 |
| - self.addr, self.begin, self.end, self.proc_id = state |
136 |
| - |
137 |
| - async def read_into(self, dst: torch.Tensor, offset: int = 0) -> None: |
138 |
| - """ |
139 |
| - Read data from the RDMABuffer into a destination tensor. |
140 |
| -
|
141 |
| - The destination tensor must be contiguous and 1 byte per item. |
142 |
| - """ |
143 |
| - _assert_tensor_is_1d_contiguous_uint8(dst) |
144 |
| - bytes = await RDMAManager.on_proc(self.proc_id).fetch.call_one( |
145 |
| - self.addr, offset, dst.numel() |
146 |
| - ) |
147 |
| - dst.copy_(torch.frombuffer(bytes, dtype=torch.uint8)) |
148 |
| - |
149 |
| - async def write(self, src: torch.Tensor, offset: int = 0) -> None: |
150 |
| - """ |
151 |
| - Write data from a source tensor into the RDMABuffer. |
152 |
| -
|
153 |
| - The source tensor must be contiguous and 1 byte per item. |
154 |
| - """ |
155 |
| - _assert_tensor_is_1d_contiguous_uint8(src) |
156 |
| - bytes = _get_bytes( |
157 |
| - src, |
158 |
| - cast(int, src.storage_offset()), |
159 |
| - src.numel(), |
160 |
| - ) |
161 |
| - await RDMAManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes) |
| 15 | +from monarch.tensor_engine import * # noqa |
0 commit comments