-
-
Notifications
You must be signed in to change notification settings - Fork 4k
fix: unsloth fixes for gfx1151 #3588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…lementations - Introduced UNSLOTH_DISABLE_FLASH_ATTENTION to disable Flash Attention based on environment variable. - Added checks for UNSLOTH_DISABLE_TRITON_RMSNORM and UNSLOTH_LAYERNORM_IMPL to conditionally use different LayerNorm implementations in Gemma and Llama models. - Updated dtype handling in Llama model to allow overriding automatic dtype casting via environment variable.
for more information, see https://pre-commit.ci
Summary of ChangesHello @0xrushi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses critical numerical instability issues encountered when fine-tuning Gemma-3 models on AMD Strix Halo (gfx1151) using HIP/ROCm, which previously led to NaN losses. The core solution involves carefully managing data types within FlashAttention2 operations and providing configurable fallbacks for other performance-optimized kernels like RoPE and RMSNorm, ensuring stable training while retaining performance where possible. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several fixes and diagnostic toggles to address numerical instability issues with bfloat16 on AMD gfx1151 hardware. The core change involves conditionally casting to float16 for FlashAttention2 on HIP, which is a sound approach to mitigate NaN losses. Additionally, environment variables are added to control RoPE and RMSNorm kernel implementations, providing valuable debugging flexibility. My review focuses on improving code maintainability by reducing duplication and making logging more consistent. Overall, these are solid changes for improving hardware compatibility and debuggability.
unsloth/models/llama.py
Outdated
| _disable_triton_rms = ( | ||
| os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1" | ||
| ) | ||
| _ln_impl = os.environ.get("UNSLOTH_LAYERNORM_IMPL", "").lower() | ||
| if _disable_triton_rms or _ln_impl == "python": | ||
| hidden_states = self.input_layernorm(hidden_states) | ||
| if not _LOGGED_RMSNORM: | ||
| logger.debug( | ||
| "Unsloth: RMSNorm=torch. env(UNSLOTH_DISABLE_TRITON_RMSNORM=%s, UNSLOTH_LAYERNORM_IMPL=%s)", | ||
| os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM"), | ||
| os.environ.get("UNSLOTH_LAYERNORM_IMPL"), | ||
| ) | ||
| _LOGGED_RMSNORM = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The debug logging for RMSNorm implementation only covers the torch path. For consistency and better diagnostics, a logger.debug call should also be added to the else block to indicate when the fast Triton kernel (fast_rms_layernorm) is being used. This can be refactored to reduce duplication, similar to the RoPE logging.
_disable_triton_rms = os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1"
_ln_impl = os.environ.get("UNSLOTH_LAYERNORM_IMPL", "").lower()
if _disable_triton_rms or _ln_impl == "python":
hidden_states = self.input_layernorm(hidden_states)
_rms_impl_name = "torch"
else:
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
_rms_impl_name = "triton"
if not _LOGGED_RMSNORM:
logger.debug(
"Unsloth: RMSNorm=%s. env(UNSLOTH_DISABLE_TRITON_RMSNORM=%s, UNSLOTH_LAYERNORM_IMPL=%s)",
_rms_impl_name,
os.environ.get("UNSLOTH_DISABLE_TRITON_RMSNORM"),
os.environ.get("UNSLOTH_LAYERNORM_IMPL"),
)
_LOGGED_RMSNORM = True
unsloth/models/gemma.py
Outdated
| hidden_states = fast_rms_layernorm( | ||
| self.input_layernorm, hidden_states, gemma = True | ||
| ) | ||
| _disable_triton_rms = os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielhanchen I think maybe adding a flag called impl or something to fast_rms_layernorm is a cleaner approach but I'd let you comment on the performance implications of that
Also I think maybe these flags don't change on runtime, so can be inferred once at the start of the model file and shared across instead of computing for every layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we should do it once only at the start as a global variable
unsloth/models/llama.py
Outdated
| _fa2_in_dtype = Q.dtype | ||
| if DEVICE_TYPE == "hip": | ||
| target = os.environ.get("UNSLOTH_FA2_COMPUTE_DTYPE", "").lower() | ||
| if target in _FA2_COMPUTE_DTYPE_MAP: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even these can be potentially inferred once at startup.
Also does the FA inaccuracy bug effect every model? If yes we should probably replicate this across other models (like mistral or qwen for eg)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was running the gpt20b notebook and im getting kernel crashes, its a long way to debug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know when you feel comfortable with this. I will be happy to review :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Datta0 gptoss needs that extra casting
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
out = m(**b, labels=b["input_ids"])
Its fine here, but will that break somewhere else that abstracts anything similar ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that is interesting. Do you happen to have the error stack trace when the autocast is not set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Copyright (C) 2023 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))== Unsloth 2025.11.3: Fast Gpt_Oss patching. Transformers: 4.57.1.
\\ /| AMD Radeon Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.
O^O/ \_/ \ Torch: 2.7.1+git99ccf24. ROCm Toolkit: 6.4.43484-123eb5128. Triton: 3.3.1
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+13c93f39.d20251112. FA2 = True]
"-____-" Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
[/opt/conda/envs/py_3.12/lib/python3.12/site-packages/accelerate/utils/modeling.py:804](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/accelerate/utils/modeling.py#line=803): UserWarning: expandable_segments not supported on this platform (Triggered internally at [/var/lib/jenkins/pytorch/c10/hip/HIPAllocatorConfig.h:29](http://192.168.1.166:8889/var/lib/jenkins/pytorch/c10/hip/HIPAllocatorConfig.h#line=28).)
_ = torch.tensor([0], device=i)
Loading checkpoint shards: 0%| | 0/9 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 26
23 m.train().cuda()
24 b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
---> 26 out = m(**b, labels=b["input_ids"])
27 loss = out.loss
29 print("loss:", float(loss.detach()))
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:726](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=725), in GptOssForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
712 def forward(
713 self,
714 input_ids: Optional[torch.LongTensor] = None,
(...) 724 **kwargs: Unpack[TransformersKwargs],
725 ) -> MoeCausalLMOutputWithPast:
--> 726 return GptOssForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/generic.py:918](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/generic.py#line=917), in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
916 if return_dict_passed is not None:
917 return_dict = return_dict_passed
--> 918 output = func(self, *args, **kwargs)
919 if not return_dict and not isinstance(output, tuple):
920 output = output.to_tuple()
File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:547](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=546), in GptOssForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
542 output_router_logits = (
543 output_router_logits if output_router_logits is not None else self.config.output_router_logits
544 )
546 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 547 outputs: MoeModelOutputWithPast = self.model(
548 input_ids=input_ids,
549 attention_mask=attention_mask,
550 position_ids=position_ids,
551 past_key_values=past_key_values,
552 inputs_embeds=inputs_embeds,
553 use_cache=use_cache,
554 output_router_logits=output_router_logits,
555 cache_position=cache_position,
556 **kwargs,
557 )
559 hidden_states = outputs.last_hidden_state
560 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gpt_oss.py:1249](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gpt_oss.py#line=1248), in patch_GptOssModel.<locals>.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)
1247 for decoder_layer in self.layers:
1248 mask = attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask
-> 1249 hidden_states = decoder_layer(
1250 hidden_states,
1251 attention_mask=mask,
1252 position_ids=position_ids,
1253 past_key_values=past_key_values,
1254 use_cache=use_cache,
1255 cache_position=cache_position,
1256 position_embeddings=position_embeddings,
1257 **kwargs,
1258 )
1259 pass
1260 hidden_states = self.norm(hidden_states)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/modeling_layers.py:93](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/modeling_layers.py#line=92), in GradientCheckpointingLayer.__call__(self, *args, **kwargs)
90 message = message.rstrip(",") + "."
91 logger.warning_once(message)
---> 93 return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
94 return super().__call__(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_compile.py:51](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_compile.py#line=50), in _disable_dynamo.<locals>.inner(*args, **kwargs)
48 disable_fn = torch._dynamo.disable(fn, recursive)
49 fn.__dynamo_disable = disable_fn # type: ignore[attr-defined]
---> 51 return disable_fn(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:838](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py#line=837), in DisableContext.__call__.<locals>._fn(*args, **kwargs)
836 _maybe_set_eval_frame(_callback_from_stance(self.callback))
837 try:
--> 838 return fn(*args, **kwargs)
839 finally:
840 set_eval_frame(None)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py:488](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py#line=487), in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
483 if context_fn is not noop_context_fn or debug is not False:
484 raise ValueError(
485 "Passing `context_fn` or `debug` is only supported when "
486 "use_reentrant=False."
487 )
--> 488 return CheckpointFunction.apply(function, preserve, *args)
489 else:
490 gen = _checkpoint_without_reentrant_generator(
491 function, preserve, context_fn, determinism_check, debug, *args, **kwargs
492 )
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/autograd/function.py:575](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/autograd/function.py#line=574), in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 "In order to use an autograd.Function with functorch transforms "
580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581 "staticmethod. For more details, please see "
582 "https://pytorch.org/docs/main/notes/extending.func.html"
583 )
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py:263](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/checkpoint.py#line=262), in CheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
260 ctx.save_for_backward(*tensor_inputs)
262 with torch.no_grad():
--> 263 outputs = run_function(*args)
264 return outputs
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/deprecation.py:172](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/utils/deprecation.py#line=171), in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
169 # DeprecationWarning is ignored by default, so we use FutureWarning instead
170 warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py:386](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py#line=385), in GptOssDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)
384 residual = hidden_states
385 hidden_states = self.post_attention_layernorm(hidden_states)
--> 386 hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
387 hidden_states = residual + hidden_states
388 return hidden_states
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:309](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=308), in GptOssMLP.forward(self, hidden_states)
308 def forward(self, hidden_states):
--> 309 return GptOssMLP_forward(self, hidden_states)
File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:298](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=297), in GptOssMLP_forward(self, hidden_states)
295 @torch.compiler.disable(recursive = False)
296 def GptOssMLP_forward(self, hidden_states):
297 router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
--> 298 routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
299 return routed_out, router_scores
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:270](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=269), in GptOssExperts.forward(self, hidden_states, router_indices, routing_weights)
269 def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
--> 270 return GptOssExperts_forward(self, hidden_states, router_indices, routing_weights)
File [~/finetuning-workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:236](http://192.168.1.166:8889/lab/workspaces/auto-x/tree/workspace/workspace/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py#line=235), in GptOssExperts_forward(self, hidden_states, router_indices, routing_weights)
234 glu = gate * torch.sigmoid(gate * self.alpha)
235 gated_output = (up + 1) * glu
--> 236 out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
237 weighted_output = out * routing_weights[token_idx, expert_idx, None]
238 next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
but SFTTrainer worked, notebook here unslothai/notebooks#126
… for improved compatibility with environment settings.
unsloth/models/gemma.py
Outdated
| if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python": | ||
| hidden_states = self.input_layernorm(hidden_states) | ||
| else: | ||
| hidden_states = fast_rms_layernorm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait are you certain this causes NaN issues? That's extremely weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
casting eps to float32 resolved that
but im still getting nans for
m,tok = FastModel.from_pretrained(
"unsloth/gemma-3-4b-it",
load_in_4bit=False, load_in_8bit=False, full_finetuning=False,
)
m.train().cuda()
b = tok(text=["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])
unsloth/models/llama.py
Outdated
| # else inplace_rope_embedding(Q, K, cos, sin, position_ids) | ||
| # ) | ||
| Q, K = fast_rope_embedding(Q, K, cos, sin) | ||
| # Allow switching RoPE implementation to a non-Triton path on HIP if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very weird if RMS Layernorm even doesn't work! MI300X seems to be fine for GPT-OSS hmm
Did you try https://docs.unsloth.ai/get-started/install-and-update/amd ie
pip install --upgrade torch==2.8.0 pytorch-triton-rocm torchvision torchaudio torchao==0.13.0 xformers --index-url https://download.pytorch.org/whl/rocm6.4
pip install --no-deps unsloth unsloth-zoo
pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git
pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@billishyahao This might be interesting ie normal Unsloth using Triton and FA2 gets NaNs for GPT-OSS for Strix Halo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty sure the precompiled torch from rocm pypi doesn't work is torch==2.8.0 a must?
I did try pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth", let me try it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea it doesn't . The dockerfile I linked works better. I used the same Dockerfile to install "unsloth[amd] @ git+https://github.com/unslothai/unsloth" and pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git.
import torch
torch.__version__
'2.8.0+rocm6.4'
import os, torch
from unsloth import FastModel
from transformers import AutoTokenizer
# os.environ['UNSLOTH_FA2_COMPUTE_DTYPE'] = 'bfloat16'
# os.environ['UNSLOTH_ROPE_IMPL'] = 'slow'
# os.environ['UNSLOTH_DISABLE_TRITON_RMSNORM'] = '1'
model_id = "unsloth/gpt-oss-20b-BF16"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tok.pad_token = tok.pad_token or tok.eos_token
m, _ = FastModel.from_pretrained(
model_id,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
float32_mixed_precision=False,
use_gradient_checkpointing=False,
attn_implementation="eager",
)
m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])
loss = out.loss
print("loss:", float(loss.detach()))
loss.backward() # exactly once per forward
has_nan = any(p.grad is not None and torch.isnan(p.grad).any() for p in m.parameters())
print("grad_has_nan:", has_nan)
---------------------------------------------------------------------------
AcceleratorError Traceback (most recent call last)
Cell In[1], line 13
10 tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
11 tok.pad_token = tok.pad_token or tok.eos_token
---> 13 m, _ = FastModel.from_pretrained(
14 model_id,
15 load_in_4bit=False,
16 load_in_8bit=False,
17 full_finetuning=False,
18 float32_mixed_precision=False,
19 use_gradient_checkpointing=False,
20 attn_implementation="eager",
21 )
23 m.train().cuda()
24 b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/loader.py:1065](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/loader.py#line=1064), in FastModel.from_pretrained(model_name, max_seq_length, dtype, load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning, token, device_map, rope_scaling, fix_tokenizer, trust_remote_code, use_gradient_checkpointing, resize_model_vocab, revision, return_logits, fullgraph, use_exact_model_name, auto_model, whisper_language, whisper_task, unsloth_force_compile, offload_embedding, float32_mixed_precision, fast_inference, gpu_memory_utilization, float8_kv_cache, random_state, max_lora_rank, disable_log_stats, qat_scheme, *args, **kwargs)
1062 if auto_model is None:
1063 auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
-> 1065 model, tokenizer = FastBaseModel.from_pretrained(
1066 model_name = model_name,
1067 max_seq_length = max_seq_length,
1068 dtype = _get_dtype(dtype),
1069 load_in_4bit = load_in_4bit,
1070 load_in_8bit = load_in_8bit,
1071 load_in_16bit = load_in_16bit,
1072 full_finetuning = full_finetuning,
1073 token = token,
1074 device_map = device_map,
1075 trust_remote_code = trust_remote_code,
1076 revision = revision if not is_peft else None,
1077 model_types = model_types,
1078 tokenizer_name = tokenizer_name,
1079 auto_model = auto_model,
1080 use_gradient_checkpointing = use_gradient_checkpointing,
1081 supports_sdpa = supports_sdpa,
1082 whisper_language = whisper_language,
1083 whisper_task = whisper_task,
1084 auto_config = model_config,
1085 offload_embedding = offload_embedding,
1086 float32_mixed_precision = float32_mixed_precision,
1087 # Pass vLLM[/inference](http://192.168.1.166:8889/inference) parameters
1088 fast_inference = fast_inference,
1089 gpu_memory_utilization = gpu_memory_utilization,
1090 float8_kv_cache = float8_kv_cache,
1091 random_state = random_state,
1092 max_lora_rank = max_lora_rank,
1093 disable_log_stats = disable_log_stats,
1094 *args,
1095 **kwargs,
1096 )
1098 if resize_model_vocab is not None:
1099 model.resize_token_embeddings(resize_model_vocab)
File [/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/vision.py:763](http://192.168.1.166:8889/opt/conda/envs/py_3.12/lib/python3.12/site-packages/unsloth/models/vision.py#line=762), in FastBaseModel.from_pretrained(model_name, max_seq_length, dtype, load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning, token, device_map, trust_remote_code, model_types, tokenizer_name, auto_model, use_gradient_checkpointing, supports_sdpa, whisper_language, whisper_task, auto_config, offload_embedding, float32_mixed_precision, fast_inference, gpu_memory_utilization, float8_kv_cache, random_state, max_lora_rank, disable_log_stats, unsloth_vllm_standby, **kwargs)
761 with torch.no_grad():
762 for jj, (name, module) in enumerate(model.named_modules()):
--> 763 exec(custom_datatype)
764 # Clear deleted GPU items
765 for _ in range(3):
File <string>:2
AcceleratorError: HIP error: invalid device function
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
ea6c89f to
2dc7ab9
Compare
- Updated RMS LayerNorm implementation to use float32 for improved precision. - Introduced safe mode for MLP operations in Gemma and Llama models to prevent dtype mismatches. - Replaced fused cross-entropy loss with standard PyTorch CE loss in Mistral and CausalLM for better handling of NaNs on Strix Halo. - Added environment variable checks to conditionally apply these changes based on the UNSLOTH_STRIX_HALO_SAFE setting.
for more information, see https://pre-commit.ci
|
@danielhanchen @Datta0 you were right Mistral and Qwen were also broken. In a few places, Let me know if this makes sense or if I’m still missing something. |
|
@0xrushi Is the CE loss the issue? Re the temporary cache, use the below to force non overwriting: import os
os.environ["UNSLOTH_COMPILE_OVERWRITE"] = "0"
...
from unsloth import FastLanguageModel |
yes This patch resolves the compilation issue https://github.com/0xrushi/unsloth/blob/3aa027c8cc3e136bb40363cb94d9b04071e611e6/unsloth/models/_utils.py#L1848 |
|
@0xrushi I think @billishyahao is going to investigate this from the AMD side as well! |
exciting to see AMD devs getting involved! |





Resolves #3385 (comment)
Summary
Root Cause
What We Changed
Validation