Skip to content

Commit 6c8cdd5

Browse files
authored
Add MambaInferenceStateConfig dataclass (#2265)
Signed-off-by: Keshav Santhanam <[email protected]>
1 parent 712dff8 commit 6c8cdd5

File tree

10 files changed

+96
-111
lines changed

10 files changed

+96
-111
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
ContextOverflowError,
3131
DynamicInferenceContext,
3232
)
33+
from megatron.core.inference.context.attention_context.mamba_metadata import (
34+
MambaInferenceStateConfig,
35+
)
3336
from megatron.core.inference.engines import DynamicInferenceEngine, EngineSuspendedError
3437
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
3538
GPTInferenceWrapper,
@@ -38,10 +41,9 @@
3841
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
3942
TextGenerationController,
4043
)
41-
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
4244
from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer
4345
from megatron.core.transformer.module import MegatronModule
44-
from megatron.core.utils import get_attr_wrapped_model
46+
from megatron.core.utils import get_mamba_inference_state_config_from_model
4547

4648
sys.path.append(
4749
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
@@ -150,9 +152,7 @@ def get_inference_context(
150152
requests: List[Request],
151153
sampling_params: Optional[SamplingParams] = None,
152154
calculate_max_sequence_length_from_requests: bool = True,
153-
layer_type_list: Optional[List[str]] = None,
154-
mamba_conv_states_shape: Optional[Tuple[int]] = None,
155-
mamba_ssm_states_shape: Optional[Tuple[int]] = None,
155+
mamba_inference_state_config: Optional[MambaInferenceStateConfig] = None,
156156
):
157157
"""The inference context manages the KV cache and other inference state."""
158158

@@ -189,9 +189,7 @@ def get_inference_context(
189189
max_tokens=args.inference_dynamic_batching_max_tokens,
190190
tensor_model_parallel_size=args.tensor_model_parallel_size,
191191
materialize_only_last_token_logits=not args.return_log_probs,
192-
layer_type_list=layer_type_list,
193-
mamba_conv_states_shape=mamba_conv_states_shape,
194-
mamba_ssm_states_shape=mamba_ssm_states_shape,
192+
mamba_inference_state_config=mamba_inference_state_config,
195193
cache_mla_latent=args.multi_latent_attention and args.cache_mla_latents,
196194
kv_lora_rank=args.kv_lora_rank if args.multi_latent_attention else None,
197195
qk_pos_emb_head_dim=args.qk_pos_emb_head_dim,
@@ -443,23 +441,14 @@ def main():
443441

444442
model = get_model()
445443

446-
# Layer type list for hybrid models
447-
decoder = get_attr_wrapped_model(model, "decoder")
448-
layer_type_list = getattr(decoder, "layer_type_list", None)
449-
if layer_type_list is not None and Symbols.MAMBA in layer_type_list:
450-
(mamba_conv_states_shape, mamba_ssm_states_shape) = decoder.mamba_state_shapes_per_request()
451-
else:
452-
mamba_conv_states_shape = None
453-
mamba_ssm_states_shape = None
444+
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
454445

455446
# Requests, context, controller.
456447
requests = build_requests(args, tokenizer, sampling_params)
457448
context = get_inference_context(
458449
requests,
459450
sampling_params,
460-
layer_type_list=layer_type_list,
461-
mamba_conv_states_shape=mamba_conv_states_shape,
462-
mamba_ssm_states_shape=mamba_ssm_states_shape,
451+
mamba_inference_state_config=mamba_inference_state_config,
463452
)
464453
controller = get_inference_controller(model, context)
465454

examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from megatron.core.inference.inference_client import InferenceClient
3131
from megatron.core.inference.inference_request import DynamicInferenceRequestRecord
3232
from megatron.core.inference.sampling_params import SamplingParams
33-
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
34-
from megatron.core.utils import get_attr_wrapped_model
33+
from megatron.core.utils import get_mamba_inference_state_config_from_model
3534

3635
from megatron.training import get_args, get_tokenizer, initialize_megatron
3736
from megatron.training.arguments import parse_args
@@ -225,28 +224,16 @@ async def main(
225224

226225
# Requests, context, conroller.
227226
model = get_model()
227+
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
228228
requests = (
229229
build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None
230230
)
231231

232-
# Layer type list for hybrid models
233-
decoder = get_attr_wrapped_model(model, "decoder")
234-
layer_type_list = getattr(decoder, "layer_type_list", None)
235-
if layer_type_list is not None and Symbols.MAMBA in layer_type_list:
236-
(mamba_conv_states_shape, mamba_ssm_states_shape) = (
237-
decoder.mamba_state_shapes_per_request()
238-
)
239-
else:
240-
mamba_conv_states_shape = None
241-
mamba_ssm_states_shape = None
242-
243232
context = get_inference_context(
244233
None,
245234
None,
246235
calculate_max_sequence_length_from_requests=False,
247-
layer_type_list=layer_type_list,
248-
mamba_conv_states_shape=mamba_conv_states_shape,
249-
mamba_ssm_states_shape=mamba_ssm_states_shape,
236+
mamba_inference_state_config=mamba_inference_state_config,
250237
)
251238

252239
controller = get_inference_controller(model, context)

megatron/core/inference/contexts/attention_context/mamba_metadata.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

3+
from dataclasses import dataclass
4+
from typing import List, Optional, Tuple
5+
36
import torch
47

58

9+
@dataclass
10+
class MambaInferenceStateConfig:
11+
"""Config for initializing Mamba model inference state tensors."""
12+
13+
layer_type_list: List[str]
14+
"""
15+
A list of strings that indicates the layer type (Mamba / Attention / MLP) for each layer.
16+
See `megatron/core/ssm/mamba_hybrid_layer_allocation.py` for the list of symbols.
17+
"""
18+
19+
mamba_conv_states_shape: Tuple[int]
20+
"""Mamba conv states shape per request."""
21+
22+
mamba_ssm_states_shape: Tuple[int]
23+
"""Mamba ssm states shape per request."""
24+
25+
626
class MambaMetadata:
727
"""Manages the metadata tensors required for Mamba layers during inference."""
828

@@ -64,7 +84,7 @@ def update_cudagraph_mapping(
6484
"""
6585
self.request_to_mamba_state_idx_cudagraph_only[0:num_active_requests] = active_mamba_indices
6686

67-
def allocate_slot(self) -> int:
87+
def allocate_slot(self) -> Optional[int]:
6888
"""
6989
Allocates a new slot for a request in the Mamba state buffers.
7090

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,11 @@
2424
from megatron.core.inference.utils import tensor_swap
2525
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
2626
from megatron.core.package_info import __version__ as mcore_version
27-
from megatron.core.ssm.mamba_hybrid_layer_allocation import (
28-
Symbols,
29-
get_layer_maps_from_layer_type_list,
30-
)
27+
from megatron.core.ssm.mamba_hybrid_layer_allocation import get_layer_maps_from_layer_type_list
3128
from megatron.core.transformer import TransformerConfig
3229
from megatron.core.utils import divide as core_divide
3330

34-
from .attention_context.mamba_metadata import MambaMetadata
31+
from .attention_context.mamba_metadata import MambaInferenceStateConfig, MambaMetadata
3532
from .attention_context.mha_metadata import GraphedMHAMetadata, NonGraphedMHAMetadata
3633
from .base_context import BaseInferenceContext
3734
from .dynamic_block_allocator import BlockAllocator
@@ -231,14 +228,8 @@ class DynamicInferenceContext(BaseInferenceContext):
231228
materialize_only_last_token_logits (Optional[bool]): Whether to only
232229
materialize logits for the last token. This should be set to False
233230
if returning log probs.
234-
layer_type_list (Optional[List[str]]): A list of strings that indicates
235-
the layer type (Mamba / Attention / MLP) for each layer.
236-
See `megatron/core/ssm/mamba_hybrid_layer_allocation.py` for the list
237-
of symbols. This must be provided for hybrid models.
238-
mamba_conv_states_shape: (Optional[Tuple[int]]): Mamba conv states shape per request.
239-
This must be provided for hybrid models.
240-
mamba_ssm_states_shape: (Optional[Tuple[int]]): Mamba ssm states shape per request.
241-
This must be provided for hybrid models.
231+
mamba_inference_state_config (Optional[MambaInferenceStateConfig]): The Mamba
232+
inference state config if the model is a hybrid model.
242233
use_cuda_graphs_for_non_decode_steps (bool): If True, use cuda graphs for non-decode
243234
engine steps.
244235
unified_memory_level (Optional[int]): Set unified memory usage within the
@@ -274,9 +265,7 @@ def __init__(
274265
qk_pos_emb_head_dim: Optional[int] = None,
275266
num_cuda_graphs: Optional[int] = None,
276267
materialize_only_last_token_logits: Optional[bool] = True,
277-
layer_type_list: Optional[List[str]] = None,
278-
mamba_conv_states_shape: Optional[Tuple[int]] = None,
279-
mamba_ssm_states_shape: Optional[Tuple[int]] = None,
268+
mamba_inference_state_config: Optional[MambaInferenceStateConfig] = None,
280269
use_cuda_graphs_for_non_decode_steps: bool = True,
281270
use_flashinfer_fused_rope: bool = False,
282271
unified_memory_level: Optional[int] = 1,
@@ -303,8 +292,10 @@ def __init__(
303292
self.num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size)
304293

305294
# Mamba states.
306-
self.is_hybrid_model = layer_type_list is not None and Symbols.MAMBA in layer_type_list
295+
self.is_hybrid_model = mamba_inference_state_config is not None
307296
if self.is_hybrid_model:
297+
mamba_conv_states_shape = mamba_inference_state_config.mamba_conv_states_shape
298+
mamba_ssm_states_shape = mamba_inference_state_config.mamba_ssm_states_shape
308299
assert (
309300
mamba_conv_states_shape is not None
310301
), "`mamba_conv_states_shape` must be specified for hybrid models"
@@ -319,7 +310,7 @@ def __init__(
319310
# corresponding attention layer index or Mamba layer index depending on the
320311
# layer type.
321312
attention_layer_map, mamba_layer_map, _, _ = get_layer_maps_from_layer_type_list(
322-
layer_type_list
313+
mamba_inference_state_config.layer_type_list
323314
)
324315
self.num_attention_layers = len(attention_layer_map)
325316
self.num_mamba_layers = len(mamba_layer_map)
@@ -728,6 +719,7 @@ def from_config(
728719
max_batch_size: int,
729720
buffer_size_gb: float = 40,
730721
num_cuda_graphs: int = None,
722+
mamba_inference_state_config: Optional[MambaInferenceStateConfig] = None,
731723
):
732724
"""
733725
Instantiate a `DynamicInferenceContext` from a `TransformerConfig` and an `InferenceWrapperConfig`.
@@ -749,6 +741,7 @@ def from_config(
749741
materialize_only_last_token_logits=False,
750742
num_cuda_graphs=num_cuda_graphs,
751743
use_flashinfer_fused_rope=None,
744+
mamba_inference_state_config=mamba_inference_state_config,
752745
)
753746

754747
@classmethod

megatron/core/inference/engines/static_engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
1818
TextGenerationController,
1919
)
20-
from megatron.core.utils import get_asyncio_loop
20+
from megatron.core.utils import get_asyncio_loop, get_mamba_inference_state_config_from_model
2121

2222
try:
2323
from tqdm import tqdm
@@ -93,6 +93,10 @@ def __init__(
9393
# Store original context in case we need to fall back to legacy static engine
9494
original_context = text_generation_controller.inference_wrapped_model.inference_context
9595

96+
mamba_inference_state_config = get_mamba_inference_state_config_from_model(
97+
text_generation_controller.inference_wrapped_model.model
98+
)
99+
96100
try:
97101
if not legacy:
98102
dynamic_context = DynamicInferenceContext.from_config(
@@ -101,6 +105,7 @@ def __init__(
101105
max_batch_size=max_batch_size,
102106
buffer_size_gb=buffer_size_gb,
103107
num_cuda_graphs=1,
108+
mamba_inference_state_config=mamba_inference_state_config,
104109
)
105110
self.controller.inference_wrapped_model.inference_context = dynamic_context
106111
self.controller.inference_wrapped_model.prep_model_for_inference()

megatron/core/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,6 +2154,25 @@ async def wrapper(*args, **kwargs):
21542154
return _decorate if func is None else _decorate(func)
21552155

21562156

2157+
def get_mamba_inference_state_config_from_model(model) -> Optional["MambaInferenceStateConfig"]:
2158+
"""Returns Mamba inference state config from the model if it is a hybrid model."""
2159+
from megatron.core.inference.contexts.attention_context.mamba_metadata import (
2160+
MambaInferenceStateConfig,
2161+
)
2162+
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
2163+
2164+
decoder = get_attr_wrapped_model(model, "decoder")
2165+
layer_type_list = getattr(decoder, "layer_type_list", None)
2166+
if layer_type_list is not None and Symbols.MAMBA in layer_type_list:
2167+
(mamba_conv_states_shape, mamba_ssm_states_shape) = decoder.mamba_state_shapes_per_request()
2168+
return MambaInferenceStateConfig(
2169+
layer_type_list=layer_type_list,
2170+
mamba_conv_states_shape=mamba_conv_states_shape,
2171+
mamba_ssm_states_shape=mamba_ssm_states_shape,
2172+
)
2173+
return None
2174+
2175+
21572176
# ============================================================================
21582177
# Backward Compatibility Decorators
21592178
# ============================================================================

megatron/rl/inference/megatron.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from megatron.core.models.gpt.gpt_model import GPTModel
2626
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
2727
from megatron.core.transformer.module import MegatronModule
28-
from megatron.core.utils import get_attr_wrapped_model, log_single_rank
28+
from megatron.core.utils import get_mamba_inference_state_config_from_model, log_single_rank
2929
from megatron.training.global_vars import get_args, get_tokenizer
3030

3131
from ..inference.inference_interface import (
@@ -107,14 +107,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
107107
if args.enable_cuda_graph:
108108
num_cuda_graphs = args.inference_dynamic_batching_num_cuda_graphs
109109

110-
# Layer type list for hybrid models
111-
decoder = get_attr_wrapped_model(model, "decoder")
112-
layer_type_list = getattr(decoder, "layer_type_list", None)
113-
if layer_type_list is not None and Symbols.MAMBA in layer_type_list:
114-
(mamba_conv_states_shape, mamba_ssm_states_shape) = decoder.mamba_state_shapes_per_request()
115-
else:
116-
mamba_conv_states_shape = None
117-
mamba_ssm_states_shape = None
110+
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
118111

119112
# Inference context.
120113
inference_context = DynamicInferenceContext(
@@ -135,9 +128,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
135128
tensor_model_parallel_size=args.tensor_model_parallel_size,
136129
materialize_only_last_token_logits=True,
137130
unified_memory_kvcache=args.inference_dynamic_batching_unified_memory_kvcache,
138-
layer_type_list=layer_type_list,
139-
mamba_conv_states_shape=mamba_conv_states_shape,
140-
mamba_ssm_states_shape=mamba_ssm_states_shape,
131+
mamba_inference_state_config=mamba_inference_state_config,
141132
metrics_writer=metrics_writer,
142133
)
143134

tests/unit_tests/inference/contexts/test_dynamic_context.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import pytest
66
import torch
77

8+
from megatron.core.inference.contexts.attention_context.mamba_metadata import (
9+
MambaInferenceStateConfig,
10+
)
811
from megatron.core.inference.contexts.dynamic_context import (
912
DynamicInferenceContext,
1013
RequestOverflowError,
@@ -52,8 +55,16 @@ def _get_dynamic_context(
5255
):
5356
set_rounder(rounder)
5457

55-
if is_hybrid_model and layer_type_list is None:
56-
layer_type_list = [Symbols.MAMBA, Symbols.MLP, Symbols.ATTENTION, Symbols.MLP]
58+
if is_hybrid_model:
59+
if layer_type_list is None:
60+
layer_type_list = [Symbols.MAMBA, Symbols.MLP, Symbols.ATTENTION, Symbols.MLP]
61+
mamba_conv_states_shape = (544, 4)
62+
mamba_ssm_states_shape = (8, 64, 16)
63+
mamba_inference_state_config = MambaInferenceStateConfig(
64+
layer_type_list, mamba_conv_states_shape, mamba_ssm_states_shape
65+
)
66+
else:
67+
mamba_inference_state_config = None
5768

5869
dynamic_context = DynamicInferenceContext(
5970
params_dtype=params_dtype,
@@ -66,9 +77,7 @@ def _get_dynamic_context(
6677
buffer_size_gb=buffer_size_gb,
6778
block_size_tokens=block_size_tokens,
6879
max_tokens=max_tokens,
69-
layer_type_list=layer_type_list,
70-
mamba_conv_states_shape=(544, 4),
71-
mamba_ssm_states_shape=(8, 64, 16),
80+
mamba_inference_state_config=mamba_inference_state_config,
7281
use_flashinfer_fused_rope=None, # default to using flash-infer if available
7382
# this is for compatibility with the LTS environment
7483
unified_memory_level=0, # unit tests currently broken with UVM

0 commit comments

Comments
 (0)