Skip to content

Commit 8ca8f52

Browse files
committed
Eagle: PostNorm and multilayer options
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 470d777 commit 8ca8f52

File tree

1 file changed

+58
-26
lines changed

1 file changed

+58
-26
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def __init__(
3030
self,
3131
model_config: ModelConfig[LlamaConfig],
3232
layer_idx: Optional[int] = None,
33+
next_layer_regular: bool = False,
3334
):
3435
config = model_config.pretrained_config
36+
self._next_layer_regular = next_layer_regular
3537
super().__init__(
3638
hidden_size=config.hidden_size,
3739
num_attention_heads=config.num_attention_heads,
@@ -52,19 +54,20 @@ def __init__(
5254
tp_size = 1
5355
# Override the QKV projection. The number of input features
5456
# is twice as big for EAGLE3 draft models.
55-
self.qkv_proj = Linear(
56-
2 * self.hidden_size,
57-
tp_size * self.q_size + 2 * tp_size * self.kv_size,
58-
bias=config.attention_bias,
59-
dtype=config.torch_dtype,
60-
mapping=self.qkv_proj.mapping,
61-
tensor_parallel_mode=TensorParallelMode.COLUMN,
62-
weights_loading_config=WeightsLoadingConfig(
63-
weight_mode=WeightMode.FUSED_QKV_LINEAR),
64-
quant_config=model_config.get_quant_config(),
65-
skip_create_weights_in_init=model_config.
66-
skip_create_weights_in_init,
67-
)
57+
if not self._next_layer_regular:
58+
self.qkv_proj = Linear(
59+
2 * self.hidden_size,
60+
tp_size * self.q_size + 2 * tp_size * self.kv_size,
61+
bias=config.attention_bias,
62+
dtype=config.torch_dtype,
63+
mapping=self.qkv_proj.mapping,
64+
tensor_parallel_mode=TensorParallelMode.COLUMN,
65+
weights_loading_config=WeightsLoadingConfig(
66+
weight_mode=WeightMode.FUSED_QKV_LINEAR),
67+
quant_config=model_config.get_quant_config(),
68+
skip_create_weights_in_init=model_config.
69+
skip_create_weights_in_init,
70+
)
6871

6972

7073
class Eagle3DecoderLayer(DecoderLayer):
@@ -73,12 +76,18 @@ def __init__(
7376
self,
7477
model_config: LlamaConfig,
7578
layer_idx: int = 0,
79+
is_first_layer: bool = True,
7680
) -> Tuple[torch.Tensor, torch.Tensor]:
7781
super().__init__()
7882
config = model_config.pretrained_config
83+
eagle_config = config.eagle_config if hasattr(config,
84+
"eagle_config") else {}
7985
self.layer_idx = layer_idx
80-
81-
self.self_attn = Eagle3Attention(model_config, layer_idx)
86+
self._next_layer_regular = (eagle_config.get("next_layer_regular", True)
87+
and not is_first_layer) or eagle_config.get(
88+
"eh_proj_before_attn", False)
89+
self.self_attn = Eagle3Attention(model_config, layer_idx,
90+
self._next_layer_regular)
8291

8392
if config.model_type == "llama4_text":
8493
inter_size = config.intermediate_size_mlp
@@ -94,9 +103,10 @@ def __init__(
94103
overridden_tp_size=1
95104
if model_config.mapping.enable_attention_dp else None,
96105
)
97-
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
98-
eps=config.rms_norm_eps,
99-
dtype=config.torch_dtype)
106+
if not self._next_layer_regular:
107+
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
108+
eps=config.rms_norm_eps,
109+
dtype=config.torch_dtype)
100110

101111
self.hidden_norm = RMSNorm(hidden_size=config.hidden_size,
102112
eps=config.rms_norm_eps,
@@ -116,10 +126,10 @@ def forward(
116126
) -> torch.Tensor:
117127
residual = hidden_states
118128

119-
embeds = self.input_layernorm(embeds)
120129
hidden_states = self.hidden_norm(hidden_states)
121-
122-
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
130+
if not self._next_layer_regular:
131+
embeds = self.input_layernorm(embeds)
132+
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
123133

124134
hidden_states = self.self_attn(
125135
position_ids=position_ids,
@@ -150,17 +160,24 @@ def __init__(
150160
super().__init__(model_config)
151161

152162
config = model_config.pretrained_config
163+
eagle_config = config.eagle_config if hasattr(config,
164+
"eagle_config") else {}
153165
self.spec_config = model_config.spec_config
154166
self.dtype = config.torch_dtype
155167
self.hidden_size = config.hidden_size
156168
self.mapping = model_config.mapping
157169
self.num_layers = model_config.pretrained_config.num_hidden_layers
170+
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn",
171+
False)
158172

159173
if hasattr(config, "target_hidden_size"):
160174
self.hidden_size_in = config.target_hidden_size
161175
else:
162176
self.hidden_size_in = config.hidden_size
163177

178+
self._return_hidden_post_norm = eagle_config.get(
179+
"return_hidden_post_norm", False)
180+
164181
if self.spec_config.num_capture_layers > 1:
165182
self.fc = Linear(self.hidden_size_in *
166183
self.spec_config.num_capture_layers,
@@ -170,7 +187,7 @@ def __init__(
170187

171188
if self.num_layers > 1:
172189
self.midlayer = nn.ModuleList([
173-
Eagle3DecoderLayer(model_config, start_layer_idx + i)
190+
Eagle3DecoderLayer(model_config, start_layer_idx + i, is_first_layer=(i == 0))
174191
for i in range(self.num_layers)
175192
])
176193
else:
@@ -184,6 +201,15 @@ def __init__(
184201
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
185202
dtype=torch.int32),
186203
requires_grad=False)
204+
if self._eh_proj_before_attn:
205+
self.enorm = RMSNorm(hidden_size=config.hidden_size,
206+
eps=config.rms_norm_eps,
207+
dtype=config.torch_dtype)
208+
self.eh_proj = nn.Linear(config.hidden_size * 2,
209+
config.hidden_size,
210+
bias=eagle_config.get(
211+
"eh_proj_bias", False),
212+
dtype=config.torch_dtype)
187213

188214
if self.hidden_size_in != config.hidden_size:
189215
if model_config.mapping.enable_attention_dp:
@@ -225,11 +251,15 @@ def forward(
225251
inputs_embeds = self.embed_tokens(input_ids).to(self.dtype)
226252

227253
assert hidden_states is not None
228-
229254
# NOTE: If hidden states from the target model have to be concatenated,
230-
# we expect that to happen outside the model definition. This helps us
231-
# avoid data-dependent control flow and gives us better CUDA graph
232-
# coverage.
255+
# ideally, we expect that to happen outside the model definition. This
256+
# helps usavoid data-dependent control flow and gives us better CUDA
257+
# graph coverage.
258+
if self._eh_proj_before_attn:
259+
input_embeds = self.enorm(inputs_embeds)
260+
hidden_states = torch.cat([input_embeds, hidden_states], dim=-1)
261+
hidden_states = self.eh_proj(hidden_states)
262+
233263
residual = None
234264
if self.num_layers > 1:
235265
for layer in self.midlayer:
@@ -249,6 +279,8 @@ def forward(
249279

250280
hidden_states, hidden_states_to_save = self.norm(
251281
hidden_states, residual)
282+
if self._return_hidden_post_norm:
283+
return hidden_states, hidden_states
252284
return hidden_states, hidden_states_to_save
253285

254286

0 commit comments

Comments
 (0)