Skip to content

Conversation

@0xrushi
Copy link

@0xrushi 0xrushi commented Nov 12, 2025

Resolves #3385 (comment)

Summary

  • Fine-tuning Gemma‑3 on AMD Strix Halo (HIP/ROCm) produced NaN losses
  • NaNs came from the first transformer block (forward), not the optimizer.

Root Cause

  • On HIP (gfx1151, ROCm 6.4), bfloat16 FlashAttention2 can be numerically unstable.
  • Unsloth routed Gemma‑3 attention through FlashAttention2 in bf16, triggering NaN activations.

What We Changed

  • Keep FlashAttention2 for performance, but on HIP run its math in float16 (safer), then cast results back.
  • Added opt‑in env toggles for adjacent kernels (RoPE, RMSNorm) and for diagnostics only.
  • Added a debug log (logger.debug) to confirm the actual paths/dtypes when DEBUG verbosity is enabled.

Validation

import torch, importlib
import os 

mods = ["unsloth","unsloth_zoo","transformers","trl","accelerate","peft","xformers","bitsandbytes","triton"]
for m in mods:
    try:
        print(m, importlib.import_module(m).__version__)
    except Exception as e:
        print(m, "not found")
print("torch:", torch.__version__, "HIP:", torch.version.hip)
print("cuda.is_available:", torch.cuda.is_available(), "bf16_supported:", torch.cuda.is_bf16_supported())
print("device:", torch.cuda.get_device_name(0))
unsloth 2025.11.3
unsloth_zoo 2025.11.3
transformers 4.57.1
trl 0.24.0
accelerate 1.11.0
peft 0.17.1
xformers not found
bitsandbytes 0.49.0.dev0
triton 3.5.1
torch: 2.10.0a0+rocm7.10.0a20251015 HIP: 7.1.25413-7721681424
cuda.is_available: True bf16_supported: True
device: Radeon 8060S Graphics
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import os
os.environ['UNSLOTH_FA2_COMPUTE_DTYPE'] = 'float16'
os.environ['UNSLOTH_ROPE_IMPL'] = 'slow'
os.environ['UNSLOTH_DISABLE_TRITON_RMSNORM'] = '1'

import unsloth, inspect
import unsloth.models.llama as L
print("unsloth_file:", unsloth.__file__)
print("llama_file:", L.__file__)

from unsloth import FastModel
from transformers import AutoTokenizer
import torch

tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained(
"unsloth/gemma-3-4b-it",
load_in_4bit=False, load_in_8bit=False, full_finetuning=False,
)
m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])
print("loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))
out.loss.backward()
has_nan = any(p.grad is not None and torch.isnan(p.grad).any() for p in
m.parameters())
print("grad_has_nan:", has_nan)
bitsandbytes library load error: Configured ROCm binary not found at /opt/venv/lib64/python3.13/site-packages/bitsandbytes/libbitsandbytes_rocm71.so
Traceback (most recent call last):
  File "/opt/venv/lib64/python3.13/site-packages/bitsandbytes/cextension.py", line 313, in <module>
    lib = get_native_library()
  File "/opt/venv/lib64/python3.13/site-packages/bitsandbytes/cextension.py", line 282, in get_native_library
    raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")
RuntimeError: Configured ROCm binary not found at /opt/venv/lib64/python3.13/site-packages/bitsandbytes/libbitsandbytes_rocm71.so


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


/opt/venv/lib64/python3.13/site-packages/torch/library.py:356: UserWarning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
    registered at /opt/venv/lib64/python3.13/site-packages/torch/_library/custom_ops.py:922
  dispatch key: ADInplaceOrView
  previous kernel: no debug info
       new kernel: registered at /opt/venv/lib64/python3.13/site-packages/torch/_library/custom_ops.py:922 (Triggered internally at /__w/TheRock/TheRock/external-builds/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
  self.m.impl(


🦥 Unsloth Zoo will now patch everything to make training faster!
unsloth_file: /opt/venv/lib64/python3.13/site-packages/unsloth/__init__.py
llama_file: /opt/venv/lib64/python3.13/site-packages/unsloth/models/llama.py


/opt/venv/lib64/python3.13/site-packages/unsloth_zoo/gradient_checkpointing.py:348: UserWarning: expandable_segments not supported on this platform (Triggered internally at /__w/TheRock/TheRock/external-builds/pytorch/pytorch/c10/hip/HIPAllocatorConfig.h:36.)
  GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"{DEVICE_TYPE_TORCH}:{i}") for i in range(n_gpus)])


