99import pytest
1010import torch
1111from tqdm import tqdm
12+ from transformer_engine .pytorch .fp8 import check_fp8_support
1213
1314from megatron .core import parallel_state
1415from megatron .core .inference .contexts .dynamic_context import (
3132from megatron .core .inference .text_generation_controllers .text_generation_controller import (
3233 TextGenerationController ,
3334)
34- from megatron .core .models .gpt .gpt_layer_specs import get_gpt_layer_local_spec
35+ from megatron .core .models .gpt .gpt_layer_specs import (
36+ get_gpt_layer_local_spec ,
37+ get_gpt_layer_with_transformer_engine_spec ,
38+ )
3539from megatron .core .models .gpt .gpt_model import GPTModel
3640from megatron .core .tensor_parallel .random import model_parallel_cuda_manual_seed
3741from megatron .core .transformer .cuda_graphs import CudaGraphManager , _CudagraphGlobalRecord
@@ -89,6 +93,8 @@ class DynamicEngineTestConfig:
8993 # relevant to the test. The tests only check if the required
9094 # context attributes are set correctly.
9195
96+ fp8 : bool = False
97+
9298 def __post_init__ (self ):
9399
94100 # Compute max_sequence_length.
@@ -236,7 +242,7 @@ def _build_test_env(cls, test_config):
236242 transformer_config = TransformerConfig (
237243 params_dtype = torch .bfloat16 ,
238244 num_layers = 4 ,
239- hidden_size = 32 ,
245+ hidden_size = 128 if test_config . fp8 else 32 ,
240246 num_attention_heads = 4 ,
241247 use_cpu_initialization = True ,
242248 cuda_graph_impl = (
@@ -259,14 +265,21 @@ def _build_test_env(cls, test_config):
259265 inference_sampling_seed = test_config .random_seed ,
260266 cuda_graph_scope = test_config .cuda_graph_scope ,
261267 )
268+ if test_config .fp8 :
269+ transformer_config .fp8 = "hybrid"
270+ transformer_config .fp8_recipe = "tensorwise"
271+ # transformer_config.fp8_param = True
272+ layer_spec = get_gpt_layer_with_transformer_engine_spec ()
273+ else :
274+ layer_spec = get_gpt_layer_local_spec ()
262275
263276 # Requests.
264277 requests = cls ._build_requests (test_config )
265278
266279 # GPT model.
267280 model = GPTModel (
268281 config = transformer_config ,
269- transformer_layer_spec = get_gpt_layer_local_spec () ,
282+ transformer_layer_spec = layer_spec ,
270283 vocab_size = test_config .vocab_size ,
271284 max_sequence_length = test_config .max_sequence_length ,
272285 parallel_output = True ,
@@ -286,6 +299,7 @@ def _build_test_env(cls, test_config):
286299 fp32_residual_connection = False ,
287300 params_dtype = transformer_config .params_dtype ,
288301 padded_vocab_size = test_config .vocab_size ,
302+ fp8 = "hybrid" if test_config .fp8 else None ,
289303 )
290304
291305 # Inference context.
@@ -799,6 +813,25 @@ def test_parallel_inference(
799813 materialize_only_last_token_logits = materialize_only_last_token_logits ,
800814 )
801815
816+ @pytest .mark .internal
817+ @pytest .mark .skipif (
818+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
819+ )
820+ @pytest .mark .parametrize ("materialize_only_last_token_logits" , [False , True ])
821+ def test_sequence_parallel_fp8_inference (self , materialize_only_last_token_logits : bool ):
822+ fp8_available , reason_for_no_fp8 = check_fp8_support ()
823+ if not fp8_available :
824+ pytest .skip (reason_for_no_fp8 )
825+
826+ self ._run_test (
827+ min_prompt_length = 19 ,
828+ max_prompt_length = 19 ,
829+ tensor_model_parallel_size = 4 ,
830+ sequence_parallel = True ,
831+ materialize_only_last_token_logits = True ,
832+ fp8 = True ,
833+ )
834+
802835 @pytest .mark .internal
803836 @pytest .mark .skipif (
804837 not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
0 commit comments