@@ -30,8 +30,10 @@ def __init__(
3030 self ,
3131 model_config : ModelConfig [LlamaConfig ],
3232 layer_idx : Optional [int ] = None ,
33+ next_layer_regular : bool = False ,
3334 ):
3435 config = model_config .pretrained_config
36+ self ._next_layer_regular = next_layer_regular
3537 super ().__init__ (
3638 hidden_size = config .hidden_size ,
3739 num_attention_heads = config .num_attention_heads ,
@@ -52,19 +54,20 @@ def __init__(
5254 tp_size = 1
5355 # Override the QKV projection. The number of input features
5456 # is twice as big for EAGLE3 draft models.
55- self .qkv_proj = Linear (
56- 2 * self .hidden_size ,
57- tp_size * self .q_size + 2 * tp_size * self .kv_size ,
58- bias = config .attention_bias ,
59- dtype = config .torch_dtype ,
60- mapping = self .qkv_proj .mapping ,
61- tensor_parallel_mode = TensorParallelMode .COLUMN ,
62- weights_loading_config = WeightsLoadingConfig (
63- weight_mode = WeightMode .FUSED_QKV_LINEAR ),
64- quant_config = model_config .get_quant_config (),
65- skip_create_weights_in_init = model_config .
66- skip_create_weights_in_init ,
67- )
57+ if not self ._next_layer_regular :
58+ self .qkv_proj = Linear (
59+ 2 * self .hidden_size ,
60+ tp_size * self .q_size + 2 * tp_size * self .kv_size ,
61+ bias = config .attention_bias ,
62+ dtype = config .torch_dtype ,
63+ mapping = self .qkv_proj .mapping ,
64+ tensor_parallel_mode = TensorParallelMode .COLUMN ,
65+ weights_loading_config = WeightsLoadingConfig (
66+ weight_mode = WeightMode .FUSED_QKV_LINEAR ),
67+ quant_config = model_config .get_quant_config (),
68+ skip_create_weights_in_init = model_config .
69+ skip_create_weights_in_init ,
70+ )
6871
6972
7073class Eagle3DecoderLayer (DecoderLayer ):
@@ -73,12 +76,18 @@ def __init__(
7376 self ,
7477 model_config : LlamaConfig ,
7578 layer_idx : int = 0 ,
79+ is_first_layer : bool = True ,
7680 ) -> Tuple [torch .Tensor , torch .Tensor ]:
7781 super ().__init__ ()
7882 config = model_config .pretrained_config
83+ eagle_config = config .eagle_config if hasattr (config ,
84+ "eagle_config" ) else {}
7985 self .layer_idx = layer_idx
80-
81- self .self_attn = Eagle3Attention (model_config , layer_idx )
86+ self ._next_layer_regular = (eagle_config .get ("next_layer_regular" , True )
87+ and not is_first_layer ) or eagle_config .get (
88+ "eh_proj_before_attn" , False )
89+ self .self_attn = Eagle3Attention (model_config , layer_idx ,
90+ self ._next_layer_regular )
8291
8392 if config .model_type == "llama4_text" :
8493 inter_size = config .intermediate_size_mlp
@@ -94,9 +103,10 @@ def __init__(
94103 overridden_tp_size = 1
95104 if model_config .mapping .enable_attention_dp else None ,
96105 )
97- self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
98- eps = config .rms_norm_eps ,
99- dtype = config .torch_dtype )
106+ if not self ._next_layer_regular :
107+ self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
108+ eps = config .rms_norm_eps ,
109+ dtype = config .torch_dtype )
100110
101111 self .hidden_norm = RMSNorm (hidden_size = config .hidden_size ,
102112 eps = config .rms_norm_eps ,
@@ -116,10 +126,10 @@ def forward(
116126 ) -> torch .Tensor :
117127 residual = hidden_states
118128
119- embeds = self .input_layernorm (embeds )
120129 hidden_states = self .hidden_norm (hidden_states )
121-
122- hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
130+ if not self ._next_layer_regular :
131+ embeds = self .input_layernorm (embeds )
132+ hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
123133
124134 hidden_states = self .self_attn (
125135 position_ids = position_ids ,
@@ -150,17 +160,24 @@ def __init__(
150160 super ().__init__ (model_config )
151161
152162 config = model_config .pretrained_config
163+ eagle_config = config .eagle_config if hasattr (config ,
164+ "eagle_config" ) else {}
153165 self .spec_config = model_config .spec_config
154166 self .dtype = config .torch_dtype
155167 self .hidden_size = config .hidden_size
156168 self .mapping = model_config .mapping
157169 self .num_layers = model_config .pretrained_config .num_hidden_layers
170+ self ._eh_proj_before_attn = eagle_config .get ("eh_proj_before_attn" ,
171+ False )
158172
159173 if hasattr (config , "target_hidden_size" ):
160174 self .hidden_size_in = config .target_hidden_size
161175 else :
162176 self .hidden_size_in = config .hidden_size
163177
178+ self ._return_hidden_post_norm = eagle_config .get (
179+ "return_hidden_post_norm" , False )
180+
164181 if self .spec_config .num_capture_layers > 1 :
165182 self .fc = Linear (self .hidden_size_in *
166183 self .spec_config .num_capture_layers ,
@@ -170,7 +187,7 @@ def __init__(
170187
171188 if self .num_layers > 1 :
172189 self .midlayer = nn .ModuleList ([
173- Eagle3DecoderLayer (model_config , start_layer_idx + i )
190+ Eagle3DecoderLayer (model_config , start_layer_idx + i , is_first_layer = ( i == 0 ) )
174191 for i in range (self .num_layers )
175192 ])
176193 else :
@@ -184,6 +201,15 @@ def __init__(
184201 self .d2t = nn .Parameter (torch .empty ((config .draft_vocab_size , ),
185202 dtype = torch .int32 ),
186203 requires_grad = False )
204+ if self ._eh_proj_before_attn :
205+ self .enorm = RMSNorm (hidden_size = config .hidden_size ,
206+ eps = config .rms_norm_eps ,
207+ dtype = config .torch_dtype )
208+ self .eh_proj = nn .Linear (config .hidden_size * 2 ,
209+ config .hidden_size ,
210+ bias = eagle_config .get (
211+ "eh_proj_bias" , False ),
212+ dtype = config .torch_dtype )
187213
188214 if self .hidden_size_in != config .hidden_size :
189215 if model_config .mapping .enable_attention_dp :
@@ -225,11 +251,15 @@ def forward(
225251 inputs_embeds = self .embed_tokens (input_ids ).to (self .dtype )
226252
227253 assert hidden_states is not None
228-
229254 # NOTE: If hidden states from the target model have to be concatenated,
230- # we expect that to happen outside the model definition. This helps us
231- # avoid data-dependent control flow and gives us better CUDA graph
232- # coverage.
255+ # ideally, we expect that to happen outside the model definition. This
256+ # helps usavoid data-dependent control flow and gives us better CUDA
257+ # graph coverage.
258+ if self ._eh_proj_before_attn :
259+ input_embeds = self .enorm (inputs_embeds )
260+ hidden_states = torch .cat ([input_embeds , hidden_states ], dim = - 1 )
261+ hidden_states = self .eh_proj (hidden_states )
262+
233263 residual = None
234264 if self .num_layers > 1 :
235265 for layer in self .midlayer :
@@ -249,6 +279,8 @@ def forward(
249279
250280 hidden_states , hidden_states_to_save = self .norm (
251281 hidden_states , residual )
282+ if self ._return_hidden_post_norm :
283+ return hidden_states , hidden_states
252284 return hidden_states , hidden_states_to_save
253285
254286
0 commit comments