==((====))==  Unsloth 2025.11.3: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    Radeon 8060S Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0a0+rocm7.10.0a20251015. ROCm Toolkit: 7.1.25413-7721681424. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


/tmp/ipykernel_3852/715560139.py:23: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /__w/TheRock/TheRock/external-builds/pytorch/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  print("loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))


loss_is_nan: False loss: 15.633398056030273
grad_has_nan: False
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])
print("forward_loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))
out.loss.backward()
has_nan = any(p.grad is not None and torch.isnan(p.grad).any() for p in m.parameters())
print("grad_has_nan:", has_nan)
==((====))==  Unsloth 2025.11.3: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    Radeon 8060S Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0a0+rocm7.10.0a20251015. ROCm Toolkit: 7.1.25413-7721681424. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


forward_loss_is_nan: False loss: 15.633398056030273
grad_has_nan: False
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])
print("FA/xformers disabled -> loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))
==((====))==  Unsloth 2025.11.3: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    Radeon 8060S Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0a0+rocm7.10.0a20251015. ROCm Toolkit: 7.1.25413-7721681424. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


FA/xformers disabled -> loss_is_nan: False loss: 15.633398056030273
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m = m.to(dtype=torch.float32).cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True)
b = {k:(v.to("cuda").to(torch.float32) if v.dtype.is_floating_point else v.to("cuda")) for k,v in b.items()}
with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=False):
    out = m(**b, labels=b["input_ids"])
