Skip to content

Commit 2b27810

Browse files
authored
[https://nvbugs/5494718][fix] Fix Single GPU Multi-node issue and OOM on DGX Spark (#8514)
Signed-off-by: Simeng Liu <[email protected]>
1 parent 812bc8c commit 2b27810

File tree

8 files changed

+39
-12
lines changed

8 files changed

+39
-12
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
1616
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
17+
from tensorrt_llm._utils import is_device_integrated
1718
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
1819
AllReduceStrategy)
20+
from tensorrt_llm.logger import logger
1921
from tensorrt_llm.mapping import Mapping
2022
from tensorrt_llm.quantization.functional import \
2123
preprocess_weights_for_mixed_gemm
@@ -67,6 +69,15 @@ def load_weight_shard(
6769
tensor_parallel_mode: Optional[TensorParallelMode] = None,
6870
device: torch.device = torch.device('cpu'),
6971
) -> torch.Tensor:
72+
# Skip device transfers on integrated GPUs to conserve shared memory
73+
if weight.device.type != device.type and is_device_integrated():
74+
# For integrated GPU systems (e.g., DGX Spark), CPU and GPU share limited physical memory.
75+
# Avoiding device transfers reduces memory consumption and unnecessary data copies,
76+
# enabling support for larger models on memory-constrained systems.
77+
logger.warning(
78+
f"[load_weight_shard] Skipping device transfer from {weight.device} to {device} on integrated GPU to conserve shared memory."
79+
)
80+
device = weight.device
7081
if isinstance(weight, torch.Tensor):
7182
tensor_shape = weight.shape
7283

tensorrt_llm/_utils.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def local_mpi_barrier():
577577

578578

579579
def mpi_broadcast(obj, root=0):
580-
return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj
580+
return mpi_comm().bcast(obj, root) if global_mpi_size() > 1 else obj
581581

582582

583583
def mpi_allgather(obj):
@@ -1141,17 +1141,6 @@ def _unique_tokens_to_json(data):
11411141
}
11421142

11431143

1144-
def is_multi_device_enable():
1145-
"""
1146-
This method evaluates if we are running on multiple GPUs and the flag ENABLE_MULTI_DEVICE is set.
1147-
So we can avoid broadcast calls on single GPU.
1148-
Issue: https://github.com/NVIDIA/TensorRT-LLM/issues/5927
1149-
ENABLE_MULTI_DEVICE is true by default when building TensorRT LLM so we need to also check
1150-
the number of devices
1151-
"""
1152-
return local_mpi_size() > 1
1153-
1154-
11551144
def set_prometheus_multiproc_dir() -> object:
11561145
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
11571146
global prometheus_multiproc_dir
@@ -1174,3 +1163,19 @@ def torch_pybind11_abi() -> str:
11741163
if TORCH_PYBIND11_ABI is None:
11751164
TORCH_PYBIND11_ABI = f"{torch._C._PYBIND11_COMPILER_TYPE}{torch._C._PYBIND11_STDLIB}{torch._C._PYBIND11_BUILD_ABI}"
11761165
return TORCH_PYBIND11_ABI
1166+
1167+
1168+
@lru_cache(maxsize=1)
1169+
def is_device_integrated() -> bool:
1170+
"""Check if the current GPU device is integrated (shares physical memory with CPU).
1171+
1172+
Integrated GPU systems include DGX Spark and other unified memory architectures.
1173+
This function caches the result to avoid repeated CUDA device property queries.
1174+
1175+
Returns:
1176+
bool: True if the GPU is integrated, False otherwise. Returns False if CUDA
1177+
is not available.
1178+
"""
1179+
if not torch.cuda.is_available():
1180+
return False
1181+
return torch.cuda.get_device_properties().is_integrated

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
34143414

34153415
MODEL_PATH = f"{llm_models_root()}/gpt_oss/gpt-oss-120b"
34163416

