Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5c7ffbd
check in collectives and add global symmetric memory class
sidsingh-nvidia Nov 19, 2025
8cab2a6
initialize global symm memory buffer
sidsingh-nvidia Nov 19, 2025
d3aa3c2
have working code
sidsingh-nvidia Nov 20, 2025
2ea9e5c
modify functional test
sidsingh-nvidia Nov 20, 2025
a604681
enable custom linear layers for mamba
sidsingh-nvidia Nov 20, 2025
e0cc39d
reformat
sidsingh-nvidia Nov 20, 2025
a2b8ba2
reformat mamba layer spec
sidsingh-nvidia Nov 20, 2025
afebe2f
Merge branch 'main' into siddharth/torch_symm
sidsingh-nvidia Nov 20, 2025
1c069aa
remove temp files
sidsingh-nvidia Nov 20, 2025
d236b97
guard triton imports
sidsingh-nvidia Nov 20, 2025
06ba2c8
more guards
sidsingh-nvidia Nov 20, 2025
716ffe2
guard triton properly
sidsingh-nvidia Nov 20, 2025
a71d2b3
decouple cuda graph max size from max requests
sidsingh-nvidia Nov 20, 2025
fdff4d7
reformat
sidsingh-nvidia Nov 20, 2025
9d25123
delete memory buffer in destroy_model_parallel, remove cuda graph max…
sidsingh-nvidia Nov 20, 2025
ea6656e
restore example
sidsingh-nvidia Nov 20, 2025
32d7adb
do not instantiate symm memory if triton is absent
sidsingh-nvidia Nov 20, 2025
c1e961d
Merge branch 'main' into siddharth/torch_symm
sidsingh-nvidia Nov 20, 2025
6b97ef3
pylint warning
sidsingh-nvidia Nov 20, 2025
cf0d56d
restore arguments.py
sidsingh-nvidia Nov 20, 2025
2fce4b9
Merge branch 'main' into siddharth/torch_symm
sidsingh-nvidia Nov 20, 2025
deb4bf0
Merge branch 'main' into siddharth/torch_symm
sidsingh-nvidia Nov 21, 2025
8f4f965
Merge branch 'main' into siddharth/torch_symm
sidsingh-nvidia Dec 1, 2025
6f08d97
cuda-graph dimension fix for TP
sidsingh-nvidia Dec 1, 2025
baab70c
consolidate load store instructions
sidsingh-nvidia Dec 1, 2025
49b23e3
fix example to create smaller traces and use kraken's heuristics for …
sidsingh-nvidia Dec 1, 2025
c738933
migrate to kraken barrier
sidsingh-nvidia Dec 1, 2025
252cc9c
correct copyright headers and reformat
sidsingh-nvidia Dec 2, 2025
f0b3c41
restore previous branching code and accumulate in fp32
sidsingh-nvidia Dec 2, 2025
e2e17d9
reformat
sidsingh-nvidia Dec 2, 2025
f894c52
update golden values to account for fp32 accumulation
sidsingh-nvidia Dec 2, 2025
e2f3b5e
latest
sidsingh-nvidia Dec 2, 2025
b424a7c
Merge branch 'main' into siddharth/torch_symm
sidsingh-nvidia Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mamba_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from megatron.core.transformer.spec_utils import import_module
from megatron.training import print_rank_0
from megatron.training.arguments import core_transformer_config_from_args

from megatron.core.models.mamba.mamba_layer_specs import mamba_inference_stack_spec

def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None):
print_rank_0('building MAMBA model ...')
if config is None:
config = core_transformer_config_from_args(args, TransformerConfig)
assert args.use_legacy_models is False, "Mamba only supported in Mcore!"

if args.spec is not None:
if config.transformer_impl == "inference_optimized":
mamba_stack_spec = mamba_inference_stack_spec
elif args.spec is not None:
mamba_stack_spec = import_module(args.spec)
else:
raise ValueError("You must provide a valid Mamba layer spec via --spec")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