print("fp32 forced -> loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))
==((====))==  Unsloth 2025.11.3: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    Radeon 8060S Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0a0+rocm7.10.0a20251015. ROCm Toolkit: 7.1.25413-7721681424. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


fp32 forced -> loss_is_nan: False loss: 15.536109924316406
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m.eval().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
with torch.no_grad():
    out = m(**b, return_dict=True)
logits = out.logits
print("logits_dtype:", logits.dtype, "shape:", tuple(logits.shape))
print("logits_has_nan:", torch.isnan(logits).any().item(), "has_inf:", torch.isinf(logits).any().item())
==((====))==  Unsloth 2025.11.3: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    Radeon 8060S Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0a0+rocm7.10.0a20251015. ROCm Toolkit: 7.1.25413-7721681424. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


logits_dtype: torch.bfloat16 shape: (2, 3, 262208)
logits_has_nan: False has_inf: False

0xrushi and others added 2 commits November 11, 2025 22:21
…lementations

- Introduced UNSLOTH_DISABLE_FLASH_ATTENTION to disable Flash Attention based on environment variable.
- Added checks for UNSLOTH_DISABLE_TRITON_RMSNORM and UNSLOTH_LAYERNORM_IMPL to conditionally use different LayerNorm implementations in Gemma and Llama models.
- Updated dtype handling in Llama model to allow overriding automatic dtype casting via environment variable.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @0xrushi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses critical numerical instability issues encountered when fine-tuning Gemma-3 models on AMD Strix Halo (gfx1151) using HIP/ROCm, which previously led to NaN losses. The core solution involves carefully managing data types within FlashAttention2 operations and providing configurable fallbacks for other performance-optimized kernels like RoPE and RMSNorm, ensuring stable training while retaining performance where possible.

Highlights

  • Fix NaN Losses on AMD Strix Halo (gfx1151): This pull request resolves numerical instability issues that caused NaN losses during Gemma-3 fine-tuning on AMD Strix Halo (gfx1151) using HIP/ROCm, specifically tracing the issue to the first transformer block's forward pass.
  • FlashAttention2 Dtype Handling for HIP: To prevent NaNs, FlashAttention2 on HIP platforms (gfx1151) will now perform its internal math in float16 when inputs are bfloat16, then cast the results back to the original bfloat16 dtype, improving numerical stability while retaining performance.
  • Configurable Kernel Implementations via Environment Variables: New environment variables have been introduced to provide granular control over various kernel implementations. Users can now opt-out of fast kernels (e.g., UNSLOTH_DISABLE_FLASH_ATTENTION, UNSLOTH_ROPE_IMPL='slow', UNSLOTH_DISABLE_TRITON_RMSNORM) or specify compute dtypes (UNSLOTH_FA2_COMPUTE_DTYPE) for debugging and stability purposes.
  • Enhanced Debugging and Logging: Debug logging has been added for RoPE, FlashAttention2, and RMSNorm implementations. When debug verbosity is enabled, these logs will confirm the actual paths and data types being used, aiding in diagnostics.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several fixes and diagnostic toggles to address numerical instability issues with bfloat16 on AMD gfx1151 hardware. The core change involves conditionally casting to float16 for FlashAttention2 on HIP, which is a sound approach to mitigate NaN losses. Additionally, environment variables are added to control RoPE and RMSNorm kernel implementations, providing valuable debugging flexibility. My review focuses on improving code maintainability by reducing duplication and making logging more consistent. Overall, these are solid changes for improving hardware compatibility and debuggability.

Comment on lines 792 to 804
_disable_triton_rms = (
os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1"
)
_ln_impl = os.environ.get("UNSLOTH_LAYERNORM_IMPL", "").lower()
if _disable_triton_rms or _ln_impl == "python":
hidden_states = self.input_layernorm(hidden_states)
if not _LOGGED_RMSNORM:
logger.debug(
"Unsloth: RMSNorm=torch. env(UNSLOTH_DISABLE_TRITON_RMSNORM=%s, UNSLOTH_LAYERNORM_IMPL=%s)",
os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM"),
os.environ.get("UNSLOTH_LAYERNORM_IMPL"),
)
_LOGGED_RMSNORM = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The debug logging for RMSNorm implementation only covers the torch path. For consistency and better diagnostics, a logger.debug call should also be added to the else block to indicate when the fast Triton kernel (fast_rms_layernorm) is being used. This can be refactored to reduce duplication, similar to the RoPE logging.

        _disable_triton_rms = os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1"
        _ln_impl = os.environ.get("UNSLOTH_LAYERNORM_IMPL", "").lower()
        if _disable_triton_rms or _ln_impl == "python":
            hidden_states = self.input_layernorm(hidden_states)
            _rms_impl_name = "torch"
        else:
            hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
            _rms_impl_name = "triton"

        if not _LOGGED_RMSNORM:
            logger.debug(
                "Unsloth: RMSNorm=%s. env(UNSLOTH_DISABLE_TRITON_RMSNORM=%s, UNSLOTH_LAYERNORM_IMPL=%s)",
                _rms_impl_name,
                os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM"),
                os.environ.get("UNSLOTH_LAYERNORM_IMPL"),
            )
            _LOGGED_RMSNORM = True

hidden_states = fast_rms_layernorm(
self.input_layernorm, hidden_states, gemma = True
)
_disable_triton_rms = os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielhanchen I think maybe adding a flag called impl or something to fast_rms_layernorm is a cleaner approach but I'd let you comment on the performance implications of that

Also I think maybe these flags don't change on runtime, so can be inferred once at the start of the model file and shared across instead of computing for every layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should do it once only at the start as a global variable

_fa2_in_dtype = Q.dtype
if DEVICE_TYPE == "hip":
target = os.environ.get("UNSLOTH_FA2_COMPUTE_DTYPE", "").lower()
if target in _FA2_COMPUTE_DTYPE_MAP:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even these can be potentially inferred once at startup.
Also does the FA inaccuracy bug effect every model? If yes we should probably replicate this across other models (like mistral or qwen for eg)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running the gpt20b notebook and im getting kernel crashes, its a long way to debug

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know when you feel comfortable with this. I will be happy to review :)

