Skip to content

Commit 9270041

Browse files
pcastonguayraayandharreasonsolocoderabbitai[bot]
authored andcommitted
AGI 0819 cherry-pick NVIDIA#6369
Signed-off-by: Patrice Castonguay <[email protected]> Signed-off-by: raayandhar <[email protected]> Signed-off-by: Lizhi Zhou <[email protected]> Signed-off-by: pcastonguay <[email protected]> Co-authored-by: raayandhar <[email protected]> Co-authored-by: Lizhi Zhou <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent aa7549c commit 9270041

File tree

15 files changed

+510
-21
lines changed

15 files changed

+510
-21
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,8 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
808808
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
809809
{
810810
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
811+
TLLM_LOG_WARNING("self: %zu dest %zu", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
812+
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
811813
return false;
812814
}
813815
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();

scripts/build_wheel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def clear_folder(folder_path):
7171
if os.path.isdir(item_path) and not os.path.islink(item_path):
7272
rmtree(item_path)
7373
else:
74-
os.remove(item_path)
74+
try:
75+
os.remove(item_path)
76+
except (OSError, IOError) as e:
77+
print(f"Failed to remove {item_path}: {e}", file=sys.stderr)
7578

7679

7780
def sysconfig_scheme(override_vars=None):

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ def __init__(self, mapping: Mapping, comm_type: CommTypeCpp,
8383
attention_type: AttentionTypeCpp,
8484
cache_transceiver_config: CacheTransceiverConfig):
8585
world_config = mapping_to_world_config(mapping)
86-
num_kv_heads_per_layer = kv_cache_manager.num_kv_heads_per_layer
86+
total_num_kv_heads_per_layer = kv_cache_manager.total_num_kv_heads_per_layer
8787
head_dim = kv_cache_manager.head_dim
8888
tokens_per_block = kv_cache_manager.tokens_per_block
8989
dtype = kv_cache_manager.dtype
9090

91-
self.impl = CacheTransceiverCpp(kv_cache_manager.impl, comm_type,
92-
num_kv_heads_per_layer, head_dim,
91+
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
92+
total_num_kv_heads_per_layer, head_dim,
9393
tokens_per_block, world_config, dtype,
9494
attention_type,
9595
cache_transceiver_config)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class BatchState:
185185
@dataclasses.dataclass
186186
class BatchStatePP(BatchState):
187187
microbatch_id: int = -1
188+
scheduled_ctx_reqs: list[LlmRequest] = None
188189

189190

190191
class PyExecutor:
@@ -744,6 +745,7 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
744745
return False
745746

746747
def _executor_loop_pp(self):
748+
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
747749
torch.cuda.set_device(self.device_id)
748750
microbatch_id = 0
749751
with self._profiler() as profile_step:
@@ -757,16 +759,33 @@ def _executor_loop_pp(self):
757759
if self.should_stop_processing:
758760
break
759761

762+
if self.kv_cache_transceiver:
763+
self._check_disagg_gen_transfer_status()
764+
760765
if self.enable_iter_perf_stats:
761766
iter_stats = self._get_init_iter_stats(
762767
len(new_requests),
763768
self.new_active_requests_queue_latency_ms)
764769

765770
self._pad_attention_dp_dummy_request()
766771

767-
scheduled_batch, _, _ = self._schedule()
772+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
773+
)
774+
775+
if self.kv_cache_transceiver:
776+
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
777+
self._prepare_disagg_gen_init(
778+
fitting_disagg_gen_init_requests)
779+
780+
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
781+
logger.warning(
782+
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
783+
)
784+
self.kv_cache_transceiver.check_context_transfer_status(
785+
1)
768786

