Skip to content
3 changes: 3 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,9 @@ def _is_openai_available():
elif DEVICE_TYPE == "xpu":
SUPPORTS_BFLOAT16 = True

if os.getenv("UNSLOTH_DISABLE_FLASH_ATTENTION", "0") == "1":
HAS_FLASH_ATTENTION = False

# =============================================
# Get Xformers
try:
Expand Down
26 changes: 19 additions & 7 deletions unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from unsloth_zoo.utils import _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
import math
import os

_DISABLE_TRITON_RMSNORM = os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1"
_LAYERNORM_IMPL = os.getenv("UNSLOTH_LAYERNORM_IMPL", "").lower()
_DISABLE_AUTODTYPE_CAST = os.getenv("UNSLOTH_DISABLE_AUTODTYPE_CAST", "0") == "1"

try:
from transformers.models.gemma.modeling_gemma import (
Expand Down Expand Up @@ -122,9 +127,12 @@ def GemmaDecoderLayer_fast_forward(
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(
self.input_layernorm, hidden_states, gemma = True
)
if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python":
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states = fast_rms_layernorm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait are you certain this causes NaN issues? That's extremely weird

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Google colab nvidia T4
image

StrixHalo
image

let me pinpoint it more

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting eps to float32 resolved that

image

but im still getting nans for

m,tok = FastModel.from_pretrained(
"unsloth/gemma-3-4b-it",
load_in_4bit=False, load_in_8bit=False, full_finetuning=False,
)
m.train().cuda()
b = tok(text=["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])

self.input_layernorm, hidden_states, gemma = True
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states = hidden_states,
causal_mask = causal_mask,
Expand All @@ -139,9 +147,12 @@ def GemmaDecoderLayer_fast_forward(

# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(
self.post_attention_layernorm, hidden_states, gemma = True
)
if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python":
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = fast_rms_layernorm(
self.post_attention_layernorm, hidden_states, gemma = True
)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -175,7 +186,8 @@ def GemmaModel_fast_forward_inference(
)
input_ids = input_ids[:, : self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
if not _DISABLE_AUTODTYPE_CAST:
hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(
Expand Down
107 changes: 97 additions & 10 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,29 @@
from peft import PeftModelForCausalLM, PeftModelForSequenceClassification
from ..save import patch_saving_functions
import re, os, inspect, math, sys

# One-time debug flags to avoid repeated logs in hot paths
_LOGGED_ROPE = False
_LOGGED_ATTENTION_FA2 = False
_LOGGED_RMSNORM = False
import types

_FA2_COMPUTE_DTYPE_MAP = {
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"f16": torch.float16,
"float16": torch.float16,
"fp16": torch.float16,
}

_ROPE_IMPL = os.getenv("UNSLOTH_ROPE_IMPL", "").lower()
_DISABLE_TRITON_ROPE = os.getenv("UNSLOTH_DISABLE_TRITON_ROPE", "0") == "1"

_LAYERNORM_IMPL = os.getenv("UNSLOTH_LAYERNORM_IMPL", "").lower()
_DISABLE_TRITON_RMSNORM = os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1"

_FA2_COMPUTE_DTYPE_TARGET = os.getenv("UNSLOTH_FA2_COMPUTE_DTYPE", "").lower()

try:
from huggingface_hub.utils import get_token
except:
Expand Down Expand Up @@ -549,8 +570,13 @@ def LlamaAttention_fast_forward(
del self.RH_Q
del self.attention

global _LOGGED_ROPE, _LOGGED_ATTENTION_FA2

bsz, q_len, _ = hidden_states.size()

# Preserve original dtype for attention output
_orig_out_dtype = hidden_states.dtype

n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
Expand Down Expand Up @@ -580,12 +606,24 @@ def LlamaAttention_fast_forward(
# cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device)
cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)

# Q, K = (
# fast_rope_embedding(Q, K, cos, sin)
# if position_ids is None
# else inplace_rope_embedding(Q, K, cos, sin, position_ids)
# )
Q, K = fast_rope_embedding(Q, K, cos, sin)
# Allow switching RoPE implementation to a non-Triton path on HIP if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very weird if RMS Layernorm even doesn't work! MI300X seems to be fine for GPT-OSS hmm

Did you try https://docs.unsloth.ai/get-started/install-and-update/amd ie

pip install --upgrade torch==2.8.0 pytorch-triton-rocm torchvision torchaudio torchao==0.13.0 xformers --index-url https://download.pytorch.org/whl/rocm6.4

pip install --no-deps unsloth unsloth-zoo
pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git
pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@billishyahao This might be interesting ie normal Unsloth using Triton and FA2 gets NaNs for GPT-OSS for Strix Halo

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure the precompiled torch from rocm pypi doesn't work is torch==2.8.0 a must?
I did try pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth", let me try it again.

Copy link
Author

@0xrushi 0xrushi Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it doesn't . The dockerfile I linked works better. I used the same Dockerfile to install "unsloth[amd] @ git+https://github.com/unslothai/unsloth" and pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git.

import torch
torch.__version__
'2.8.0+rocm6.4'

import os, torch
from unsloth import FastModel
from transformers import AutoTokenizer

# os.environ['UNSLOTH_FA2_COMPUTE_DTYPE'] = 'bfloat16'
# os.environ['UNSLOTH_ROPE_IMPL'] = 'slow'
# os.environ['UNSLOTH_DISABLE_TRITON_RMSNORM'] = '1'

model_id = "unsloth/gpt-oss-20b-BF16"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tok.pad_token = tok.pad_token or tok.eos_token

m, _ = FastModel.from_pretrained(
model_id,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
float32_mixed_precision=False,
use_gradient_checkpointing=False,
attn_implementation="eager",
)

m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")

out = m(**b, labels=b["input_ids"])
loss = out.loss

print("loss:", float(loss.detach()))
loss.backward() # exactly once per forward

has_nan = any(p.grad is not None and torch.isnan(p.grad).any() for p in m.parameters())
print("grad_has_nan:", has_nan)
---------------------------------------------------------------------------
AcceleratorError                          Traceback (most recent call last)
Cell In[1], line 13
     10 tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
     11 tok.pad_token = tok.pad_token or tok.eos_token
---> 13 m, _ = FastModel.from_pretrained(
     14 model_id,
     15 load_in_4bit=False,
     16 load_in_8bit=False,
     17 full_finetuning=False,
     18 float32_mixed_precision=False,
     19 use_gradient_checkpointing=False,
     20 attn_implementation="eager",
     21 )
     23 m.train().cuda()
     24 b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/loader.py:1065](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/loader.py#line=1064), in FastModel.from_pretrained(model_name, max_seq_length, dtype, load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning, token, device_map, rope_scaling, fix_tokenizer, trust_remote_code, use_gradient_checkpointing, resize_model_vocab, revision, return_logits, fullgraph, use_exact_model_name, auto_model, whisper_language, whisper_task, unsloth_force_compile, offload_embedding, float32_mixed_precision, fast_inference, gpu_memory_utilization, float8_kv_cache, random_state, max_lora_rank, disable_log_stats, qat_scheme, *args, **kwargs)
   1062 if auto_model is None:
   1063     auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
-> 1065 model, tokenizer = FastBaseModel.from_pretrained(
   1066     model_name = model_name,
   1067     max_seq_length = max_seq_length,
   1068     dtype = _get_dtype(dtype),
   1069     load_in_4bit = load_in_4bit,
   1070     load_in_8bit = load_in_8bit,
   1071     load_in_16bit = load_in_16bit,
   1072     full_finetuning = full_finetuning,
   1073     token = token,
   1074     device_map = device_map,
   1075     trust_remote_code = trust_remote_code,
   1076     revision = revision if not is_peft else None,
   1077     model_types = model_types,
   1078     tokenizer_name = tokenizer_name,
   1079     auto_model = auto_model,
   1080     use_gradient_checkpointing = use_gradient_checkpointing,
   1081     supports_sdpa = supports_sdpa,
   1082     whisper_language = whisper_language,
   1083     whisper_task = whisper_task,
   1084     auto_config = model_config,
   1085     offload_embedding = offload_embedding,
   1086     float32_mixed_precision = float32_mixed_precision,
   1087     # Pass vLLM[/inference](http://192.168.1.166:8889/inference) parameters
   1088     fast_inference = fast_inference,
   1089     gpu_memory_utilization = gpu_memory_utilization,
   1090     float8_kv_cache = float8_kv_cache,
   1091     random_state = random_state,
   1092     max_lora_rank = max_lora_rank,
   1093     disable_log_stats = disable_log_stats,
   1094     *args,
   1095     **kwargs,
   1096 )
   1098 if resize_model_vocab is not None:
   1099     model.resize_token_embeddings(resize_model_vocab)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/vision.py:763](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/vision.py#line=762), in FastBaseModel.from_pretrained(model_name, max_seq_length, dtype, load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning, token, device_map, trust_remote_code, model_types, tokenizer_name, auto_model, use_gradient_checkpointing, supports_sdpa, whisper_language, whisper_task, auto_config, offload_embedding, float32_mixed_precision, fast_inference, gpu_memory_utilization, float8_kv_cache, random_state, max_lora_rank, disable_log_stats, unsloth_vllm_standby, **kwargs)
    761     with torch.no_grad():
    762         for jj, (name, module) in enumerate(model.named_modules()):
--> 763             exec(custom_datatype)
    764 # Clear deleted GPU items
    765 for _ in range(3):

File <string>:2

AcceleratorError: HIP error: invalid device function
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.

if _ROPE_IMPL == "slow" or _DISABLE_TRITON_ROPE:
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
_rope_impl_name = "slow (torch)"
else:
Q, K = fast_rope_embedding(Q, K, cos, sin)
_rope_impl_name = "triton (fast)"

if not _LOGGED_ROPE:
logger.debug(
"Unsloth: RoPE=%s. device=%s cos_dtype=%s env(UNSLOTH_ROPE_IMPL=%s, UNSLOTH_DISABLE_TRITON_ROPE=%s)",
_rope_impl_name,
DEVICE_TYPE_TORCH,
str(cos.dtype),
os.getenv("UNSLOTH_ROPE_IMPL"),
os.getenv("UNSLOTH_DISABLE_TRITON_ROPE"),
)
_LOGGED_ROPE = True

if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
Expand Down Expand Up @@ -618,7 +656,33 @@ def LlamaAttention_fast_forward(
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Allow selecting a compute dtype, defaulting to fp16 if inputs are bf16.
_fa2_in_dtype = Q.dtype
if DEVICE_TYPE == "hip":
if _FA2_COMPUTE_DTYPE_TARGET in _FA2_COMPUTE_DTYPE_MAP:
target_dtype = _FA2_COMPUTE_DTYPE_MAP[_FA2_COMPUTE_DTYPE_TARGET]
if Q.dtype != target_dtype:
Q = Q.to(target_dtype)
K = K.to(target_dtype)
V = V.to(target_dtype)
elif Q.dtype == torch.bfloat16:
Q = Q.to(torch.float16)
K = K.to(torch.float16)
V = V.to(torch.float16)
A = flash_attn_func(Q, K, V, causal = True)
if A.dtype != _orig_out_dtype:
A = A.to(_orig_out_dtype)
if not _LOGGED_ATTENTION_FA2:
logger.debug(
"Unsloth: Attention=FlashAttention2. device=%s in_dtype=%s used_dtype=%s out_dtype=%s env(UNSLOTH_FA2_COMPUTE_DTYPE=%s, UNSLOTH_DISABLE_FLASH_ATTENTION=%s)",
DEVICE_TYPE,
str(_fa2_in_dtype),
str(Q.dtype),
str(A.dtype),
os.getenv("UNSLOTH_FA2_COMPUTE_DTYPE"),
os.getenv("UNSLOTH_DISABLE_FLASH_ATTENTION"),
)
_LOGGED_ATTENTION_FA2 = True
else:
# when qlen==vlen and attn_mask is None, we should use causal attention
Q_len = Q.shape[-2]
Expand Down Expand Up @@ -697,6 +761,8 @@ def LlamaDecoderLayer_fast_forward(
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
global _LOGGED_RMSNORM

if use_cache and hasattr(self, "_flag_for_generation"):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(
Expand Down Expand Up @@ -724,7 +790,21 @@ def LlamaDecoderLayer_fast_forward(
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python":
hidden_states = self.input_layernorm(hidden_states)
_rms_impl_name = "torch"
else:
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
_rms_impl_name = "triton"

if not _LOGGED_RMSNORM:
logger.debug(
"Unsloth: RMSNorm=%s. env(UNSLOTH_DISABLE_TRITON_RMSNORM=%s, UNSLOTH_LAYERNORM_IMPL=%s)",
_rms_impl_name,
os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM"),
os.getenv("UNSLOTH_LAYERNORM_IMPL"),
)
_LOGGED_RMSNORM = True
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states = hidden_states,
causal_mask = causal_mask,
Expand All @@ -740,7 +820,12 @@ def LlamaDecoderLayer_fast_forward(

# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python":
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = fast_rms_layernorm(
self.post_attention_layernorm, hidden_states
)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -853,7 +938,9 @@ def LlamaModel_fast_forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

inputs_embeds = inputs_embeds.to(_get_dtype(dtype_from_config(self.config)))
# Allow overriding the automatic dtype cast via env for diagnostics
if os.getenv("UNSLOTH_DISABLE_AUTODTYPE_CAST", "0") != "1":
inputs_embeds = inputs_embeds.to(_get_dtype(dtype_from_config(self.config)))

# Normalized from Gemma
IS_GEMMA = self.config.model_type.startswith("gemma")
Expand Down Expand Up @@ -1349,7 +1436,7 @@ def _CausalLM_fast_forward(
labels = labels.to(lm_head_device)

# Output last hidden states without logits if asked
if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
if os.getenv("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
return CausalLMOutputWithPast(
Expand Down