Skip to content

Commit c66b146

Browse files
If CUDA_VISIBLE_DEVICES is set, map cuda.current_device()
Perform a mapping between integer CUDA_VISIBLE_DEVICES values to find the host current device for DEEP_EP_DEVICE_TO_HCA_MAPPING. Signed-off-by: Clayton Coleman <[email protected]>
1 parent 2abd7e3 commit c66b146

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

deep_ep/buffer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def _setup_device_hca_mapping(self):
154154

155155
# Get current device and set appropriate HCA
156156
current_device = torch.cuda.current_device()
157+
# Translate CUDA_VISIBLE_DEVICES
158+
if 'CUDA_VISIBLE_DEVICES' in os.environ:
159+
visible_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
160+
assert len(visible_devices) > current_device, f"CUDA_VISIBLE_DEVICES has {len(visible_devices)} entries which is fewer than the current device {current_device}"
161+
assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices"
162+
current_device = int(visible_devices[current_device])
163+
157164
assert current_device in device_mapping, f"Current CUDA device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
158165
os.environ['NVSHMEM_ENABLE_PE_MAPPING'] = '1'
159166
os.environ['NVSHMEM_HCA_LIST'] = device_mapping[current_device]

0 commit comments

Comments
 (0)