Copy link
Author

@0xrushi 0xrushi Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qwen is fine

image

mistral too

image

gpt oss

image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Datta0 gptoss needs that extra casting

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    out = m(**b, labels=b["input_ids"])

Its fine here, but will that break somewhere else that abstracts anything similar ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that is interesting. Do you happen to have the error stack trace when the autocast is not set?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Copyright (C) 2023 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.11.3: Fast Gpt_Oss patching. Transformers: 4.57.1.
   \\   /|    AMD Radeon Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+git99ccf24. ROCm Toolkit: 6.4.43484-123eb5128. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30+13c93f39.d20251112. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
[/opt/conda/envs/py_3.12/lib/python3.12/site-packages/accelerate/utils/modeling.py:804](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/accelerate/utils/modeling.py#line=803): UserWarning: expandable_segments not supported on this platform (Triggered internally at [/var/lib/jenkins/pytorch/c10/hip/HIPAllocatorConfig.h:29](http://192.168.1.166:8889/var/lib/jenkins/pytorch/c10/hip/HIPAllocatorConfig.h#line=28).)
  _ = torch.tensor([0], device=i)
Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 26
     23 m.train().cuda()
     24 b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
---> 26 out = m(**b, labels=b["input_ids"])
     27 loss = out.loss
     29 print("loss:", float(loss.detach()))

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:726](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=725), in GptOssForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
    712 def forward(
    713     self,
    714     input_ids: Optional[torch.LongTensor] = None,
   (...)    724     **kwargs: Unpack[TransformersKwargs],
    725 ) -> MoeCausalLMOutputWithPast:
