-
-
Notifications
You must be signed in to change notification settings - Fork 4.1k
fix: unsloth fixes for gfx1151 #3588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
9d0a807
71e810f
61b9621
ca91fa4
d574b1f
b8ec53f
dbf2e52
3c87aa4
2dc7ab9
3aa027c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,8 +94,29 @@ | |
| from peft import PeftModelForCausalLM, PeftModelForSequenceClassification | ||
| from ..save import patch_saving_functions | ||
| import re, os, inspect, math, sys | ||
|
|
||
| # One-time debug flags to avoid repeated logs in hot paths | ||
| _LOGGED_ROPE = False | ||
| _LOGGED_ATTENTION_FA2 = False | ||
| _LOGGED_RMSNORM = False | ||
| import types | ||
|
|
||
| _FA2_COMPUTE_DTYPE_MAP = { | ||
| "bf16": torch.bfloat16, | ||
| "bfloat16": torch.bfloat16, | ||
| "f16": torch.float16, | ||
| "float16": torch.float16, | ||
| "fp16": torch.float16, | ||
| } | ||
|
|
||
| _ROPE_IMPL = os.getenv("UNSLOTH_ROPE_IMPL", "").lower() | ||
| _DISABLE_TRITON_ROPE = os.getenv("UNSLOTH_DISABLE_TRITON_ROPE", "0") == "1" | ||
|
|
||
| _LAYERNORM_IMPL = os.getenv("UNSLOTH_LAYERNORM_IMPL", "").lower() | ||
| _DISABLE_TRITON_RMSNORM = os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM", "0") == "1" | ||
|
|
||
| _FA2_COMPUTE_DTYPE_TARGET = os.getenv("UNSLOTH_FA2_COMPUTE_DTYPE", "").lower() | ||
|
|
||
| try: | ||
| from huggingface_hub.utils import get_token | ||
| except: | ||
|
|
@@ -549,8 +570,13 @@ def LlamaAttention_fast_forward( | |
| del self.RH_Q | ||
| del self.attention | ||
|
|
||
| global _LOGGED_ROPE, _LOGGED_ATTENTION_FA2 | ||
|
|
||
| bsz, q_len, _ = hidden_states.size() | ||
|
|
||
| # Preserve original dtype for attention output | ||
| _orig_out_dtype = hidden_states.dtype | ||
|
|
||
| n_heads = self.config.num_attention_heads | ||
| n_groups = self.num_key_value_groups | ||
| n_kv_heads = self.config.num_key_value_heads | ||
|
|
@@ -580,12 +606,24 @@ def LlamaAttention_fast_forward( | |
| # cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) | ||
| cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index) | ||
|
|
||
| # Q, K = ( | ||
| # fast_rope_embedding(Q, K, cos, sin) | ||
| # if position_ids is None | ||
| # else inplace_rope_embedding(Q, K, cos, sin, position_ids) | ||
| # ) | ||
| Q, K = fast_rope_embedding(Q, K, cos, sin) | ||
| # Allow switching RoPE implementation to a non-Triton path on HIP if needed | ||
|
||
| if _ROPE_IMPL == "slow" or _DISABLE_TRITON_ROPE: | ||
| Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) | ||
| _rope_impl_name = "slow (torch)" | ||
| else: | ||
| Q, K = fast_rope_embedding(Q, K, cos, sin) | ||
| _rope_impl_name = "triton (fast)" | ||
|
|
||
| if not _LOGGED_ROPE: | ||
| logger.debug( | ||
| "Unsloth: RoPE=%s. device=%s cos_dtype=%s env(UNSLOTH_ROPE_IMPL=%s, UNSLOTH_DISABLE_TRITON_ROPE=%s)", | ||
| _rope_impl_name, | ||
| DEVICE_TYPE_TORCH, | ||
| str(cos.dtype), | ||
| os.getenv("UNSLOTH_ROPE_IMPL"), | ||
| os.getenv("UNSLOTH_DISABLE_TRITON_ROPE"), | ||
| ) | ||
| _LOGGED_ROPE = True | ||
|
|
||
| if past_key_value is not None: | ||
| K = torch.cat([past_key_value[0], K], dim = 2) | ||
|
|
@@ -618,7 +656,33 @@ def LlamaAttention_fast_forward( | |
| Q = Q.transpose(1, 2) | ||
| K = K.transpose(1, 2) | ||
| V = V.transpose(1, 2) | ||
| # Allow selecting a compute dtype, defaulting to fp16 if inputs are bf16. | ||
| _fa2_in_dtype = Q.dtype | ||
| if DEVICE_TYPE == "hip": | ||
| if _FA2_COMPUTE_DTYPE_TARGET in _FA2_COMPUTE_DTYPE_MAP: | ||
| target_dtype = _FA2_COMPUTE_DTYPE_MAP[_FA2_COMPUTE_DTYPE_TARGET] | ||
| if Q.dtype != target_dtype: | ||
| Q = Q.to(target_dtype) | ||
| K = K.to(target_dtype) | ||
| V = V.to(target_dtype) | ||
| elif Q.dtype == torch.bfloat16: | ||
| Q = Q.to(torch.float16) | ||
| K = K.to(torch.float16) | ||
| V = V.to(torch.float16) | ||
| A = flash_attn_func(Q, K, V, causal = True) | ||
| if A.dtype != _orig_out_dtype: | ||
| A = A.to(_orig_out_dtype) | ||
| if not _LOGGED_ATTENTION_FA2: | ||
| logger.debug( | ||
| "Unsloth: Attention=FlashAttention2. device=%s in_dtype=%s used_dtype=%s out_dtype=%s env(UNSLOTH_FA2_COMPUTE_DTYPE=%s, UNSLOTH_DISABLE_FLASH_ATTENTION=%s)", | ||
| DEVICE_TYPE, | ||
| str(_fa2_in_dtype), | ||
| str(Q.dtype), | ||
| str(A.dtype), | ||
| os.getenv("UNSLOTH_FA2_COMPUTE_DTYPE"), | ||
| os.getenv("UNSLOTH_DISABLE_FLASH_ATTENTION"), | ||
| ) | ||
| _LOGGED_ATTENTION_FA2 = True | ||
| else: | ||
| # when qlen==vlen and attn_mask is None, we should use causal attention | ||
| Q_len = Q.shape[-2] | ||
|
|
@@ -697,6 +761,8 @@ def LlamaDecoderLayer_fast_forward( | |
| (see `past_key_values`). | ||
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | ||
| """ | ||
| global _LOGGED_RMSNORM | ||
|
|
||
| if use_cache and hasattr(self, "_flag_for_generation"): | ||
| residual = hidden_states | ||
| hidden_states = fast_rms_layernorm_inference( | ||
|
|
@@ -724,7 +790,21 @@ def LlamaDecoderLayer_fast_forward( | |
| hidden_states += residual | ||
| else: | ||
| residual = hidden_states | ||
| hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) | ||
| if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python": | ||
| hidden_states = self.input_layernorm(hidden_states) | ||
| _rms_impl_name = "torch" | ||
| else: | ||
| hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) | ||
| _rms_impl_name = "triton" | ||
|
|
||
| if not _LOGGED_RMSNORM: | ||
| logger.debug( | ||
| "Unsloth: RMSNorm=%s. env(UNSLOTH_DISABLE_TRITON_RMSNORM=%s, UNSLOTH_LAYERNORM_IMPL=%s)", | ||
| _rms_impl_name, | ||
| os.getenv("UNSLOTH_DISABLE_TRITON_RMSNORM"), | ||
| os.getenv("UNSLOTH_LAYERNORM_IMPL"), | ||
| ) | ||
| _LOGGED_RMSNORM = True | ||
| hidden_states, self_attn_weights, present_key_value = self.self_attn( | ||
| hidden_states = hidden_states, | ||
| causal_mask = causal_mask, | ||
|
|
@@ -740,7 +820,12 @@ def LlamaDecoderLayer_fast_forward( | |
|
|
||
| # Fully Connected | ||
| residual = hidden_states | ||
| hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) | ||
| if _DISABLE_TRITON_RMSNORM or _LAYERNORM_IMPL == "python": | ||
| hidden_states = self.post_attention_layernorm(hidden_states) | ||
| else: | ||
| hidden_states = fast_rms_layernorm( | ||
| self.post_attention_layernorm, hidden_states | ||
| ) | ||
| hidden_states = self.mlp(hidden_states) | ||
| hidden_states = residual + hidden_states | ||
|
|
||
|
|
@@ -853,7 +938,9 @@ def LlamaModel_fast_forward( | |
| if inputs_embeds is None: | ||
| inputs_embeds = self.embed_tokens(input_ids) | ||
|
|
||
| inputs_embeds = inputs_embeds.to(_get_dtype(dtype_from_config(self.config))) | ||
| # Allow overriding the automatic dtype cast via env for diagnostics | ||
| if os.getenv("UNSLOTH_DISABLE_AUTODTYPE_CAST", "0") != "1": | ||
| inputs_embeds = inputs_embeds.to(_get_dtype(dtype_from_config(self.config))) | ||
|
|
||
| # Normalized from Gemma | ||
| IS_GEMMA = self.config.model_type.startswith("gemma") | ||
|
|
@@ -1349,7 +1436,7 @@ def _CausalLM_fast_forward( | |
| labels = labels.to(lm_head_device) | ||
|
|
||
| # Output last hidden states without logits if asked | ||
| if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": | ||
| if os.getenv("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": | ||
| if num_logits_to_keep != 0: | ||
| hidden_states = hidden_states[:, -num_logits_to_keep:, :] | ||
| return CausalLMOutputWithPast( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait are you certain this causes NaN issues? That's extremely weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Google colab nvidia T4

StrixHalo

let me pinpoint it more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
casting eps to float32 resolved that
but im still getting nans for