769787
self.num_scheduled_requests = scheduled_batch.batch_size
788+
770789
logger.debug(
771790
f'has {len(self.active_requests)} active_request, '
772791
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
@@ -779,7 +798,7 @@ def _executor_loop_pp(self):
779798
can_queue = 0 not in tp_batch_sizes
780799
else:
781800
can_queue = scheduled_batch.batch_size > 0
782-
if not can_queue:
801+
if not can_queue and not self.kv_cache_transceiver:
783802
assert len(self.inflight_req_ids) > 0, (
784803
"fail to schedule any pending request, probably run out of resource"
785804
)
@@ -788,8 +807,28 @@ def _executor_loop_pp(self):
788807
self.micro_batches[microbatch_id] = None
789808
else:
790809
self._add_inflight_ids(scheduled_batch)
810+
811+
if self.kv_cache_transceiver:
812+
# For generation requests which have completed KV cache transfer
813+
self._prepare_disagg_gen_transmission_complete(
814+
scheduled_batch)
815+
791816
self.resource_manager.prepare_resources(scheduled_batch)
792817

818+
# The generation requests that are do not have batch_idx,
819+
# needs to be in front of the batch due to the assumptions
820+
# made in model_engine.py::_forward_step. This is only important
821+
# for disaggregated serving. For non-disaggregated serving,
822+
# the generation requests always have batch_idx.
823+
scheduled_batch.generation_requests = sorted( # stable sort
824+
scheduled_batch.generation_requests,
825+
key=lambda req: int(req.py_batch_idx is not None),
826+
)
827+
828+
if self.kv_cache_transceiver:
829+
# Return the first token to the client
830+
self._handle_first_token_response(scheduled_batch)
831+
793832
# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
794833
if not self.dist.is_last_pp_rank:
795834
sample_state = self._forward_step_inter_pp(
@@ -814,6 +853,7 @@ def _executor_loop_pp(self):
814853
iter_start_time=iter_start_time,
815854
iter_stats=iter_stats,
816855
microbatch_id=microbatch_id,
856+
scheduled_ctx_reqs=scheduled_batch.context_requests,
817857
)
818858

819859
self.micro_batches[microbatch_id] = batch_state
@@ -878,6 +918,11 @@ def _executor_loop_pp(self):
878918
if previous_batch is not None:
879919
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
880920
self._update_requests(previous_batch.sample_state)
921+
922+
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
923+
self._send_disagg_ctx_cache(
924+
previous_batch.scheduled_ctx_reqs)
925+
881926
self._handle_canceled_requests()
882927
finished_requests = self._handle_responses()
883928
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
@@ -886,6 +931,9 @@ def _executor_loop_pp(self):
886931
self._remove_inflight_ids(previous_scheduled_batch)
887932
self.micro_batches[prev_microbatch_id] = None
888933

934+
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
935+
self._terminate_ctx_finished_requests()
936+
889937
# march forward in microbatch slots
890938
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
891939

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,33 @@ def __init__(
154154
(num_kv_heads + tp_size - 1) // tp_size
155155
for _ in range(self.num_local_layers)
156156
]
157+
self.total_num_kv_heads_per_layer = [
158+
(num_kv_heads + tp_size - 1) // tp_size
159+
for _ in range(self.num_layers)
160+
]
157161
else:
158162
assert len(num_kv_heads) == self.num_layers
159163

164+
def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
165+
kv_head: Optional[int]):
166+
if kv_head is not None:
167+
num_kv_heads_per_layer.append(
168+
(kv_head + tp_size - 1) // tp_size)
169+
else:
170+
num_kv_heads_per_layer.append(0)
171+
160172
self.num_kv_heads_per_layer = []
161173
if self.num_local_layers > 0:
162174
for i in self.pp_layers:
163175
kv_head = num_kv_heads[i]
164-
if kv_head is not None:
165-
self.num_kv_heads_per_layer.append(
166-
(kv_head + tp_size - 1) // tp_size)
167-
else:
168-
self.num_kv_heads_per_layer.append(0)
176+
append_to_kv_heads_per_layer(self.num_kv_heads_per_layer,
177+
kv_head)
178+
179+
self.total_num_kv_heads_per_layer = []
180+
for i in range(self.num_layers):
181+
kv_head = num_kv_heads[i]
182+
append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer,
183+
kv_head)
169184

170185
self.num_kv_heads = num_kv_heads
171186
self.head_dim = head_dim

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,3 +735,14 @@ def setup_class(cls):
735735
logger.set_level("info")
736736
yield
737737
logger.set_level(original_level)
738+
739+
740+
def get_accuracy_task(dataset_name: str):
741+
try:
742+
task_class = globals()[dataset_name]
743+
if issubclass(task_class, AccuracyTask):
744+
return task_class
745+
else:
746+
raise ValueError(f"Unknown dataset: {dataset_name}.")
747+
except KeyError:
748+
raise ValueError(f"Not registered dataset: {dataset_name}.")

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
2121
from tensorrt_llm.llmapi.llm_args import LlmArgs
2222

23-
from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper
23+
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
24+
skip_pre_hopper)
2425
from ..trt_test_alternative import popen
25-
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
26+
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
27+
get_accuracy_task)
2628

2729

