diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 48e9e9b48f..b0175de6f8 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1079,6 +1079,7 @@ def initialize_attention_state( self.padded_active_token_count = min( self.padded_active_token_count, self.max_active_requests ) + self.padding_slice = slice(active_token_count, self.padded_active_token_count) # How are we calculating the padded active request count? # Case 1: Using cuda graphs: diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 25546d3662..4453a8c213 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -33,7 +33,11 @@ from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import WrappedTensor, deprecate_inference_params +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + is_using_quantization_scales, +) class GPTModel(LanguageModule): @@ -386,11 +390,19 @@ def _preprocess( else: sequence_len_offset = None - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) + if in_inference_mode: + # Clear the outputs for padding tokens when using dynamic batching with + # quantization scales to avoid corrupting amax calculations + if inference_context.is_dynamic_batching() and is_using_quantization_scales( + self.config + ): + decoder_input[inference_context.padding_slice] = 0.0 + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if not has_config_logger_enabled(self.config): + decoder_input = WrappedTensor(decoder_input) preproc_output = ( decoder_input, diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 378cf7e47d..51b9bd1677 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -16,7 +16,11 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import ModelType from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.utils import WrappedTensor, deprecate_inference_params +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + is_using_quantization_scales, +) class MambaModel(LanguageModule): @@ -201,6 +205,15 @@ def forward( pass elif self.pre_process: decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + + # Clear the outputs for padding tokens when using dynamic batching with + # quantization scales to avoid corrupting amax calculations + if ( + in_inference_mode + and inference_context.is_dynamic_batching() + and is_using_quantization_scales(self.config) + ): + decoder_input[inference_context.padding_slice] = 0.0 else: # intermediate stage of pipeline # decoder will get hidden_states from encoder.input_tensor diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 1051799db9..d839af3790 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -33,6 +33,7 @@ get_pg_size, is_fa_min_version, is_te_min_version, + is_using_quantization_scales, nvtx_range_pop, nvtx_range_push, ) @@ -949,6 +950,11 @@ def forward( ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + # Clear the outputs for padding tokens when using quantization scales + # to avoid corrupting amax calculations + if is_using_quantization_scales(self.config): + core_attn_out[inference_context.padding_slice] = 0.0 + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': # reshape to same output shape as unpacked case # (t, np, hn) -> (t, b=1, h=np*hn) @@ -960,7 +966,6 @@ def forward( # ================= # Output. [sq, b, h] # ================= - nvtx_range_push(suffix="linear_proj") output, bias = self.linear_proj(core_attn_out) nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 19015af62d..003c44ab29 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -2099,6 +2099,11 @@ def get_asyncio_loop(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.A return loop +def is_using_quantization_scales(config): + """Returns whether the model is using quantization scales based on the config.""" + return getattr(config, "fp8", False) or getattr(config, "fp4", False) + + _ASYNC_TASK_STATS = defaultdict(lambda: [0, 0.0]) # cnt, total_time diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/golden_values_dev_dgx_h100.json index cbc5f4fa3a..ed4f7ce423 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/golden_values_dev_dgx_h100.json @@ -1,38 +1,38 @@ { "0": { "input_prompt": "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies.", - "generated_text": " And then you get to the end of the movie, and you realize that this is not New York at all. This is New York at the end", + "generated_text": " And that this is the place where you can be yourself, and be yourself in the most beautiful way possible. And that this is the place where you", "generated_tokens": [ 3060, - 2430, - 1636, - 2012, - 1317, - 1278, - 2362, - 1307, + 1455, + 1593, + 1395, 1278, - 16070, + 3535, + 2478, + 1636, + 1710, + 1402, + 14019, 1044, 1321, - 1636, - 23067, + 1402, + 14019, + 1294, + 1278, + 2725, + 15568, + 3039, + 4171, + 1046, + 3060, 1455, 1593, 1395, - 1605, - 3140, - 5152, - 1513, - 1747, - 1046, - 2409, - 1395, - 3140, - 5152, - 1513, 1278, - 2362 + 3535, + 2478, + 1636 ], "latency": 0.2963709831237793, "cuda_graph_request_count_map": { @@ -143,35 +143,35 @@ -4.629369258880615, -3.4186267852783203, -1.9727531671524048, - -2.354729652404785, - -1.474542498588562, - -2.48478364944458, - -1.7641210556030273, - -1.1853944063186646, - -2.8624324798583984, - -0.5740103125572205, - -0.4542185962200165, - -1.4300930500030518, - -0.8807456493377686, - -0.4597663879394531, - -0.9252307415008545, - -1.648141860961914, - -0.44453874230384827, - -1.818476915359497, - -0.5714479088783264, - -1.2115143537521362, - -1.0910619497299194, - -0.0023161747958511114, - -1.3206473588943481, - -0.008621376007795334, - -0.7551823854446411, - -0.9404395818710327, - -0.07279698550701141, - -0.9365248680114746, - -0.03344438225030899, - -1.9720849990844727, - -1.3928067684173584, - -0.7453650832176208 + -2.3787546157836914, + -2.1559927463531494, + -0.3042203187942505, + -1.5030715465545654, + -2.2185328006744385, + -1.163987398147583, + -1.6270570755004883, + -1.382636547088623, + -2.0105135440826416, + -1.4725666046142578, + -0.8826082944869995, + -1.2427284717559814, + -1.966317892074585, + -2.320446252822876, + -1.6230499744415283, + -0.9565537571907043, + -1.8637971878051758, + -2.4722466468811035, + -1.0165706872940063, + -1.5226590633392334, + -0.9222670793533325, + -0.3918261229991913, + -0.4523355960845947, + -0.2231457382440567, + -0.014109030365943909, + -0.36200955510139465, + -0.2032179832458496, + -0.043058376759290695, + -0.09619185328483582 ] }, "throughput": [ @@ -184,4 +184,4 @@ 100.30260782227111, 100.30996418216475 ] -} \ No newline at end of file +} diff --git a/tests/unit_tests/models/test_gpt_model.py b/tests/unit_tests/models/test_gpt_model.py index 6936cfbe60..469a04f6f3 100644 --- a/tests/unit_tests/models/test_gpt_model.py +++ b/tests/unit_tests/models/test_gpt_model.py @@ -8,9 +8,13 @@ import torch from packaging import version from pytest import approx +from transformer_engine.pytorch.fp8 import check_fp8_support from megatron.core import parallel_state from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext +from megatron.core.inference.inference_request import DynamicInferenceRequest +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_with_transformer_engine_spec, get_mlp_module_spec, @@ -18,8 +22,9 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.module import Float16Module from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version +from megatron.core.utils import is_fa_min_version, is_te_min_version from tests.unit_tests.test_utilities import Utils @@ -333,3 +338,108 @@ def test_gpt_model_with_custom_pg(self, tp_size, dp_size, cp_size): assert logits.shape[0] == sequence_length assert logits.shape[1] == micro_batch_size assert logits.shape[2] == self.gpt_model.config.hidden_size + + +class TestGPTWithDynamicInference: + """Tests GPTModel with dynamic inference.""" + + @torch.inference_mode() + def setup_method(self, method): + fp8_available, reason_for_no_fp8 = check_fp8_support() + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + transformer_config = TransformerConfig( + num_layers=4, + hidden_size=64, + num_attention_heads=8, + use_cpu_initialization=True, + params_dtype=torch.bfloat16, + bf16=True, + fp8="hybrid", + fp8_recipe="tensorwise", + ) + + self.gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), + vocab_size=128, + max_sequence_length=DynamicInferenceContext.TOKEN_ROUNDER, + parallel_output=True, + ) + self.gpt_model = Float16Module(self.gpt_model.config, self.gpt_model) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_dynamic_inference_padding_with_fp8(self): + """ + Tests that logits for padded tokens are zeroed out for fp8 inference. + """ + self.gpt_model.cuda() + self.gpt_model.eval() + config = self.gpt_model.config + + inference_context = DynamicInferenceContext( + params_dtype=config.params_dtype, + num_layers=config.num_layers, + kv_channels=config.hidden_size // config.num_attention_heads, + num_attention_heads=config.num_attention_heads, + max_sequence_length=self.gpt_model.module.max_sequence_length, + buffer_size_gb=1.0, + block_size_tokens=256, + materialize_only_last_token_logits=False, + ) + + # Add a request with 10 tokens. Since 10 is not a multiple of 64, + # this will create padding up to the padded length of 64. + active_token_count = 10 + request = DynamicInferenceRequest( + request_id=0, + prompt_tokens=torch.arange(0, active_token_count, dtype=torch.long, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=1), + ) + inference_context.add_request(request) + + # Prepares the context, including calculating the padded token count. + inference_context.initialize_attention_state() + + assert inference_context.active_token_count == active_token_count + assert inference_context.padded_active_token_count == DynamicInferenceContext.TOKEN_ROUNDER + + # Prepare inputs for the forward pass. + padded_token_count = inference_context.padded_active_token_count + input_ids, position_ids = inference_context.current_input_and_position_ids() + + # Run the forward pass with inference parameters. + logits = self.gpt_model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=None, + inference_context=inference_context, + runtime_gather_output=True, + ) + + # Verify the output shape. + assert logits.shape[0] == 1 + assert logits.shape[1] == padded_token_count + assert logits.shape[2] == self.gpt_model.module.vocab_size + + # Extract the logits corresponding to the padding tokens (from index 10 to 63). + padding_start_idx = inference_context.active_token_count + padding_end_idx = inference_context.padded_active_token_count + padding_logits = logits[0, padding_start_idx:padding_end_idx, :] + + # Assert that all padding logits are zero. + assert torch.all(padding_logits == 0.0), "Logits for padding tokens are not all zero." diff --git a/tests/unit_tests/models/test_mamba_model.py b/tests/unit_tests/models/test_mamba_model.py index ca42ae496b..1d05f6e1ac 100644 --- a/tests/unit_tests/models/test_mamba_model.py +++ b/tests/unit_tests/models/test_mamba_model.py @@ -1,18 +1,29 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import os from datetime import timedelta import pytest import torch +from transformer_engine.pytorch.fp8 import check_fp8_support from megatron.core import parallel_state from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.inference.contexts import BaseInferenceContext, StaticInferenceContext +from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext +from megatron.core.inference.inference_request import DynamicInferenceRequest +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig -from megatron.core.utils import divide, is_torch_min_version +from megatron.core.transformer.module import Float16Module +from megatron.core.utils import ( + divide, + get_mamba_inference_state_config_from_model, + is_fa_min_version, + is_torch_min_version, +) from tests.unit_tests.test_utilities import Utils @@ -218,3 +229,115 @@ def test_with_custom_process_groups(self, tmp_path, tp_size, cp_size, pp_size): assert logits.shape[0] == micro_batch_size assert logits.shape[1] == sequence_length assert logits.shape[2] == divide(model.vocab_size, tp_size) + + +class TestMambaWithDynamicInference: + """Tests MambaModel with dynamic inference.""" + + @torch.inference_mode() + def setup_method(self, method): + fp8_available, reason_for_no_fp8 = check_fp8_support() + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + transformer_config = TransformerConfig( + num_layers=3, # 1 Mamba layer, 1 attention layer, 1 MLP layer + hidden_size=256, + mamba_num_heads=16, + num_attention_heads=16, + use_cpu_initialization=True, + params_dtype=torch.bfloat16, + bf16=True, + fp8="hybrid", + fp8_recipe="tensorwise", + ) + + self.mamba_model = MambaModel( + config=transformer_config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=128, + max_sequence_length=DynamicInferenceContext.TOKEN_ROUNDER, + hybrid_attention_ratio=0.3, + hybrid_mlp_ratio=0.3, + parallel_output=True, + ) + self.mamba_model = Float16Module(self.mamba_model.config, self.mamba_model) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_dynamic_inference_padding_with_fp8(self): + """ + Tests that logits for padded tokens are zeroed out for fp8 inference. + """ + self.mamba_model.cuda() + self.mamba_model.eval() + config = self.mamba_model.config + + # Mamba specific: Retrieve inference state config + mamba_inference_state_config = get_mamba_inference_state_config_from_model( + self.mamba_model.module + ) + + inference_context = DynamicInferenceContext( + params_dtype=config.params_dtype, + num_layers=config.num_layers, + kv_channels=config.hidden_size // config.num_attention_heads, + num_attention_heads=config.num_attention_heads, + max_sequence_length=self.mamba_model.module.max_sequence_length, + buffer_size_gb=1.0, + block_size_tokens=256, + materialize_only_last_token_logits=False, + mamba_inference_state_config=mamba_inference_state_config, + use_cuda_graphs_for_non_decode_steps=False, + ) + + # Add a request with 10 tokens. Since 10 is not a multiple of the rounder, + # this will create padding up to the padded length. + active_token_count = 10 + request = DynamicInferenceRequest( + request_id=0, + prompt_tokens=torch.arange(0, active_token_count, dtype=torch.long, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=1), + ) + inference_context.add_request(request) + + # Prepares the context, including calculating the padded token count. + inference_context.initialize_attention_state() + + assert inference_context.active_token_count == active_token_count + assert inference_context.padded_active_token_count == DynamicInferenceContext.TOKEN_ROUNDER + + # Prepare inputs for the forward pass. + padded_token_count = inference_context.padded_active_token_count + input_ids, position_ids = inference_context.current_input_and_position_ids() + + # Run the forward pass with inference parameters. + logits = self.mamba_model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=None, + inference_context=inference_context, + runtime_gather_output=True, + ) + + # Verify the output shape. + assert logits.shape[0] == 1 + assert logits.shape[1] == padded_token_count + assert logits.shape[2] == self.mamba_model.module.vocab_size + + # Extract the logits corresponding to the padding tokens. + padding_start_idx = inference_context.active_token_count + padding_end_idx = inference_context.padded_active_token_count + padding_logits = logits[0, padding_start_idx:padding_end_idx, :] + + # Assert that all padding logits are zero. + assert torch.all(padding_logits == 0.0), "Logits for padding tokens are not all zero."