Skip to content

Commit 6ce10ab

Browse files
committed
Use separate weight mapper for draft model
Signed-off-by: Anurag Mukkara <[email protected]>
1 parent d11acee commit 6ce10ab

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,11 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
554554
hidden_size=model_config.pretrained_config.hidden_size,
555555
vocab_size=model_config.pretrained_config.vocab_size)
556556
self.draft_model = None
557+
self.draft_config = None
557558
spec_config = getattr(model_config, 'spec_config', None)
558559
if spec_config and spec_config.spec_dec_mode.use_one_engine():
559-
draft_config = None
560560
if spec_config.spec_dec_mode.is_eagle3_one_model():
561-
draft_config = ModelConfig.from_pretrained(
561+
self.draft_config = ModelConfig.from_pretrained(
562562
model_config.spec_config.speculative_model_dir,
563563
trust_remote_code=True,
564564
attn_backend=model_config.attn_backend,
@@ -567,17 +567,17 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
567567
spec_config=model_config.spec_config,
568568
max_num_tokens=model_config.max_num_tokens,
569569
moe_max_num_tokens=model_config.moe_max_num_tokens)
570-
draft_config.quant_config.kv_cache_quant_algo = \
570+
self.draft_config.quant_config.kv_cache_quant_algo = \
571571
model_config.quant_config.kv_cache_quant_algo
572572

573-
self.draft_model = get_draft_model(model_config, draft_config,
573+
self.draft_model = get_draft_model(model_config, self.draft_config,
574574
self.lm_head, self.model)
575575
self.spec_worker = get_spec_worker(model_config.spec_config,
576576
model_config,
577577
model_config.mapping)
578578

579-
if draft_config is not None:
580-
for key, value in draft_config.extra_attrs.items():
579+
if self.draft_config is not None:
580+
for key, value in self.draft_config.extra_attrs.items():
581581
assert key in ('attn_layers', 'mla_layers')
582582
assert key in model_config.extra_attrs
583583
model_config.extra_attrs[key].update(value)

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,10 @@ def init_meta_tensor(t: torch.Tensor):
278278
):
279279
weights = checkpoint_loader.load_weights(
280280
self.spec_config.speculative_model_dir)
281+
draft_weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
282+
model.draft_model, model.draft_config)
281283
self._call_load_weights(model.load_draft_weights, weights,
282-
self.weight_mapper)
284+
draft_weight_mapper)
283285

284286
elif load_format == LoadFormat.DUMMY:
285287
self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(

0 commit comments

Comments
 (0)