from .multimem import multimem_all_gather, multimem_reduce_scatter
142 changes: 142 additions & 0 deletions megatron/core/inference/communication/torch_symm_triton/asm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from unittest.mock import MagicMock

from megatron.core.utils import null_decorator

try:
import triton
import triton.language as tl
except ImportError:
triton = MagicMock()
tl = MagicMock()
triton.jit = null_decorator


@triton.jit
def multimem_ld_reduce_128(multicast_ptrs, mask):
"""
Multicast load and reduce 128 bits (4 x bf16) from all peers over nvlink
Outputs are returned as 4 tl.uint32 registers, each containing 2 bf16 values
"""
return tl.inline_asm_elementwise(
"""
{
.reg .pred %p0;
setp.eq.s32 %p0, $5, 1;
@!%p0 bra end;
multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {$0, $1, $2, $3}, [$4];
end:
}
""",
"=r,=r,=r,=r,l,r",
args=[multicast_ptrs, mask.to(tl.int32)],
dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32),
is_pure=True,
pack=1,
)


@triton.jit
def multimem_st_128(multicast_ptrs, x, y, z, w, mask):
"""
Multicast store 128 bits (4 x bf16) to all peers over nvlink
"""
return tl.inline_asm_elementwise(
"""
{
.reg .pred %p0;
setp.eq.s32 %p0, $6, 1;
@!%p0 bra end;
multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5};
end:
}
""",
"=r,l,r,r,r,r,r",
args=[multicast_ptrs, x, y, z, w, mask.to(tl.int32)],
dtype=(tl.uint32),
is_pure=False,
pack=1,
)


@triton.jit
def ld_128(ptr, mask):
"""
Load 128 bits (4 x bf16) from ptr
Outputs are returned as 4 tl.uint32 registers, each containing 2 bf16 values
"""
return tl.inline_asm_elementwise(
"""
{
.reg .pred %p0;
setp.eq.s32 %p0, $5, 1;
@!%p0 bra end;
ld.global.relaxed.sys.v4.u32 {$0, $1, $2, $3}, [$4];
end:
}
""",
"=r,=r,=r,=r,l,r",
args=[ptr, mask.to(tl.int32)],
dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32),
is_pure=True,
pack=1,
)


@triton.jit
def st_128(ptr, x, y, z, w, mask):
"""
Store 128 bits (8 x bf16) to ptr
each of x, y, z, w is a tl.uint32 register
containing 2 bf16 values
"""
return tl.inline_asm_elementwise(
"""
{
.reg .pred %p0;
setp.eq.s32 %p0, $6, 1;
@!%p0 bra end;
st.global.relaxed.sys.v4.f32 [$1], {$2, $3, $4, $5};
end:
}
""",
"=r,l,r,r,r,r,r",
args=[ptr, x, y, z, w, mask.to(tl.int32)],
dtype=(tl.uint32),
is_pure=False,
pack=1,
)


@triton.jit
def add_v8_bf16_from_u32(
a0,
a1,
a2,
a3, # First vector of 8 bf16s, packed in 4 uint32s
b0,
b1,
b2,
b3, # Second vector of 8 bf16s, packed in 4 uint32s
):
"""
Adds two vectors of 8 bfloat16 numbers.
Each vector is passed as four tl.uint32 tensors.
Returns the result as a tuple of four tl.uint32 tensors.
"""
return tl.inline_asm_elementwise(
"""
{
add.bf16x2 $0, $4, $8;
add.bf16x2 $1, $5, $9;
add.bf16x2 $2, $6, $10;
add.bf16x2 $3, $7, $11;
}
""",
# 8 outputs (=r), 8 inputs (r)
"=r,=r,=r,=r,r,r,r,r,r,r,r,r",
args=[a0, a1, a2, a3, b0, b1, b2, b3],
dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32),
is_pure=True,
pack=1,
)
Loading
Loading