Skip to content

Commit a7c0b54

Browse files
authored
[None][feat] add specdec to nemotron nas (#8985)
Signed-off-by: Shreyas Misra <[email protected]>
1 parent 7ab02ad commit a7c0b54

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

tensorrt_llm/_torch/models/modeling_nemotron_nas.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from ..modules.gated_mlp import GatedMLP
1818
from ..modules.linear import Linear, TensorParallelMode
1919
from ..modules.rms_norm import RMSNorm
20-
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
21-
register_auto_model)
20+
from ..speculative import SpecMetadata
21+
from .modeling_speculative import SpecDecOneEngineForCausalLM
22+
from .modeling_utils import DecoderModel, register_auto_model
2223

2324

2425
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
@@ -117,6 +118,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
117118
block_config: Dict[str, Any], layer_idx: int):
118119
super().__init__()
119120
config = model_config.pretrained_config
121+
self.layer_idx = layer_idx
120122
self.block_config = block_config
121123
if not self.block_config.attention.no_op:
122124
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
@@ -150,6 +152,7 @@ def forward(
150152
hidden_states: torch.Tensor,
151153
attn_metadata: AttentionMetadata,
152154
residual: Optional[torch.Tensor] = None,
155+
spec_metadata: Optional[SpecMetadata] = None,
153156
**kwargs,
154157
) -> torch.Tensor:
155158
if not self.block_config.attention.no_op:
@@ -178,6 +181,11 @@ def forward(
178181
hidden_states, residual)
179182
hidden_states = self.mlp(hidden_states, **kwargs)
180183

184+
# Capture hidden states for speculative decoding
185+
if spec_metadata is not None:
186+
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
187+
hidden_states, residual)
188+
181189
return hidden_states, residual
182190

183191

@@ -238,6 +246,7 @@ def forward(
238246
input_ids: Optional[torch.IntTensor] = None,
239247
position_ids: Optional[torch.IntTensor] = None,
240248
inputs_embeds: Optional[torch.FloatTensor] = None,
249+
spec_metadata: Optional[SpecMetadata] = None,
241250
lora_params: Optional[dict] = None,
242251
**kwargs,
243252
) -> torch.Tensor:
@@ -259,6 +268,7 @@ def forward(
259268
hidden_states=hidden_states,
260269
attn_metadata=attn_metadata,
261270
residual=residual,
271+
spec_metadata=spec_metadata,
262272
lora_params=lora_params,
263273
)
264274

@@ -267,11 +277,8 @@ def forward(
267277

268278

269279
@register_auto_model("DeciLMForCausalLM")
270-
class NemotronNASForCausalLM(DecoderModelForCausalLM[NemotronNASModel,
271-
PretrainedConfig]):
280+
class NemotronNASForCausalLM(SpecDecOneEngineForCausalLM[NemotronNASModel,
281+
PretrainedConfig]):
272282

273283
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
274-
super().__init__(NemotronNASModel(model_config),
275-
config=model_config,
276-
hidden_size=model_config.pretrained_config.hidden_size,
277-
vocab_size=model_config.pretrained_config.vocab_size)
284+
super().__init__(NemotronNASModel(model_config), model_config)

0 commit comments

Comments
 (0)