Skip to content

Commit 2dc7ab9

Browse files
committed
Enhance numerical stability for HIP/Strix Halo compatibility
- Updated RMS LayerNorm implementation to use float32 for improved precision. - Introduced safe mode for MLP operations in Gemma and Llama models to prevent dtype mismatches. - Replaced fused cross-entropy loss with standard PyTorch CE loss in Mistral and CausalLM for better handling of NaNs on Strix Halo. - Added environment variable checks to conditionally apply these changes based on the UNSLOTH_STRIX_HALO_SAFE setting.
1 parent 3c87aa4 commit 2dc7ab9

File tree

5 files changed

+232
-132
lines changed

5 files changed

+232
-132
lines changed

unsloth/kernels/rms_layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def _gemma_rms_layernorm_forward(
147147
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
148148

149149
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
150-
inv_var = tl.math.rsqrt(row_var + eps)
150+
eps_f32 = tl.full((), eps, tl.float32)
151+
inv_var = tl.math.rsqrt(row_var + eps_f32)
151152
tl.store(r, inv_var)
152153
normed = X_row * inv_var
153154
output = normed * (W_row + 1.0)

unsloth/models/_utils.py

100644100755
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,104 @@ def unsloth_compile_transformers(
18421842
return_logits = return_logits,
18431843
supports_sdpa = supports_sdpa,
18441844
)
1845+
1846+
# After compilation, patch GPT-OSS experts on HIP/Strix-safe machines so
1847+
# that expert matmuls use matching dtypes (avoids float vs bfloat16).
1848+
if DEVICE_TYPE == "hip" and os.environ.get("UNSLOTH_STRIX_HALO_SAFE", "0") == "1":
1849+
try:
1850+
import unsloth_compiled_module_gpt_oss as _gpt_oss_compiled
1851+
1852+
@torch.compiler.disable(recursive = False)
1853+
def GptOssExperts_forward_safe(
1854+
self,
1855+
hidden_states: torch.Tensor,
1856+
router_indices = None,
1857+
routing_weights = None,
1858+
) -> torch.Tensor:
1859+
batch_size = hidden_states.shape[0]
1860+
hidden_states = hidden_states.reshape(-1, self.hidden_size)
1861+
num_experts = routing_weights.shape[1]
1862+
1863+
if hidden_states.device.type == "cpu" or self.training:
1864+
next_states = torch.zeros_like(
1865+
hidden_states,
1866+
dtype = hidden_states.dtype,
1867+
device = hidden_states.device,
1868+
)
1869+
with torch.no_grad():
1870+
expert_mask = torch.nn.functional.one_hot(
1871+
router_indices,
1872+
num_classes = num_experts + 1,
1873+
)
1874+
expert_mask = expert_mask.permute(2, 1, 0)
1875+
expert_hit = torch.greater(
1876+
expert_mask.sum(dim = (-1, -2)), 0
1877+
).nonzero()
1878+
1879+
for expert_idx in expert_hit[:]:
1880+
expert_idx = expert_idx[0]
1881+
if expert_idx == num_experts:
1882+
continue
1883+
with torch.no_grad():
1884+
_, token_idx = torch.where(expert_mask[expert_idx])
1885+
current_state = hidden_states[token_idx]
1886+
gate_up = (
1887+
current_state @ self.gate_up_proj[expert_idx]
1888+
+ self.gate_up_proj_bias[expert_idx]
1889+
)
1890+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
1891+
gate = gate.clamp(min = None, max = self.limit)
1892+
up = up.clamp(min = -self.limit, max = self.limit)
1893+
glu = gate * torch.sigmoid(gate * self.alpha)
1894+
gated_output = (up + 1) * glu
1895+
1896+
# Ensure matmul uses a consistent dtype
1897+
w = self.down_proj[expert_idx]
1898+
gated_output = gated_output.to(w.dtype)
1899+
out = gated_output @ w + self.down_proj_bias[expert_idx]
1900+
1901+
weighted_output = (
1902+
out * routing_weights[token_idx, expert_idx, None]
1903+
)
1904+
next_states.index_add_(
1905+
0,
1906+
token_idx,
1907+
weighted_output.to(hidden_states.dtype),
1908+
)
1909+
next_states = next_states.view(batch_size, -1, self.hidden_size)
1910+
else:
1911+
hidden_states = hidden_states.repeat(num_experts, 1)
1912+
hidden_states = hidden_states.view(
1913+
num_experts,
1914+
-1,
1915+
self.hidden_size,
1916+
)
1917+
gate_up = (
1918+
torch.bmm(hidden_states, self.gate_up_proj)
1919+
+ self.gate_up_proj_bias[..., None, :]
1920+
)
1921+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
1922+
gate = gate.clamp(min = None, max = self.limit)
1923+
up = up.clamp(min = -self.limit, max = self.limit)
1924+
glu = gate * torch.sigmoid(gate * self.alpha)
1925+
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
1926+
next_states = next_states + self.down_proj_bias[..., None, :]
1927+
next_states = next_states.view(
1928+
num_experts,
1929+
batch_size,
1930+
-1,
1931+
self.hidden_size,
1932+
)
1933+
next_states = next_states * routing_weights.transpose(
1934+
0, 1
1935+
).view(num_experts, batch_size, -1)[..., None]
1936+
next_states = next_states.sum(dim = 0)
1937+
return next_states
1938+
1939+
_gpt_oss_compiled.GptOssExperts_forward = GptOssExperts_forward_safe
1940+
del _gpt_oss_compiled
1941+
except Exception:
1942+
pass
18451943
# Redo patches which override compiler
18461944
for temporary_patch in TEMPORARY_PATCHES:
18471945
temporary_patch()

unsloth/models/gemma.py

100644100755
Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
import math
2020
import os
2121

22-
_DISABLE_TRITON_RMSNORM = os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1"
23-
_LAYERNORM_IMPL = os.getenv("UNSLOTH_LAYERNORM_IMPL", "").lower()
24-
_DISABLE_AUTODTYPE_CAST = os.getenv("UNSLOTH_DISABLE_AUTODTYPE_CAST", "0") == "1"
22+
_STRIX_HALO_SAFE = os.getenv("UNSLOTH_STRIX_HALO_SAFE", "0") == "1"
23+
_GEMMA_STRIX_SAFE = ("hip" == DEVICE_TYPE) and _STRIX_HALO_SAFE
2524

2625
try:
2726
from transformers.models.gemma.modeling_gemma import (
@@ -123,16 +122,18 @@ def GemmaDecoderLayer_fast_forward(
123122
hidden_states = fast_rms_layernorm_inference_gemma(
124123
self.post_attention_layernorm, hidden_states, out_weight
125124
)
126-
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
127-
hidden_states += residual
125+
if _GEMMA_STRIX_SAFE:
126+
mlp_in = hidden_states.to(torch.float32)
127+
mlp_out = self.mlp(mlp_in)
128+
hidden_states = residual + mlp_out.to(hidden_states.dtype)
129+
else:
130+
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
131+
hidden_states += residual
128132
else:
129133
residual = hidden_states
130-
if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python":
131-
hidden_states = self.input_layernorm(hidden_states)
132-
else:
133-
hidden_states = fast_rms_layernorm(
134-
self.input_layernorm, hidden_states, gemma = True
135-
)
134+
hidden_states = fast_rms_layernorm(
135+
self.input_layernorm, hidden_states, gemma = True
136+
)
136137
hidden_states, self_attn_weights, present_key_value = self.self_attn(
137138
hidden_states = hidden_states,
138139
causal_mask = causal_mask,
@@ -147,14 +148,19 @@ def GemmaDecoderLayer_fast_forward(
147148

148149
# Fully Connected
149150
residual = hidden_states
150-
if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python":
151-
hidden_states = self.post_attention_layernorm(hidden_states)
151+
hidden_states = fast_rms_layernorm(
152+
self.post_attention_layernorm, hidden_states, gemma = True
153+
)
154+
155+
# On Strix Halo (HIP) in safe mode, run Gemma MLP in float32 for
156+
# numerical stability, then cast back. Else use the default path.
157+
if _GEMMA_STRIX_SAFE:
158+
mlp_in = hidden_states.to(torch.float32)
159+
mlp_out = self.mlp(mlp_in)
160+
hidden_states = residual + mlp_out.to(hidden_states.dtype)
152161
else:
153-
hidden_states = fast_rms_layernorm(
154-
self.post_attention_layernorm, hidden_states, gemma = True
155-
)
156-
hidden_states = self.mlp(hidden_states)
157-
hidden_states = residual + hidden_states
162+
hidden_states = self.mlp(hidden_states)
163+
hidden_states = residual + hidden_states
158164

159165
outputs = (hidden_states,)
160166
if output_attentions:
@@ -186,8 +192,7 @@ def GemmaModel_fast_forward_inference(
186192
)
187193
input_ids = input_ids[:, : self.max_seq_length]
188194
hidden_states = self.model.embed_tokens(input_ids)
189-
if not _DISABLE_AUTODTYPE_CAST:
190-
hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
195+
hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
191196
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
192197
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
193198
hidden_states *= torch.tensor(

0 commit comments

Comments
 (0)