Skip to content

Commit 17ae12c

Browse files
committed
Fix fp8 inference with sequence parallelism
1 parent 9a79a45 commit 17ae12c

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

megatron/core/fp8_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
import torch
1111

1212
from megatron.core.enums import Fp8Recipe
13+
from megatron.core.tensor_parallel import (
14+
ColumnParallelLinear,
15+
RowParallelLinear,
16+
gather_from_sequence_parallel_region,
17+
reduce_scatter_to_sequence_parallel_region,
18+
)
1319
from megatron.core.transformer.transformer_config import TransformerConfig
1420
from megatron.core.utils import get_te_version, is_te_min_version
1521

@@ -112,6 +118,27 @@ def get_fp8_align_size(fp8_recipe: Fp8Recipe) -> int:
112118
return 16
113119

114120

121+
def is_column_parallel_linear(module):
122+
"""Returns whether the given module is a ColumnParallelLinear layer."""
123+
if HAVE_TE and (
124+
isinstance(module, TEColumnParallelLinear)
125+
or isinstance(module, TELayerNormColumnParallelLinear)
126+
):
127+
return True
128+
elif isinstance(module, ColumnParallelLinear):
129+
return True
130+
return False
131+
132+
133+
def is_row_parallel_linear(module):
134+
"""Returns whether the given module is a RowParallelLinear layer."""
135+
if HAVE_TE and isinstance(module, TERowParallelLinear):
136+
return True
137+
elif isinstance(module, RowParallelLinear):
138+
return True
139+
return False
140+
141+
115142
"""
116143
The code below abstracts the functionalities needed for implementing "--fp8-param-gather" into
117144
several functions. It provides different implementations for each function based on different
@@ -587,6 +614,18 @@ def padded_forward(input_tensor, *args, **kwargs):
587614
if not FP8GlobalStateManager.is_fp8_enabled():
588615
return original_forward(input_tensor, *args, **kwargs)
589616

617+
# With sequence parallelism we need to all-gather before padding
618+
# and reduce-scatter after unpadding
619+
if is_sequence_parallel := getattr(module, "sequence_parallel", False):
620+
if is_column_parallel_linear(module):
621+
input_tensor = gather_from_sequence_parallel_region(
622+
input_tensor, group=module.tp_group
623+
)
624+
625+
# Disable sequence parallelism on the module because we are handling the
626+
# all-gather and reduce-scatter externally
627+
module.sequence_parallel = False
628+
590629
seq_len, batch_size, hidden_size = input_tensor.shape
591630
# Reshape to (S, B*H) to pad sequence dimension
592631
input_2d = input_tensor.reshape(seq_len, -1)
@@ -612,6 +651,16 @@ def padded_forward(input_tensor, *args, **kwargs):
612651
unpadded_output_2d = _unpad_func(output_2d, [seq_len])
613652
unpadded_output = unpadded_output_2d.reshape(seq_len, batch_size, output_hidden_size)
614653

654+
if is_sequence_parallel:
655+
# Reduce-scatter after unpadding
656+
if is_row_parallel_linear(module):
657+
unpadded_output = reduce_scatter_to_sequence_parallel_region(
658+
unpadded_output, group=module.tp_group
659+
)
660+
661+
# Reset sequence parallelism flag on the module
662+
module.sequence_parallel = True
663+
615664
if other_outputs:
616665
return (unpadded_output,) + other_outputs
617666
else:

tests/unit_tests/inference/engines/test_dynamic_engine.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
from tqdm import tqdm
12+
from transformer_engine.pytorch.fp8 import check_fp8_support
1213

1314
from megatron.core import parallel_state
1415
from megatron.core.inference.contexts.dynamic_context import (
@@ -31,7 +32,10 @@
3132
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
3233
TextGenerationController,
3334
)
34-
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
35+
from megatron.core.models.gpt.gpt_layer_specs import (
36+
get_gpt_layer_local_spec,
37+
get_gpt_layer_with_transformer_engine_spec,
38+
)
3539
from megatron.core.models.gpt.gpt_model import GPTModel
3640
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
3741
from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord
@@ -89,6 +93,8 @@ class DynamicEngineTestConfig:
8993
# relevant to the test. The tests only check if the required
9094
# context attributes are set correctly.
9195

96+
fp8: bool = False
97+
9298
def __post_init__(self):
9399

94100
# Compute max_sequence_length.
@@ -236,7 +242,7 @@ def _build_test_env(cls, test_config):
236242
transformer_config = TransformerConfig(
237243
params_dtype=torch.bfloat16,
238244
num_layers=4,
239-
hidden_size=32,
245+
hidden_size=128 if test_config.fp8 else 32,
240246
num_attention_heads=4,
241247
use_cpu_initialization=True,
242248
cuda_graph_impl=(
@@ -259,14 +265,21 @@ def _build_test_env(cls, test_config):
259265
inference_sampling_seed=test_config.random_seed,
260266
cuda_graph_scope=test_config.cuda_graph_scope,
261267
)
268+
if test_config.fp8:
269+
transformer_config.fp8 = "hybrid"
270+
transformer_config.fp8_recipe = "tensorwise"
271+
# transformer_config.fp8_param = True
272+
layer_spec = get_gpt_layer_with_transformer_engine_spec()
273+
else:
274+
layer_spec = get_gpt_layer_local_spec()
262275

263276
# Requests.
264277
requests = cls._build_requests(test_config)
265278

266279
# GPT model.
267280
model = GPTModel(
268281
config=transformer_config,
269-
transformer_layer_spec=get_gpt_layer_local_spec(),
282+
transformer_layer_spec=layer_spec,
270283
vocab_size=test_config.vocab_size,
271284
max_sequence_length=test_config.max_sequence_length,
272285
parallel_output=True,
@@ -286,6 +299,7 @@ def _build_test_env(cls, test_config):
286299
fp32_residual_connection=False,
287300
params_dtype=transformer_config.params_dtype,
288301
padded_vocab_size=test_config.vocab_size,
302+
fp8="hybrid" if test_config.fp8 else None,
289303
)
290304

291305
# Inference context.
@@ -799,6 +813,25 @@ def test_parallel_inference(
799813
materialize_only_last_token_logits=materialize_only_last_token_logits,
800814
)
801815

816+
@pytest.mark.internal
817+
@pytest.mark.skipif(
818+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
819+
)
820+
@pytest.mark.parametrize("materialize_only_last_token_logits", [False, True])
821+
def test_sequence_parallel_fp8_inference(self, materialize_only_last_token_logits: bool):
822+
fp8_available, reason_for_no_fp8 = check_fp8_support()
823+
if not fp8_available:
824+
pytest.skip(reason_for_no_fp8)
825+
826+
self._run_test(
827+
min_prompt_length=19,
828+
max_prompt_length=19,
829+
tensor_model_parallel_size=4,
830+
sequence_parallel=True,
831+
materialize_only_last_token_logits=True,
832+
fp8=True,
833+
)
834+
802835
@pytest.mark.internal
803836
@pytest.mark.skipif(
804837
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"

0 commit comments

Comments
 (0)