1717from ..modules .gated_mlp import GatedMLP
1818from ..modules .linear import Linear , TensorParallelMode
1919from ..modules .rms_norm import RMSNorm
20- from .modeling_utils import (DecoderModel , DecoderModelForCausalLM ,
21- register_auto_model )
20+ from ..speculative import SpecMetadata
21+ from .modeling_speculative import SpecDecOneEngineForCausalLM
22+ from .modeling_utils import DecoderModel , register_auto_model
2223
2324
2425def _ffn_mult_to_intermediate_size (ffn_mult : float , n_embd : int ) -> int :
@@ -117,6 +118,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
117118 block_config : Dict [str , Any ], layer_idx : int ):
118119 super ().__init__ ()
119120 config = model_config .pretrained_config
121+ self .layer_idx = layer_idx
120122 self .block_config = block_config
121123 if not self .block_config .attention .no_op :
122124 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
@@ -150,6 +152,7 @@ def forward(
150152 hidden_states : torch .Tensor ,
151153 attn_metadata : AttentionMetadata ,
152154 residual : Optional [torch .Tensor ] = None ,
155+ spec_metadata : Optional [SpecMetadata ] = None ,
153156 ** kwargs ,
154157 ) -> torch .Tensor :
155158 if not self .block_config .attention .no_op :
@@ -178,6 +181,11 @@ def forward(
178181 hidden_states , residual )
179182 hidden_states = self .mlp (hidden_states , ** kwargs )
180183
184+ # Capture hidden states for speculative decoding
185+ if spec_metadata is not None :
186+ spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
187+ hidden_states , residual )
188+
181189 return hidden_states , residual
182190
183191
@@ -238,6 +246,7 @@ def forward(
238246 input_ids : Optional [torch .IntTensor ] = None ,
239247 position_ids : Optional [torch .IntTensor ] = None ,
240248 inputs_embeds : Optional [torch .FloatTensor ] = None ,
249+ spec_metadata : Optional [SpecMetadata ] = None ,
241250 lora_params : Optional [dict ] = None ,
242251 ** kwargs ,
243252 ) -> torch .Tensor :
@@ -259,6 +268,7 @@ def forward(
259268 hidden_states = hidden_states ,
260269 attn_metadata = attn_metadata ,
261270 residual = residual ,
271+ spec_metadata = spec_metadata ,
262272 lora_params = lora_params ,
263273 )
264274
@@ -267,11 +277,8 @@ def forward(
267277
268278
269279@register_auto_model ("DeciLMForCausalLM" )
270- class NemotronNASForCausalLM (DecoderModelForCausalLM [NemotronNASModel ,
271- PretrainedConfig ]):
280+ class NemotronNASForCausalLM (SpecDecOneEngineForCausalLM [NemotronNASModel ,
281+ PretrainedConfig ]):
272282
273283 def __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
274- super ().__init__ (NemotronNASModel (model_config ),
275- config = model_config ,
276- hidden_size = model_config .pretrained_config .hidden_size ,
277- vocab_size = model_config .pretrained_config .vocab_size )
284+ super ().__init__ (NemotronNASModel (model_config ), model_config )
0 commit comments