3417+
@pytest.mark.skip(reason="https://nvbugs/5596343")
34173418
@pytest.mark.parametrize(
34183419
"kv_cache_dtype",
34193420
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
@@ -3465,6 +3466,7 @@ def test_dummy_load_format(self):
34653466
task = GSM8K(model_name)
34663467
task.evaluate(llm, is_integration_test=True)
34673468

3469+
@pytest.mark.skip(reason="https://nvbugs/5596343")
34683470
@pytest.mark.skip_less_device(4)
34693471
@pytest.mark.parametrize(
34703472
"kv_cache_dtype",
@@ -3668,6 +3670,7 @@ class TestQwen2_VL_7B(LlmapiAccuracyTestHarness):
36683670

36693671
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
36703672

3673+
@pytest.mark.skip(reason="https://nvbugs/5601909")
36713674
def test_auto_dtype(self):
36723675
with LLM(self.MODEL_PATH,
36733676
max_num_tokens=16384,

tests/integration/defs/cpp/test_e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def test_model(build_google_tests, model, prepare_model, run_model_tests,
263263
run_model_tests(model, run_fp8)
264264

265265

266+
@pytest.mark.skip(reason="https://nvbugs/5601670")
266267
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
267268
indirect=True)
268269
@pytest.mark.parametrize("model", ["bart", "gpt", "t5"])

tests/integration/defs/disaggregated/test_workers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def background_workers(llm_venv, config_file: str, num_ranks: int = None):
509509
proc.wait()
510510

511511

512+
@pytest.mark.skip(reason="https://nvbugs/5372970")
512513
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
513514
indirect=True)
514515
def test_workers_conditional_disaggregation(disaggregated_test_root,

tests/integration/defs/test_e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,7 @@ def test_openai_lora(llm_root, llm_venv):
16651665
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])
16661666

16671667

1668+
@pytest.mark.skip(reason="https://nvbugs/5596377")
16681669
def test_openai_chat_multimodal_example(llm_root, llm_venv):
16691670
test_root = unittest_path() / "llmapi" / "apps"
16701671
llm_venv.run_cmd(

tests/unittest/executor/test_rpc_proxy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def create_proxy(self, tp_size: int):
5555

5656
return proxy
5757

58+
@pytest.mark.skip(reason="https://nvbugs/5579234")
5859
@pytest.mark.parametrize("num_reqs", [1, 10])
5960
def test_tp1(self, num_reqs):
6061
tokenizer = TransformersTokenizer.from_pretrained(model_path)

tests/unittest/executor/test_rpc_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def test_fetch_responses_streaming_sync(self):
9898
break
9999
assert 0 < len(results) <= 5
100100

101+
@pytest.mark.skip(reason="https://nvbugs/5583261")
101102
@pytest.mark.asyncio
102103
@pytest.mark.parametrize("req_count", [10])
103104
async def test_main_loop_async(self, req_count: int):
@@ -175,6 +176,7 @@ async def process_request_streaming():
175176

176177
await process_request_streaming()
177178

179+
@pytest.mark.skip(reason="https://nvbugs/5583261")
178180
@pytest.mark.asyncio
179181
async def test_fetch_stats_loop_async(self):
180182
await asyncio.sleep(1)
@@ -227,13 +229,15 @@ def create_rpc_client(self, addr: str):
227229

228230
@skip_single_gpu
229231
@pytest.mark.gpu2
232+
@pytest.mark.skip(reason="https://nvbugs/5583261")
230233
def test_create_shutdown(self):
231234
# Invoke setup_engine in rank 0, and that will unblock all the ranks to
232235
# invoke setup_engine simultaneously.
233236
pass
234237

235238
@skip_single_gpu
236239
@pytest.mark.gpu2
240+
@pytest.mark.skip(reason="https://nvbugs/5583261")
237241
def test_fetch_responses_sync(self):
238242
# Wait a bit to ensure engine is ready
239243
time.sleep(1)

0 commit comments

Comments
 (0)