Skip to content

Commit b1c616c

Browse files
authored
[Dev release cherry pick] Fixes for gpt-oss (#2076)
Signed-off-by: Chen Cui <[email protected]>
1 parent 193a929 commit b1c616c

File tree

5 files changed

+43
-40
lines changed

5 files changed

+43
-40
lines changed

megatron/core/models/common/embeddings/rope_utils.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -268,46 +268,47 @@ def apply_rotary_pos_emb(
268268
if config.apply_rope_fusion:
269269
if cu_seqlens is None:
270270
# NOTE: TE backends do not support mRoPE in bshd format when bs > 1.
271+
use_unfused = False
271272
if config.mrope_section is not None and freqs.shape[1] > 1:
272273
# TODO: Add a check in TransformerConfig and remove this unfused implementation.
273274
warnings.warn(
274275
"apply_rope_fusion does not support mRoPE in bshd format when bs > 1. "
275276
"Please set apply_rope_fusion to false. This will become an error in v0.16."
276277
)
277-
return _apply_rotary_pos_emb_bshd(
278-
t,
279-
freqs,
280-
rotary_interleaved=config.rotary_interleaved,
281-
multi_latent_attention=config.multi_latent_attention,
282-
mscale=mscale,
278+
use_unfused = True
279+
if mscale != 1.0:
280+
warnings.warn(
281+
f"mscale={mscale} is not supported by TE's fused RoPE. "
282+
"Using unfused implementation."
283283
)
284-
else:
284+
use_unfused = True
285+
if not use_unfused:
285286
assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available."
286287
return fused_apply_rotary_pos_emb(t, freqs, interleaved=config.rotary_interleaved)
287288
else:
288289
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
289290
return fused_apply_rotary_pos_emb_thd(
290291
t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank()
291292
)
293+
# use unfused implementation
294+
if cu_seqlens is None:
295+
return _apply_rotary_pos_emb_bshd(
296+
t,
297+
freqs,
298+
rotary_interleaved=config.rotary_interleaved,
299+
multi_latent_attention=config.multi_latent_attention,
300+
mscale=mscale,
301+
)
292302
else:
293-
if cu_seqlens is None:
294-
return _apply_rotary_pos_emb_bshd(
295-
t,
296-
freqs,
297-
rotary_interleaved=config.rotary_interleaved,
298-
multi_latent_attention=config.multi_latent_attention,
299-
mscale=mscale,
300-
)
301-
else:
302-
return _apply_rotary_pos_emb_thd(
303-
t,
304-
cu_seqlens,
305-
freqs,
306-
rotary_interleaved=config.rotary_interleaved,
307-
multi_latent_attention=config.multi_latent_attention,
308-
mscale=mscale,
309-
cp_group=cp_group,
310-
)
303+
return _apply_rotary_pos_emb_thd(
304+
t,
305+
cu_seqlens,
306+
freqs,
307+
rotary_interleaved=config.rotary_interleaved,
308+
multi_latent_attention=config.multi_latent_attention,
309+
mscale=mscale,
310+
cp_group=cp_group,
311+
)
311312

312313

313314
def apply_rotary_pos_emb_with_cos_sin(

megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,22 +228,25 @@ def _yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
228228

229229
@lru_cache(maxsize=8)
230230
def _yarn_get_concentration_factor(
231-
scaling_factor: float, mscale: float, mscale_all_dim: float
231+
scaling_factor: float, mscale: Optional[float], mscale_all_dim: Optional[float]
232232
) -> float:
233233
"""
234234
Get the concentration factor (factor multiplied to the sine and cosine components of the
235235
embedding). This factor is also known as attention factor, and sometimes homonymously known as
236236
"mscale"
237237
"""
238+
if mscale is None or mscale_all_dim is None:
239+
return _yarn_get_mscale(scaling_factor)
238240
return float(
239241
_yarn_get_mscale(scaling_factor, mscale) / _yarn_get_mscale(scaling_factor, mscale_all_dim)
240242
)
241243

242244

243245
def _yarn_get_concentration_factor_from_config(config: TransformerConfig) -> float:
244-
fields = ["yarn_rotary_scaling_factor", "yarn_mscale", "yarn_mscale_all_dim"]
245-
if all(hasattr(config, f) for f in fields):
246+
if hasattr(config, "yarn_rotary_scaling_factor"):
246247
return _yarn_get_concentration_factor(
247-
config.yarn_rotary_scaling_factor, config.yarn_mscale, config.yarn_mscale_all_dim
248+
config.yarn_rotary_scaling_factor,
249+
getattr(config, "yarn_mscale", None),
250+
getattr(config, "yarn_mscale_all_dim", None),
248251
)
249252
return 1.0

megatron/core/transformer/dot_product_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,19 @@ def __init__(
116116
if self.config.softmax_type == "vanilla":
117117
self.softmax_offset = None
118118
elif self.config.softmax_type == "off-by-one":
119-
self.softmax_offset = torch.zeros(self.num_attention_heads_per_partition)
119+
self.softmax_offset = torch.zeros(
120+
self.num_attention_heads_per_partition,
121+
device=torch.cuda.current_device(),
122+
dtype=self.config.params_dtype,
123+
)
120124
elif self.config.softmax_type == "learnable":
121125
self.register_parameter(
122126
"softmax_offset",
123127
torch.nn.Parameter(
124128
torch.empty(
125-
self.num_attention_heads_per_partition, dtype=self.config.params_dtype
129+
self.num_attention_heads_per_partition,
130+
device=torch.cuda.current_device(),
131+
dtype=self.config.params_dtype,
126132
)
127133
),
128134
)

megatron/core/transformer/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
22

33
"""Utilities for transformer layers."""
4-
from functools import lru_cache
54
from operator import itemgetter
65
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple, Union
76

@@ -29,13 +28,11 @@ def get_linear_layer(rows, columns, init_method, perform_initialization=True):
2928
return layer
3029

3130

32-
@lru_cache(maxsize=32)
3331
def get_default_causal_mask(sq: int) -> torch.Tensor:
3432
"""Return the causal upper triangular mask for softmax input."""
3533
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
3634

3735

38-
@lru_cache(maxsize=32)
3936
def get_sliding_window_causal_mask(sq, skv, window_size):
4037
"""Create the equivalent attention mask for SWA in [sq, skv] shape"""
4138
m = torch.ones(sq, skv, dtype=torch.bool, device="cuda")

tests/unit_tests/fusions/test_torch_softmax.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
13
import pytest
24
import torch
35

@@ -21,9 +23,6 @@ def setup_method(self, method):
2123
scale=None,
2224
)
2325

24-
def teardown_method(self):
25-
get_default_causal_mask.cache_clear()
26-
2726
def test_output_shape(self):
2827
x = torch.randn(8, 2, 4, 4, device="cuda")
2928
y = self.softmax(x, None, None)
@@ -126,9 +125,6 @@ def test_causal_mask_equal_scores(self):
126125
class TestFusedScaleMaskSoftmaxComprehensive:
127126
"""Comprehensive tests for FusedScaleMaskSoftmax including window attention and scaling."""
128127

129-
def teardown_method(self):
130-
get_default_causal_mask.cache_clear()
131-
132128
def test_scaling_factor(self):
133129
"""Test softmax with different scaling factors."""
134130
x = torch.randn(2, 4, 8, 8, device="cuda")

0 commit comments

Comments
 (0)