Skip to content

Commit c9d2c8f

Browse files
authored
Explicitly zero out padding token activations for dynamic inference (#2008)
Signed-off-by: Keshav Santhanam <[email protected]>
1 parent c90160d commit c9d2c8f

File tree

8 files changed

+332
-63
lines changed

8 files changed

+332
-63
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,7 @@ def initialize_attention_state(
10791079
self.padded_active_token_count = min(
10801080
self.padded_active_token_count, self.max_active_requests
10811081
)
1082+
self.padding_slice = slice(active_token_count, self.padded_active_token_count)
10821083

10831084
# How are we calculating the padded active request count?
10841085
# Case 1: Using cuda graphs:

megatron/core/models/gpt/gpt_model.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
from megatron.core.transformer.spec_utils import ModuleSpec
3434
from megatron.core.transformer.transformer_block import TransformerBlock
3535
from megatron.core.transformer.transformer_config import TransformerConfig
36-
from megatron.core.utils import WrappedTensor, deprecate_inference_params
36+
from megatron.core.utils import (
37+
WrappedTensor,
38+
deprecate_inference_params,
39+
is_using_quantization_scales,
40+
)
3741

3842

3943
class GPTModel(LanguageModule):
@@ -386,11 +390,19 @@ def _preprocess(
386390
else:
387391
sequence_len_offset = None
388392

389-
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
390-
# reference held by this caller function, enabling early garbage collection for
391-
# inference. Skip wrapping if decoder_input is logged after decoder completion.
392-
if in_inference_mode and not has_config_logger_enabled(self.config):
393-
decoder_input = WrappedTensor(decoder_input)
393+
if in_inference_mode:
394+
# Clear the outputs for padding tokens when using dynamic batching with
395+
# quantization scales to avoid corrupting amax calculations
396+
if inference_context.is_dynamic_batching() and is_using_quantization_scales(
397+
self.config
398+
):
399+
decoder_input[inference_context.padding_slice] = 0.0
400+
401+
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
402+
# reference held by this caller function, enabling early garbage collection for
403+
# inference. Skip wrapping if decoder_input is logged after decoder completion.
404+
if not has_config_logger_enabled(self.config):
405+
decoder_input = WrappedTensor(decoder_input)
394406

395407
preproc_output = (
396408
decoder_input,

megatron/core/models/mamba/mamba_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from megatron.core.transformer import TransformerConfig
1717
from megatron.core.transformer.enums import ModelType
1818
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
19-
from megatron.core.utils import WrappedTensor, deprecate_inference_params
19+
from megatron.core.utils import (
20+
WrappedTensor,
21+
deprecate_inference_params,
22+
is_using_quantization_scales,
23+
)
2024

2125

2226
class MambaModel(LanguageModule):
@@ -201,6 +205,15 @@ def forward(
201205
pass
202206
elif self.pre_process:
203207
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
208+
209+
# Clear the outputs for padding tokens when using dynamic batching with
210+
# quantization scales to avoid corrupting amax calculations
211+
if (
212+
in_inference_mode
213+
and inference_context.is_dynamic_batching()
214+
and is_using_quantization_scales(self.config)
215+
):
216+
decoder_input[inference_context.padding_slice] = 0.0
204217
else:
205218
# intermediate stage of pipeline
206219
# decoder will get hidden_states from encoder.input_tensor

megatron/core/transformer/attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_pg_size,
3434
is_fa_min_version,
3535
is_te_min_version,
36+
is_using_quantization_scales,
3637
nvtx_range_pop,
3738
nvtx_range_push,
3839
)
@@ -949,6 +950,11 @@ def forward(
949950
)
950951
core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)')
951952

953+
# Clear the outputs for padding tokens when using quantization scales
954+
# to avoid corrupting amax calculations
955+
if is_using_quantization_scales(self.config):
956+
core_attn_out[inference_context.padding_slice] = 0.0
957+
952958
if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
953959
# reshape to same output shape as unpacked case
954960
# (t, np, hn) -> (t, b=1, h=np*hn)
@@ -960,7 +966,6 @@ def forward(
960966
# =================
961967
# Output. [sq, b, h]
962968
# =================
963-
964969
nvtx_range_push(suffix="linear_proj")
965970
output, bias = self.linear_proj(core_attn_out)
966971
nvtx_range_pop(suffix="linear_proj")

megatron/core/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,6 +2099,11 @@ def get_asyncio_loop(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.A
20992099
return loop
21002100

21012101

2102+
def is_using_quantization_scales(config):
2103+
"""Returns whether the model is using quantization scales based on the config."""
2104+
return getattr(config, "fp8", False) or getattr(config, "fp4", False)
2105+
2106+
21022107
_ASYNC_TASK_STATS = defaultdict(lambda: [0, 0.0]) # cnt, total_time
21032108

21042109

tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/golden_values_dev_dgx_h100.json

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
11
{
22
"0": {
33
"input_prompt": "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies.",
4-
"generated_text": " And then you get to the end of the movie, and you realize that this is not New York at all. This is New York at the end",
4+
"generated_text": " And that this is the place where you can be yourself, and be yourself in the most beautiful way possible. And that this is the place where you",
55
"generated_tokens": [
66
3060,
7-
2430,
8-
1636,
9-
2012,
10-
1317,
11-
1278,
12-
2362,
13-
1307,
7+
1455,
8+
1593,
9+
1395,
1410
1278,
15-
16070,
11+
3535,
12+
2478,
13+
1636,
14+
1710,
15+
1402,
16+
14019,
1617
1044,
1718
1321,
18-
1636,
19-
23067,
19+
1402,
20+
14019,
21+
1294,
22+
1278,
23+
2725,
24+
15568,
25+
3039,
26+
4171,
27+
1046,
28+
3060,
2029
1455,
2130
1593,
2231
1395,
23-
1605,
24-
3140,
25-
5152,
26-
1513,
27-
1747,
28-
1046,
29-
2409,
30-
1395,
31-
3140,
32-
5152,
33-
1513,
3432
1278,
35-
2362
33+
3535,
34+
2478,
35+
1636
3636
],
3737
"latency": 0.2963709831237793,
3838
"cuda_graph_request_count_map": {
@@ -143,35 +143,35 @@
143143
-4.629369258880615,
144144
-3.4186267852783203,
145145
-1.9727531671524048,
146-
-2.354729652404785,
147-
-1.474542498588562,
148-
-2.48478364944458,
149-
-1.7641210556030273,
150-
-1.1853944063186646,
151-
-2.8624324798583984,
152-
-0.5740103125572205,
153-
-0.4542185962200165,
154-
-1.4300930500030518,
155-
-0.8807456493377686,
156-
-0.4597663879394531,
157-
-0.9252307415008545,
158-
-1.648141860961914,
159-
-0.44453874230384827,
160-
-1.818476915359497,
161-
-0.5714479088783264,
162-
-1.2115143537521362,
163-
-1.0910619497299194,
164-
-0.0023161747958511114,
165-
-1.3206473588943481,
166-
-0.008621376007795334,
167-
-0.7551823854446411,
168-
-0.9404395818710327,
169-
-0.07279698550701141,
170-
-0.9365248680114746,
171-
-0.03344438225030899,
172-
-1.9720849990844727,
173-
-1.3928067684173584,
174-
-0.7453650832176208
146+
-2.3787546157836914,
147+
-2.1559927463531494,
148+
-0.3042203187942505,
149+
-1.5030715465545654,
150+
-2.2185328006744385,
151+
-1.163987398147583,
152+
-1.6270570755004883,
153+
-1.382636547088623,
154+
-2.0105135440826416,
155+
-1.4725666046142578,
156+
-0.8826082944869995,
157+
-1.2427284717559814,
158+
-1.966317892074585,
159+
-2.320446252822876,
160+
-1.6230499744415283,
161+
-0.9565537571907043,
162+
-1.8637971878051758,
163+
-2.4722466468811035,
164+
-1.0165706872940063,
165+
-1.5226590633392334,
166+
-0.9222670793533325,
167+
-0.3918261229991913,
168+
-0.4523355960845947,
169+
-0.2231457382440567,
170+
-0.014109030365943909,
171+
-0.36200955510139465,
172+
-0.2032179832458496,
173+
-0.043058376759290695,
174+
-0.09619185328483582
175175
]
176176
},
177177
"throughput": [
@@ -184,4 +184,4 @@
184184
100.30260782227111,
185185
100.30996418216475
186186
]
187-
}
187+
}

tests/unit_tests/models/test_gpt_model.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,23 @@
88
import torch
99
from packaging import version
1010
from pytest import approx
11+
from transformer_engine.pytorch.fp8 import check_fp8_support
1112

1213
from megatron.core import parallel_state
1314
from megatron.core.hyper_comm_grid import HyperCommGrid
15+
from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext
16+
from megatron.core.inference.inference_request import DynamicInferenceRequest
17+
from megatron.core.inference.sampling_params import SamplingParams
1418
from megatron.core.models.gpt.gpt_layer_specs import (
1519
get_gpt_layer_with_transformer_engine_spec,
1620
get_mlp_module_spec,
1721
)
1822
from megatron.core.models.gpt.gpt_model import GPTModel
1923
from megatron.core.process_groups_config import ProcessGroupCollection
2024
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
25+
from megatron.core.transformer.module import Float16Module
2126
from megatron.core.transformer.transformer_config import TransformerConfig
22-
from megatron.core.utils import is_te_min_version
27+
from megatron.core.utils import is_fa_min_version, is_te_min_version
2328
from tests.unit_tests.test_utilities import Utils
2429

2530

@@ -333,3 +338,108 @@ def test_gpt_model_with_custom_pg(self, tp_size, dp_size, cp_size):
333338
assert logits.shape[0] == sequence_length
334339
assert logits.shape[1] == micro_batch_size
335340
assert logits.shape[2] == self.gpt_model.config.hidden_size
341+
342+
343+
class TestGPTWithDynamicInference:
344+
"""Tests GPTModel with dynamic inference."""
345+
346+
@torch.inference_mode()
347+
def setup_method(self, method):
348+
fp8_available, reason_for_no_fp8 = check_fp8_support()
349+
if not fp8_available:
350+
pytest.skip(reason_for_no_fp8)
351+
352+
os.environ.pop('NVTE_FUSED_ATTN', None)
353+
os.environ.pop('NVTE_FLASH_ATTN', None)
354+
os.environ.pop('NVTE_UNFUSED_ATTN', None)
355+
Utils.initialize_model_parallel(1, 1)
356+
model_parallel_cuda_manual_seed(123)
357+
358+
transformer_config = TransformerConfig(
359+
num_layers=4,
360+
hidden_size=64,
361+
num_attention_heads=8,
362+
use_cpu_initialization=True,
363+
params_dtype=torch.bfloat16,
364+
bf16=True,
365+
fp8="hybrid",
366+
fp8_recipe="tensorwise",
367+
)
368+
369+
self.gpt_model = GPTModel(
370+
config=transformer_config,
371+
transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(),
372+
vocab_size=128,
373+
max_sequence_length=DynamicInferenceContext.TOKEN_ROUNDER,
374+
parallel_output=True,
375+
)
376+
self.gpt_model = Float16Module(self.gpt_model.config, self.gpt_model)
377+
378+
def teardown_method(self, method):
379+
Utils.destroy_model_parallel()
380+
381+
@pytest.mark.internal
382+
@pytest.mark.skipif(
383+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
384+
)
385+
@torch.inference_mode()
386+
def test_dynamic_inference_padding_with_fp8(self):
387+
"""
388+
Tests that logits for padded tokens are zeroed out for fp8 inference.
389+
"""
390+
self.gpt_model.cuda()
391+
self.gpt_model.eval()
392+
config = self.gpt_model.config
393+
394+
inference_context = DynamicInferenceContext(
395+
params_dtype=config.params_dtype,
396+
num_layers=config.num_layers,
397+
kv_channels=config.hidden_size // config.num_attention_heads,
398+
num_attention_heads=config.num_attention_heads,
399+
max_sequence_length=self.gpt_model.module.max_sequence_length,
400+
buffer_size_gb=1.0,
401+
block_size_tokens=256,
402+
materialize_only_last_token_logits=False,
403+
)
404+
405+
# Add a request with 10 tokens. Since 10 is not a multiple of 64,
406+
# this will create padding up to the padded length of 64.
407+
active_token_count = 10
408+
request = DynamicInferenceRequest(
409+
request_id=0,
410+
prompt_tokens=torch.arange(0, active_token_count, dtype=torch.long, device='cuda'),
411+
sampling_params=SamplingParams(num_tokens_to_generate=1),
412+
)
413+
inference_context.add_request(request)
414+
415+
# Prepares the context, including calculating the padded token count.
416+
inference_context.initialize_attention_state()
417+
418+
assert inference_context.active_token_count == active_token_count
419+
assert inference_context.padded_active_token_count == DynamicInferenceContext.TOKEN_ROUNDER
420+
421+
# Prepare inputs for the forward pass.
422+
padded_token_count = inference_context.padded_active_token_count
423+
input_ids, position_ids = inference_context.current_input_and_position_ids()
424+
425+
# Run the forward pass with inference parameters.
426+
logits = self.gpt_model.forward(
427+
input_ids=input_ids,
428+
position_ids=position_ids,
429+
attention_mask=None,
430+
inference_context=inference_context,
431+
runtime_gather_output=True,
432+
)
433+
434+
# Verify the output shape.
435+
assert logits.shape[0] == 1
436+
assert logits.shape[1] == padded_token_count
437+
assert logits.shape[2] == self.gpt_model.module.vocab_size
438+
439+
# Extract the logits corresponding to the padding tokens (from index 10 to 63).
440+
padding_start_idx = inference_context.active_token_count
441+
padding_end_idx = inference_context.padded_active_token_count
442+
padding_logits = logits[0, padding_start_idx:padding_end_idx, :]
443+
444+
# Assert that all padding logits are zero.
445+
assert torch.all(padding_logits == 0.0), "Logits for padding tokens are not all zero."

0 commit comments

Comments
 (0)