Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,7 @@ def initialize_attention_state(
self.padded_active_token_count = min(
self.padded_active_token_count, self.max_requests
)
self.padding_slice = slice(active_token_count, self.padded_active_token_count)

# How are we calculating the padded active request count?
# Case 1: Using cuda graphs:
Expand Down
24 changes: 18 additions & 6 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.utils import (
WrappedTensor,
deprecate_inference_params,
is_using_quantization_scales,
)


class GPTModel(LanguageModule):
Expand Down Expand Up @@ -386,11 +390,19 @@ def _preprocess(
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)
if in_inference_mode:
# Clear the outputs for padding tokens when using dynamic batching with
# quantization scales to avoid corrupting amax calculations
if inference_context.is_dynamic_batching() and is_using_quantization_scales(
self.config
):
decoder_input[inference_context.padding_slice] = 0.0

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

preproc_output = (
decoder_input,
Expand Down
16 changes: 15 additions & 1 deletion megatron/core/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.utils import (
WrappedTensor,
deprecate_inference_params,
is_using_quantization_scales,
)


class MambaModel(LanguageModule):
Expand Down Expand Up @@ -201,6 +205,16 @@ def forward(
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)

# Clear the outputs for padding tokens when using dynamic batching with
# quantization scales to avoid corrupting amax calculations
# TODO(ksanthanam): Add unit test once dynamic engine supports hybrid models
if (
in_inference_mode
and inference_context.is_dynamic_batching()
and is_using_quantization_scales(self.config)
):
decoder_input[inference_context.padding_slice] = 0.0
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
Expand Down
7 changes: 6 additions & 1 deletion megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_pg_size,
is_fa_min_version,
is_te_min_version,
is_using_quantization_scales,
nvtx_range_pop,
nvtx_range_push,
)
Expand Down Expand Up @@ -915,6 +916,11 @@ def forward(
)
core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)')

# Clear the outputs for padding tokens when using quantization scales
# to avoid corrupting amax calculations
if is_using_quantization_scales(self.config):
core_attn_out[inference_context.padding_slice] = 0.0

if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
Expand All @@ -926,7 +932,6 @@ def forward(
# =================
# Output. [sq, b, h]
# =================

nvtx_range_push(suffix="linear_proj")
output, bias = self.linear_proj(core_attn_out)
nvtx_range_pop(suffix="linear_proj")
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,11 @@ def get_asyncio_loop(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.A
return loop


def is_using_quantization_scales(config):
"""Returns whether the model is using quantization scales based on the config."""
return getattr(config, "fp8", False) or getattr(config, "fp4", False)


_ASYNC_TASK_STATS = defaultdict(lambda: [0, 0.0]) # cnt, total_time


Expand Down
Loading
Loading