Skip to content

Commit fe69243

Browse files
authored
[None][chore] Add placement test for ray executor (#9122)
Signed-off-by: Erin Ho <[email protected]>
1 parent bdcf837 commit fe69243

File tree

1 file changed

+58
-3
lines changed

1 file changed

+58
-3
lines changed

tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import os
22

33
import pytest
4+
import ray
5+
from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
6+
placement_group, remove_placement_group)
47
from utils.llm_data import llm_models_root
58

69
from tensorrt_llm import LLM
710
from tensorrt_llm._torch.utils import get_device_uuid
11+
from tensorrt_llm.llmapi import KvCacheConfig
812

913

1014
class DummyWorkerExtension:
@@ -22,17 +26,68 @@ def test_worker_extension():
2226
assert result[0] == "SUCCESS"
2327

2428

29+
@pytest.mark.gpu4
30+
def test_bundle_indices(monkeypatch):
31+
"""Placement via bundle indices"""
32+
33+
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
34+
monkeypatch.setenv("TLLM_RAY_USE_RPC", "1")
35+
36+
pg = None
37+
try:
38+
ray.init()
39+
pg = placement_group([{"GPU": 1, "CPU": 1}] * 4)
40+
ray.get(pg.ready())
41+
print(f"Placement group ready with bundles {pg.bundle_specs}")
42+
43+
bundle_indices = [2, 3]
44+
runtime_env = {
45+
"env_vars": {
46+
"TRTLLM_RAY_PER_WORKER_GPUS": "0.8",
47+
"TRTLLM_RAY_BUNDLE_INDICES": ",".join(map(str, bundle_indices))
48+
}
49+
}
50+
51+
llm = ray.remote(
52+
num_cpus=0,
53+
num_gpus=0,
54+
runtime_env=runtime_env,
55+
scheduling_strategy=PlacementGroupSchedulingStrategy(
56+
placement_group=pg,
57+
placement_group_capture_child_tasks=True,
58+
),
59+
)(LLM).remote(
60+
model=os.path.join(llm_models_root(), "llama-models-v2",
61+
"TinyLlama-1.1B-Chat-v1.0"),
62+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1),
63+
tensor_parallel_size=2,
64+
orchestrator_type="ray",
65+
)
66+
67+
inference_actor_uuids = ray.get(
68+
llm._collective_rpc.remote("report_device_id"))
69+
70+
expected_uuids = [get_device_uuid(idx) for idx in bundle_indices]
71+
72+
assert sorted(inference_actor_uuids) == sorted(expected_uuids), \
73+
f"Workers not placed on expected GPUs. Expected UUIDs: {expected_uuids}, Got: {inference_actor_uuids}"
74+
75+
finally:
76+
if pg is not None:
77+
remove_placement_group(pg)
78+
ray.shutdown()
79+
80+
2581
@pytest.mark.gpu2
26-
def test_cuda_visible_device():
82+
def test_cuda_visible_device(monkeypatch):
2783
"""Placement via cuda_visible_device"""
28-
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
84+
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "1")
2985

3086
llm = LLM(model=llm_models_root() /
3187
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
3288
orchestrator_type="ray")
3389

3490
infer_actor_uuids = llm._collective_rpc("report_device_id")
3591

36-
del os.environ["CUDA_VISIBLE_DEVICES"]
3792
assert infer_actor_uuids[0] == get_device_uuid(1)
3893
print(f"{infer_actor_uuids=}")

0 commit comments

Comments
 (0)