@@ -30,8 +30,10 @@ def __init__(
3030 self ,
3131 model_config : ModelConfig [LlamaConfig ],
3232 layer_idx : Optional [int ] = None ,
33+ next_layer_regular : bool = False ,
3334 ):
3435 config = model_config .pretrained_config
36+ self ._next_layer_regular = next_layer_regular
3537 super ().__init__ (
3638 hidden_size = config .hidden_size ,
3739 num_attention_heads = config .num_attention_heads ,
@@ -52,19 +54,20 @@ def __init__(
5254 tp_size = 1
5355 # Override the QKV projection. The number of input features
5456 # is twice as big for EAGLE3 draft models.
55- self .qkv_proj = Linear (
56- 2 * self .hidden_size ,
57- tp_size * self .q_size + 2 * tp_size * self .kv_size ,
58- bias = config .attention_bias ,
59- dtype = config .torch_dtype ,
60- mapping = self .qkv_proj .mapping ,
61- tensor_parallel_mode = TensorParallelMode .COLUMN ,
62- weights_loading_config = WeightsLoadingConfig (
63- weight_mode = WeightMode .FUSED_QKV_LINEAR ),
64- quant_config = model_config .get_quant_config (),
65- skip_create_weights_in_init = model_config .
66- skip_create_weights_in_init ,
67- )
57+ if not self ._next_layer_regular :
58+ self .qkv_proj = Linear (
59+ 2 * self .hidden_size ,
60+ tp_size * self .q_size + 2 * tp_size * self .kv_size ,
61+ bias = config .attention_bias ,
62+ dtype = config .torch_dtype ,
63+ mapping = self .qkv_proj .mapping ,
64+ tensor_parallel_mode = TensorParallelMode .COLUMN ,
65+ weights_loading_config = WeightsLoadingConfig (
66+ weight_mode = WeightMode .FUSED_QKV_LINEAR ),
67+ quant_config = model_config .get_quant_config (),
68+ skip_create_weights_in_init = model_config .
69+ skip_create_weights_in_init ,
70+ )
6871
6972
7073class Eagle3DecoderLayer (DecoderLayer ):
@@ -73,12 +76,18 @@ def __init__(
7376 self ,
7477 model_config : LlamaConfig ,
7578 layer_idx : int = 0 ,
79+ is_first_layer : bool = True ,
7680 ) -> Tuple [torch .Tensor , torch .Tensor ]:
7781 super ().__init__ ()
7882 config = model_config .pretrained_config
83+ eagle_config = config .eagle_config if hasattr (config ,
84+ "eagle_config" ) else {}
7985 self .layer_idx = layer_idx
80-
81- self .self_attn = Eagle3Attention (model_config , layer_idx )
86+ self ._next_layer_regular = (eagle_config .get ("next_layer_regular" , True )
87+ and not is_first_layer ) or eagle_config .get (
88+ "eh_proj_before_attn" , False )
89+ self .self_attn = Eagle3Attention (model_config , layer_idx ,
90+ self ._next_layer_regular )
8291
8392 if config .model_type == "llama4_text" :
8493 inter_size = config .intermediate_size_mlp
@@ -94,9 +103,10 @@ def __init__(
94103 overridden_tp_size = 1
95104 if model_config .mapping .enable_attention_dp else None ,
96105 )
97- self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
98- eps = config .rms_norm_eps ,
99- dtype = config .torch_dtype )
106+ if not self ._next_layer_regular :
107+ self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
108+ eps = config .rms_norm_eps ,
109+ dtype = config .torch_dtype )
100110
101111 self .hidden_norm = RMSNorm (hidden_size = config .hidden_size ,
102112 eps = config .rms_norm_eps ,
@@ -116,10 +126,10 @@ def forward(
116126 ) -> torch .Tensor :
117127 residual = hidden_states
118128
119- embeds = self .input_layernorm (embeds )
120129 hidden_states = self .hidden_norm (hidden_states )
121-
122- hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
130+ if not self ._next_layer_regular :
131+ embeds = self .input_layernorm (embeds )
132+ hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
123133
124134 hidden_states = self .self_attn (
125135 position_ids = position_ids ,
@@ -150,17 +160,24 @@ def __init__(
150160 super ().__init__ (model_config )
151161
152162 config = model_config .pretrained_config
163+ eagle_config = config .eagle_config if hasattr (config ,
164+ "eagle_config" ) else {}
153165 self .spec_config = model_config .spec_config
154166 self .dtype = config .torch_dtype
155167 self .hidden_size = config .hidden_size
156168 self .mapping = model_config .mapping
157169 self .num_layers = model_config .pretrained_config .num_hidden_layers
170+ self ._eh_proj_before_attn = eagle_config .get ("eh_proj_before_attn" ,
171+ False )
158172
159173 if hasattr (config , "target_hidden_size" ):
160174 self .hidden_size_in = config .target_hidden_size
161175 else :
162176 self .hidden_size_in = config .hidden_size
163177
178+ self ._return_hidden_post_norm = eagle_config .get (
179+ "return_hidden_post_norm" , False )
180+
164181 if self .spec_config .num_capture_layers > 1 :
165182 self .fc = Linear (self .hidden_size_in *
166183 self .spec_config .num_capture_layers ,
@@ -170,7 +187,9 @@ def __init__(
170187
171188 if self .num_layers > 1 :
172189 self .midlayer = nn .ModuleList ([
173- Eagle3DecoderLayer (model_config , start_layer_idx + i )
190+ Eagle3DecoderLayer (model_config ,
191+ start_layer_idx + i ,
192+ is_first_layer = (i == 0 ))
174193 for i in range (self .num_layers )
175194 ])
176195 else :
@@ -184,6 +203,15 @@ def __init__(
184203 self .d2t = nn .Parameter (torch .empty ((config .draft_vocab_size , ),
185204 dtype = torch .int32 ),
186205 requires_grad = False )
206+ if self ._eh_proj_before_attn :
207+ self .enorm = RMSNorm (hidden_size = config .hidden_size ,
208+ eps = config .rms_norm_eps ,
209+ dtype = config .torch_dtype )
210+ self .eh_proj = nn .Linear (config .hidden_size * 2 ,
211+ config .hidden_size ,
212+ bias = eagle_config .get (
213+ "eh_proj_bias" , False ),
214+ dtype = config .torch_dtype )
187215
188216 if self .hidden_size_in != config .hidden_size :
189217 if model_config .mapping .enable_attention_dp :
@@ -225,11 +253,15 @@ def forward(
225253 inputs_embeds = self .embed_tokens (input_ids ).to (self .dtype )
226254
227255 assert hidden_states is not None
228-
229256 # NOTE: If hidden states from the target model have to be concatenated,
230- # we expect that to happen outside the model definition. This helps us
231- # avoid data-dependent control flow and gives us better CUDA graph
232- # coverage.
257+ # ideally, we expect that to happen outside the model definition. This
258+ # helps usavoid data-dependent control flow and gives us better CUDA
259+ # graph coverage.
260+ if self ._eh_proj_before_attn :
261+ input_embeds = self .enorm (inputs_embeds )
262+ hidden_states = torch .cat ([input_embeds , hidden_states ], dim = - 1 )
263+ hidden_states = self .eh_proj (hidden_states )
264+
233265 residual = None
234266 if self .num_layers > 1 :
235267 for layer in self .midlayer :
@@ -249,6 +281,8 @@ def forward(
249281
250282 hidden_states , hidden_states_to_save = self .norm (
251283 hidden_states , residual )
284+ if self ._return_hidden_post_norm :
285+ return hidden_states , hidden_states
252286 return hidden_states , hidden_states_to_save
253287
254288
0 commit comments