2830
class Result(GenerationResultBase):
@@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7173
temp_dir = tempfile.TemporaryDirectory()
7274
disaggregated_serving_config_path = os.path.join(
7375
temp_dir.name, "disaggregated_serving_config.yaml")
76+
77+
if tensor_parallel_size > 1:
78+
print(
79+
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
80+
)
81+
7482
with open(disaggregated_serving_config_path, "w") as f:
7583
yaml.dump(disaggregated_server_config, f)
7684
ctx_server_config_path = os.path.join(temp_dir.name,
@@ -88,21 +96,47 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
8896
trtllm_serve_path = "trtllm-serve"
8997
# Common arguments for both servers
9098
common_args = [
91-
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
92-
"pytorch"
99+
trtllm_serve_path,
100+
model_name,
101+
"--host",
102+
"localhost",
103+
"--backend",
104+
"pytorch",
93105
]
94-
if tensor_parallel_size > 1:
95-
common_args.append(f"--tp_size={tensor_parallel_size}")
106+
gen_tp, gen_pp = gen_server_config.get(
107+
"tensor_parallel_size",
108+
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
109+
1)
110+
ctx_tp, ctx_pp = ctx_server_config.get(
111+
"tensor_parallel_size",
112+
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
113+
1)
114+
115+
ctx_total_gpus = ctx_tp * ctx_pp
116+
gen_total_gpus = gen_tp * gen_pp
96117

97118
env_ctx = os.environ.copy()
98119
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
99-
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(
100-
map(str, range(tensor_parallel_size)))
120+
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))
101121

102122
env_gen = os.environ.copy()
103123
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
104124
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
105-
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
125+
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
126+
ctx_server_args = common_args + [
127+
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
128+
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
129+
]
130+
gen_server_args = common_args + [
131+
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
132+
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
133+
]
134+
if "max_num_tokens" in ctx_server_config:
135+
ctx_server_args.append(
136+
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
137+
if "max_num_tokens" in gen_server_config:
138+
gen_server_args.append(
139+
f"--max_num_tokens={gen_server_config['max_num_tokens']}")
106140

107141
with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir,
108142
popen(common_args + [
@@ -177,6 +211,56 @@ def generate_async(prompt: str,
177211
disaggregated_server.wait()
178212

179213

214+
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
215+
ctx_tp: int, gen_pp: int, gen_tp: int,
216+
test_set: LlmapiAccuracyTestHarness):
217+
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
218+
pytest.fail(
219+
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
220+
)
221+
222+
kv_cache_config = {
223+
"free_gpu_memory_fraction": 0.5,
224+
"enable_block_reuse": False
225+
}
226+
ctx_server_config = {
227+
"pipeline_parallel_size": ctx_pp,
228+
"tensor_parallel_size": ctx_tp,
229+
"disable_overlap_scheduler": True,
230+
"kv_cache_config": kv_cache_config,
231+
"cache_transceiver_config": {
232+
"backend": "default"
233+
}
234+
}
235+
gen_server_config = {
236+
"tensor_parallel_size": gen_tp,
237+
"pipeline_parallel_size": gen_pp,
238+
"disable_overlap_scheduler": True,
239+
"kv_cache_config": kv_cache_config,
240+
"cache_transceiver_config": {
241+
"backend": "default"
242+
}
243+
}
244+
disaggregated_server_config = {
245+
"hostname": "localhost",
246+
"port": 8000,
247+
"backend": "pytorch",
248+
"context_servers": {
249+
"num_instances": 1,
250+
"urls": ["localhost:8001"]
251+
},
252+
"generation_servers": {
253+
"num_instances": 1,
254+
"urls": ["localhost:8002"]
255+
}
256+
}
257+
with launch_disaggregated_llm(disaggregated_server_config,
258+
ctx_server_config, gen_server_config,
259+
model_path) as llm:
260+
task = test_set(model_name)
261+
task.evaluate(llm)
262+
263+
180264
@pytest.mark.timeout(3600)
181265
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
182266
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
@@ -252,6 +336,20 @@ def test_ngram(self):
252336
task = GSM8K(self.MODEL_NAME)
253337
task.evaluate(llm)
254338

339+
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
340+
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
341+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
342+
def test_tp_pp_symmetric(self, tp, pp, testset):
343+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
344+
tp, get_accuracy_task(testset))
345+
346+
@parametrize_with_ids("ctx_pp", [2, 4])
347+
@parametrize_with_ids("gen_tp", [1, 2])
348+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
349+
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
350+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
351+
gen_tp, get_accuracy_task(testset))
352+
255353

256354
@pytest.mark.timeout(3600)
257355
@pytest.mark.skip_less_device_memory(140000)

0 commit comments

Comments
 (0)