Skip to content

Commit a03df6a

Browse files
danielhanchenArthurZucker
authored andcommitted
Fix GPT-OSS swiglu_limit not passed in for MXFP4 (#40197)
Add swiglu_limit = 7.0
1 parent 170b270 commit a03df6a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/integrations/mxfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(self, config):
172172
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
173173
)
174174
self.alpha = 1.702
175-
175+
self.limit = getattr(config, "swiglu_limit", 7.0)
176176
self.gate_up_proj_precision_config = None
177177
self.down_proj_precision_config = None
178178

@@ -185,7 +185,7 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter
185185
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
186186

187187
with torch.cuda.device(hidden_states.device):
188-
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
188+
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
189189

190190
intermediate_cache1 = matmul_ogs(
191191
hidden_states,

0 commit comments

Comments
 (0)