@@ -554,11 +554,11 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
554554 hidden_size = model_config .pretrained_config .hidden_size ,
555555 vocab_size = model_config .pretrained_config .vocab_size )
556556 self .draft_model = None
557+ self .draft_config = None
557558 spec_config = getattr (model_config , 'spec_config' , None )
558559 if spec_config and spec_config .spec_dec_mode .use_one_engine ():
559- draft_config = None
560560 if spec_config .spec_dec_mode .is_eagle3_one_model ():
561- draft_config = ModelConfig .from_pretrained (
561+ self . draft_config = ModelConfig .from_pretrained (
562562 model_config .spec_config .speculative_model_dir ,
563563 trust_remote_code = True ,
564564 attn_backend = model_config .attn_backend ,
@@ -567,17 +567,17 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
567567 spec_config = model_config .spec_config ,
568568 max_num_tokens = model_config .max_num_tokens ,
569569 moe_max_num_tokens = model_config .moe_max_num_tokens )
570- draft_config .quant_config .kv_cache_quant_algo = \
570+ self . draft_config .quant_config .kv_cache_quant_algo = \
571571 model_config .quant_config .kv_cache_quant_algo
572572
573- self .draft_model = get_draft_model (model_config , draft_config ,
573+ self .draft_model = get_draft_model (model_config , self . draft_config ,
574574 self .lm_head , self .model )
575575 self .spec_worker = get_spec_worker (model_config .spec_config ,
576576 model_config ,
577577 model_config .mapping )
578578
579- if draft_config is not None :
580- for key , value in draft_config .extra_attrs .items ():
579+ if self . draft_config is not None :
580+ for key , value in self . draft_config .extra_attrs .items ():
581581 assert key in ('attn_layers' , 'mla_layers' )
582582 assert key in model_config .extra_attrs
583583 model_config .extra_attrs [key ].update (value )
0 commit comments