--> 726     return GptOssForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/generic.py:918](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/generic.py#line=917), in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    916 if return_dict_passed is not None:
    917     return_dict = return_dict_passed
--> 918 output = func(self, *args, **kwargs)
    919 if not return_dict and not isinstance(output, tuple):
    920     output = output.to_tuple()

File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:547](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=546), in GptOssForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
    542 output_router_logits = (
    543     output_router_logits if output_router_logits is not None else self.config.output_router_logits
    544 )
    546 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 547 outputs: MoeModelOutputWithPast = self.model(
    548     input_ids=input_ids,
    549     attention_mask=attention_mask,
    550     position_ids=position_ids,
    551     past_key_values=past_key_values,
    552     inputs_embeds=inputs_embeds,
    553     use_cache=use_cache,
    554     output_router_logits=output_router_logits,
    555     cache_position=cache_position,
    556     **kwargs,
    557 )
    559 hidden_states = outputs.last_hidden_state
    560 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gpt_oss.py:1249](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gpt_oss.py#line=1248), in patch_GptOssModel.<locals>.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)
   1247 for decoder_layer in self.layers:
   1248     mask = attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask
-> 1249     hidden_states = decoder_layer(
   1250         hidden_states,
   1251         attention_mask=mask,
   1252         position_ids=position_ids,
   1253         past_key_values=past_key_values,
   1254         use_cache=use_cache,
   1255         cache_position=cache_position,
   1256         position_embeddings=position_embeddings,
   1257         **kwargs,
   1258     )
   1259 pass
   1260 hidden_states = self.norm(hidden_states)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/modeling_layers.py:93](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/modeling_layers.py#line=92), in GradientCheckpointingLayer.__call__(self, *args, **kwargs)
     90         message = message.rstrip(",") + "."
     91         logger.warning_once(message)
---> 93     return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
     94 return super().__call__(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_compile.py:51](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_compile.py#line=50), in _disable_dynamo.<locals>.inner(*args, **kwargs)
     48     disable_fn = torch._dynamo.disable(fn, recursive)
     49     fn.__dynamo_disable = disable_fn  # type: ignore[attr-defined]
---> 51 return disable_fn(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:838](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py#line=837), in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    836 _maybe_set_eval_frame(_callback_from_stance(self.callback))
    837 try:
--> 838     return fn(*args, **kwargs)
    839 finally:
    840     set_eval_frame(None)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py:488](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py#line=487), in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
    483     if context_fn is not noop_context_fn or debug is not False:
    484         raise ValueError(
    485             "Passing `context_fn` or `debug` is only supported when "
    486             "use_reentrant=False."
    487         )
--> 488     return CheckpointFunction.apply(function, preserve, *args)
    489 else:
    490     gen = _checkpoint_without_reentrant_generator(
    491         function, preserve, context_fn, determinism_check, debug, *args, **kwargs
    492     )

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/autograd/function.py:575](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/autograd/function.py#line=574), in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         "In order to use an autograd.Function with functorch transforms "
    580         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    581         "staticmethod. For more details, please see "
    582         "https://pytorch.org/docs/main/notes/extending.func.html"
    583     )

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py:263](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py#line=262), in CheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
    260 ctx.save_for_backward(*tensor_inputs)
    262 with torch.no_grad():
--> 263     outputs = run_function(*args)
    264 return outputs

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/deprecation.py:172](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/deprecation.py#line=171), in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
    168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
    169     # DeprecationWarning is ignored by default, so we use FutureWarning instead
    170     warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py:386](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py#line=385), in GptOssDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)
    384 residual = hidden_states
    385 hidden_states = self.post_attention_layernorm(hidden_states)
--> 386 hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores
    387 hidden_states = residual + hidden_states
    388 return hidden_states

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:309](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=308), in GptOssMLP.forward(self, hidden_states)
    308 def forward(self, hidden_states):
--> 309     return GptOssMLP_forward(self, hidden_states)

File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:298](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=297), in GptOssMLP_forward(self, hidden_states)
    295 @torch.compiler.disable(recursive = False)
    296 def GptOssMLP_forward(self, hidden_states):
    297     router_scores, router_indices = self.router(hidden_states)  # (num_experts, seq_len)
--> 298     routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
    299     return routed_out, router_scores

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:270](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=269), in GptOssExperts.forward(self, hidden_states, router_indices, routing_weights)
    269 def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
--> 270     return GptOssExperts_forward(self, hidden_states, router_indices, routing_weights)

File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:236](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=235), in GptOssExperts_forward(self, hidden_states, router_indices, routing_weights)
    234 glu = gate * torch.sigmoid(gate * self.alpha)
    235 gated_output = (up + 1) * glu
--> 236 out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
    237 weighted_output = out * routing_weights[token_idx, expert_idx, None]
    238 next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

but SFTTrainer worked, notebook here unslothai/notebooks#126

if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python":
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states = fast_rms_layernorm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait are you certain this causes NaN issues? That's extremely weird

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Google colab nvidia T4
image

StrixHalo
image

let me pinpoint it more

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting eps to float32 resolved that

image

but im still getting nans for

m,tok = FastModel.from_pretrained(
"unsloth/gemma-3-4b-it",
load_in_4bit=False, load_in_8bit=False, full_finetuning=False,
)
m.train().cuda()
b = tok(text=["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])

# else inplace_rope_embedding(Q, K, cos, sin, position_ids)
# )
Q, K = fast_rope_embedding(Q, K, cos, sin)
# Allow switching RoPE implementation to a non-Triton path on HIP if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very weird if RMS Layernorm even doesn't work! MI300X seems to be fine for GPT-OSS hmm

Did you try https://docs.unsloth.ai/get-started/install-and-update/amd ie

pip install --upgrade torch==2.8.0 pytorch-triton-rocm torchvision torchaudio torchao==0.13.0 xformers --index-url https://download.pytorch.org/whl/rocm6.4

pip install --no-deps unsloth unsloth-zoo
pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git
pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@billishyahao This might be interesting ie normal Unsloth using Triton and FA2 gets NaNs for GPT-OSS for Strix Halo

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure the precompiled torch from rocm pypi doesn't work is torch==2.8.0 a must?
I did try pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth", let me try it again.

Copy link
Author

@0xrushi 0xrushi Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea it doesn't . The dockerfile I linked works better. I used the same Dockerfile to install "unsloth[amd] @ git+https://github.com/unslothai/unsloth" and pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git.

import torch
torch.__version__
'2.8.0+rocm6.4'

import os, torch
from unsloth import FastModel
from transformers import AutoTokenizer

# os.environ['UNSLOTH_FA2_COMPUTE_DTYPE'] = 'bfloat16'
# os.environ['UNSLOTH_ROPE_IMPL'] = 'slow'
# os.environ['UNSLOTH_DISABLE_TRITON_RMSNORM'] = '1'

model_id = "unsloth/gpt-oss-20b-BF16"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tok.pad_token = tok.pad_token or tok.eos_token

m, _ = FastModel.from_pretrained(
model_id,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
float32_mixed_precision=False,
use_gradient_checkpointing=False,
attn_implementation="eager",
)

m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")

out = m(**b, labels=b["input_ids"])
loss = out.loss

print("loss:", float(loss.detach()))
loss.backward() # exactly once per forward

has_nan = any(p.grad is not None and torch.isnan(p.grad).any() for p in m.parameters())
print("grad_has_nan:", has_nan)
---------------------------------------------------------------------------
AcceleratorError                          Traceback (most recent call last)
Cell In[1], line 13
     10 tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
     11 tok.pad_token = tok.pad_token or tok.eos_token
---> 13 m, _ = FastModel.from_pretrained(
     14 model_id,
     15 load_in_4bit=False,
     16 load_in_8bit=False,
     17 full_finetuning=False,
     18 float32_mixed_precision=False,
     19 use_gradient_checkpointing=False,
     20 attn_implementation="eager",
     21 )
     23 m.train().cuda()
     24 b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/loader.py:1065](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/loader.py#line=1064), in FastModel.from_pretrained(model_name, max_seq_length, dtype, load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning, token, device_map, rope_scaling, fix_tokenizer, trust_remote_code, use_gradient_checkpointing, resize_model_vocab, revision, return_logits, fullgraph, use_exact_model_name, auto_model, whisper_language, whisper_task, unsloth_force_compile, offload_embedding, float32_mixed_precision, fast_inference, gpu_memory_utilization, float8_kv_cache, random_state, max_lora_rank, disable_log_stats, qat_scheme, *args, **kwargs)
   1062 if auto_model is None:
   1063     auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
-> 1065 model, tokenizer = FastBaseModel.from_pretrained(
   1066     model_name = model_name,
   1067     max_seq_length = max_seq_length,
   1068     dtype = _get_dtype(dtype),
   1069     load_in_4bit = load_in_4bit,
   1070     load_in_8bit = load_in_8bit,
   1071     load_in_16bit = load_in_16bit,
   1072     full_finetuning = full_finetuning,
   1073     token = token,
   1074     device_map = device_map,
   1075     trust_remote_code = trust_remote_code,
   1076     revision = revision if not is_peft else None,
   1077     model_types = model_types,
   1078     tokenizer_name = tokenizer_name,
   1079     auto_model = auto_model,
   1080     use_gradient_checkpointing = use_gradient_checkpointing,
   1081     supports_sdpa = supports_sdpa,
   1082     whisper_language = whisper_language,
   1083     whisper_task = whisper_task,
   1084     auto_config = model_config,
   1085     offload_embedding = offload_embedding,
   1086     float32_mixed_precision = float32_mixed_precision,
   1087     # Pass vLLM[/inference](http://192.168.1.166:8889/inference) parameters
   1088     fast_inference = fast_inference,
   1089     gpu_memory_utilization = gpu_memory_utilization,
   1090     float8_kv_cache = float8_kv_cache,
   1091     random_state = random_state,
   1092     max_lora_rank = max_lora_rank,
   1093     disable_log_stats = disable_log_stats,
   1094     *args,
   1095     **kwargs,
   1096 )
   1098 if resize_model_vocab is not None:
   1099     model.resize_token_embeddings(resize_model_vocab)

File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/vision.py:763](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/vision.py#line=762), in FastBaseModel.from_pretrained(model_name, max_seq_length, dtype, load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning, token, device_map, trust_remote_code, model_types, tokenizer_name, auto_model, use_gradient_checkpointing, supports_sdpa, whisper_language, whisper_task, auto_config, offload_embedding, float32_mixed_precision, fast_inference, gpu_memory_utilization, float8_kv_cache, random_state, max_lora_rank, disable_log_stats, unsloth_vllm_standby, **kwargs)
    761     with torch.no_grad():
    762         for jj, (name, module) in enumerate(model.named_modules()):
--> 763             exec(custom_datatype)
    764 # Clear deleted GPU items
    765 for _ in range(3):

File <string>:2

AcceleratorError: HIP error: invalid device function
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.

0xrushi and others added 2 commits November 15, 2025 21:33
- Updated RMS LayerNorm implementation to use float32 for improved precision.
- Introduced safe mode for MLP operations in Gemma and Llama models to prevent dtype mismatches.
- Replaced fused cross-entropy loss with standard PyTorch CE loss in Mistral and CausalLM for better handling of NaNs on Strix Halo.
- Added environment variable checks to conditionally apply these changes based on the UNSLOTH_STRIX_HALO_SAFE setting.
@0xrushi
Copy link
Author

0xrushi commented Nov 16, 2025

@danielhanchen @Datta0 you were right Mistral and Qwen were also broken.
I had some trouble cleaning up the old cache, but instead of keeping the disable_rmsnorm and disable other stuff flags, I fixed the issues by converting the items to float32. GPT oss was trickier because it kept generating a temporary cache, and I had to patch that in unsloth/unsloth/models/_utils.py.

In a few places, fast_cross_entropy_loss also had to be replaced with the standard PyTorch cross_entropy loss
Do you know where fast_cross_entropy_loss is actually defined? If we can track it down, maybe we can patch it directly at the source.

Let me know if this makes sense or if I’m still missing something.

@danielhanchen
Copy link
Contributor

@0xrushi Is the CE loss the issue? Re the temporary cache, use the below to force non overwriting:

import os
os.environ["UNSLOTH_COMPILE_OVERWRITE"] = "0"

...
from unsloth import FastLanguageModel

@0xrushi
Copy link
Author

0xrushi commented Nov 18, 2025

fast_cross_entropy_loss

yes fast_cross_entropy_loss is one of the issues, normal pytorch CE is fine.
second one is eps_f32 = tl.full((), eps, tl.float32) in rms norm.
third is similar typecasting in mlp

This patch resolves the compilation issue https://github.com/0xrushi/unsloth/blob/3aa027c8cc3e136bb40363cb94d9b04071e611e6/unsloth/models/_utils.py#L1848

@danielhanchen
Copy link
Contributor

@0xrushi I think @billishyahao is going to investigate this from the AMD side as well!

@0xrushi
Copy link
Author

0xrushi commented Nov 20, 2025

@0xrushi I think @billishyahao is going to investigate this from the AMD side as well!

exciting to see AMD devs getting involved!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Training on ROCm (gfx1151, Strix Halo) results in NaN losses with Gemma3 fine-tuning

3 participants