1818import torch .distributed .checkpoint as dcp
1919import transformer_engine
2020from torch .distributed .checkpoint .state_dict import get_model_state_dict
21- from transformer_engine .common . recipe import DelayedScaling , MXFP8BlockScaling
22- from transformer_engine .pytorch . fp8 import check_fp8_support , check_mxfp8_support
23- from transformer_engine .pytorch .tensor .float8_tensor import Float8Tensor
21+ from transformer_engine .common import recipe as recipe_module
22+ from transformer_engine .pytorch import fp8
23+ from transformer_engine .pytorch .tensor .quantized_tensor import QuantizedTensor
2424
2525from esm .collator import MLMDataCollatorWithFlattening
2626from esm .modeling_esm_te import NVEsmConfig , NVEsmForMaskedLM
2727
2828
29- def requires_fp8 (func ):
30- """Decorator to skip tests that require FP8 support."""
31- fp8_available , reason = check_fp8_support ()
32- return pytest .mark .skipif (not fp8_available , reason = f"FP8 is not supported on this GPU: { reason } " )(func )
33-
34-
35- def requires_mxfp8 (func ):
36- """Decorator to skip tests that require MXFP8 support."""
37- mxfp8_available , reason = check_mxfp8_support ()
38- if torch .cuda .get_device_capability () == (12 , 0 ):
39- mxfp8_available = False
40- reason = "MXFP8 is not supported on sm120"
41- return pytest .mark .skipif (not mxfp8_available , reason = f"MXFP8 is not supported on this GPU: { reason } " )(func )
29+ ALL_RECIPES = [
30+ recipe_module .DelayedScaling (),
31+ recipe_module .Float8CurrentScaling (),
32+ recipe_module .Float8BlockScaling (),
33+ recipe_module .MXFP8BlockScaling (),
34+ # recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True),
35+ ]
36+
37+
38+ def _check_recipe_support (recipe : recipe_module .Recipe ):
39+ """Check if a recipe is supported and return (supported, reason)."""
40+ if isinstance (recipe , recipe_module .DelayedScaling ):
41+ recipe_supported , reason = fp8 .check_fp8_support ()
42+ elif isinstance (recipe , recipe_module .Float8CurrentScaling ):
43+ recipe_supported , reason = fp8 .check_fp8_support ()
44+ elif isinstance (recipe , recipe_module .Float8BlockScaling ):
45+ recipe_supported , reason = fp8 .check_fp8_block_scaling_support ()
46+ elif isinstance (recipe , recipe_module .MXFP8BlockScaling ):
47+ recipe_supported , reason = fp8 .check_mxfp8_support ()
48+ elif isinstance (recipe , recipe_module .NVFP4BlockScaling ):
49+ recipe_supported , reason = fp8 .check_nvfp4_support ()
50+ else :
51+ recipe_supported = False
52+ reason = "Unsupported recipe"
53+ return recipe_supported , reason
54+
55+
56+ def requires_recipe_support (recipe : recipe_module .Recipe ):
57+ """Decorator to skip tests that require recipe support."""
58+
59+ def requires_recipe_support_inner (func ):
60+ recipe_supported , reason = _check_recipe_support (recipe )
61+ return pytest .mark .skipif (not recipe_supported , reason = reason )(func )
62+
63+ return requires_recipe_support_inner
64+
65+
66+ def parametrize_recipes_with_support (recipes ):
67+ """Generate pytest.param objects with skip marks for unsupported recipes."""
68+ parametrized_recipes = []
69+ for recipe in recipes :
70+ recipe_supported , reason = _check_recipe_support (recipe )
71+ parametrized_recipes .append (
72+ pytest .param (
73+ recipe ,
74+ id = recipe .__class__ .__name__ ,
75+ marks = pytest .mark .skipif (
76+ not recipe_supported ,
77+ reason = reason ,
78+ ),
79+ )
80+ )
81+ return parametrized_recipes
4282
4383
4484@pytest .fixture
@@ -53,24 +93,30 @@ def input_data_thd(tokenizer, tokenized_proteins):
5393 return data_collator (tokenized_proteins )
5494
5595
56- @requires_fp8
57- def test_fp8_forward_and_backward_pass (te_model_checkpoint , input_data ):
96+ @pytest . mark . parametrize ( "fp8_recipe" , parametrize_recipes_with_support ( ALL_RECIPES ))
97+ def test_fp8_forward_and_backward_pass (te_model_checkpoint , input_data , fp8_recipe ):
5898 model_te = NVEsmForMaskedLM .from_pretrained (te_model_checkpoint , dtype = torch .bfloat16 )
5999 model_te .to ("cuda" )
60100
61101 input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
62102 outputs = model_te (** input_data )
63103
64- fp8_recipe = DelayedScaling ()
65104 with transformer_engine .pytorch .fp8_autocast (enabled = True , fp8_recipe = fp8_recipe ):
66105 outputs_fp8 = model_te (** input_data )
67106 outputs_fp8 .loss .backward ()
68107
69- torch .testing .assert_close (outputs_fp8 .loss , outputs .loss )
108+ if isinstance (fp8_recipe , recipe_module .NVFP4BlockScaling ):
109+ atol = 0.2
110+ rtol = 0.05
111+ else :
112+ atol = None
113+ rtol = None
114+
115+ torch .testing .assert_close (outputs_fp8 .loss , outputs .loss , atol = atol , rtol = rtol )
70116
71117
72- @requires_fp8
73- def test_fp8_forward_and_backward_pass_thd (te_model_checkpoint , input_data_thd , monkeypatch ):
118+ @pytest . mark . parametrize ( "fp8_recipe" , parametrize_recipes_with_support ( ALL_RECIPES ))
119+ def test_fp8_forward_and_backward_pass_thd (te_model_checkpoint , input_data_thd , fp8_recipe , monkeypatch ):
74120 if torch .cuda .get_device_capability () == (12 , 0 ):
75121 # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
76122 # but it's missing this THD implementation.
@@ -82,54 +128,27 @@ def test_fp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd,
82128 input_data = {k : v .to ("cuda" ) if isinstance (v , torch .Tensor ) else v for k , v in input_data_thd .items ()}
83129 outputs = model_te (** input_data )
84130
85- fp8_recipe = DelayedScaling ()
86131 with transformer_engine .pytorch .fp8_autocast (enabled = True , fp8_recipe = fp8_recipe ):
87132 outputs_fp8 = model_te (** input_data )
88133 outputs_fp8 .loss .backward ()
89134
90- torch .testing .assert_close (outputs_fp8 .loss , outputs .loss )
91-
92-
93- @requires_mxfp8
94- def test_mxfp8_forward_and_backward_pass (te_model_checkpoint , input_data ):
95- model_te = NVEsmForMaskedLM .from_pretrained (te_model_checkpoint , dtype = torch .bfloat16 )
96- model_te .to ("cuda" )
97-
98- input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
99- outputs = model_te (** input_data )
100-
101- mxfp8_recipe = MXFP8BlockScaling ()
102- with transformer_engine .pytorch .fp8_autocast (enabled = True , fp8_recipe = mxfp8_recipe ):
103- outputs_fp8 = model_te (** input_data )
104- outputs_fp8 .loss .backward ()
105-
106- torch .testing .assert_close (outputs_fp8 .loss , outputs .loss )
135+ if isinstance (fp8_recipe , recipe_module .NVFP4BlockScaling ):
136+ atol = 0.2
137+ rtol = 0.05
138+ else :
139+ atol = None
140+ rtol = None
107141
142+ torch .testing .assert_close (outputs_fp8 .loss , outputs .loss , atol = atol , rtol = rtol )
108143
109- @requires_mxfp8
110- def test_mxfp8_forward_and_backward_pass_thd (te_model_checkpoint , input_data_thd ):
111- model_te = NVEsmForMaskedLM .from_pretrained (te_model_checkpoint , attn_input_format = "thd" , dtype = torch .bfloat16 )
112- model_te .to ("cuda" )
113-
114- input_data = {k : v .to ("cuda" ) if isinstance (v , torch .Tensor ) else v for k , v in input_data_thd .items ()}
115- outputs = model_te (** input_data )
116144
117- mxfp8_recipe = MXFP8BlockScaling ()
118- with transformer_engine .pytorch .fp8_autocast (enabled = True , fp8_recipe = mxfp8_recipe ):
119- outputs_fp8 = model_te (** input_data )
120- outputs_fp8 .loss .backward ()
121-
122- torch .testing .assert_close (outputs_fp8 .loss , outputs .loss )
123-
124-
125- @requires_fp8
126- def test_fp8_model_init_forward_and_backward (te_model_checkpoint , input_data ):
127- fp8_recipe = DelayedScaling ()
145+ @pytest .mark .parametrize ("fp8_recipe" , parametrize_recipes_with_support (ALL_RECIPES ))
146+ def test_fp8_model_init_forward_and_backward (te_model_checkpoint , input_data , fp8_recipe ):
128147 config = NVEsmConfig .from_pretrained (te_model_checkpoint , dtype = torch .bfloat16 )
129148 with transformer_engine .pytorch .fp8_model_init (enabled = True , recipe = fp8_recipe ):
130149 model_te = NVEsmForMaskedLM (config )
131150
132- assert isinstance (model_te .lm_head .dense .weight , Float8Tensor )
151+ assert isinstance (model_te .lm_head .dense .weight , QuantizedTensor )
133152
134153 model_te .to ("cuda" )
135154 input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
@@ -140,39 +159,34 @@ def test_fp8_model_init_forward_and_backward(te_model_checkpoint, input_data):
140159 outputs_fp8 .loss .backward ()
141160
142161
143- @requires_fp8
144162@pytest .mark .xfail (reason = "BIONEMO-3055: fp8 model init and pretrained loading is not currently supported." )
145- def test_fp8_model_init_from_pretrained (te_model_checkpoint , input_data ):
146- fp8_recipe = DelayedScaling ()
147-
163+ @pytest .mark .parametrize ("fp8_recipe" , parametrize_recipes_with_support (ALL_RECIPES ))
164+ def test_fp8_model_init_from_pretrained (te_model_checkpoint , fp8_recipe ):
148165 # TODO: this will be renamed to quantized_model_init in the future, fp8_model_init will be removed in 3.0
149166 with transformer_engine .pytorch .fp8_model_init (enabled = True , recipe = fp8_recipe ):
150- model_te = NVEsmForMaskedLM .from_pretrained (te_model_checkpoint )
167+ model_te = NVEsmForMaskedLM .from_pretrained (te_model_checkpoint , dtype = torch . bfloat16 )
151168
152- assert isinstance (model_te .esm .encoder .layers [0 ].layernorm_mlp .fc2_weight , Float8Tensor )
153- assert isinstance (model_te .lm_head .dense .weight , Float8Tensor )
169+ assert isinstance (model_te .esm .encoder .layers [0 ].layernorm_mlp .fc2_weight , QuantizedTensor )
170+ assert isinstance (model_te .lm_head .dense .weight , QuantizedTensor )
154171
155172
156- @requires_fp8
157173@pytest .mark .xfail (reason = "BIONEMO-3055: fp8 model init and pretrained saving is not currently supported." )
158- def test_fp8_model_init_save_pretrained ( te_model_checkpoint , tmp_path ):
159- fp8_recipe = DelayedScaling ()
174+ @ pytest . mark . parametrize ( "fp8_recipe" , parametrize_recipes_with_support ( ALL_RECIPES ))
175+ def test_fp8_model_init_save_pretrained ( te_model_checkpoint , tmp_path , fp8_recipe ):
160176 config = NVEsmConfig .from_pretrained (te_model_checkpoint , dtype = torch .bfloat16 )
161177 with transformer_engine .pytorch .fp8_model_init (enabled = True , recipe = fp8_recipe ):
162178 model_fp8 = NVEsmForMaskedLM (config )
163179
164- assert isinstance (model_fp8 .esm .encoder .layers [0 ].layernorm_mlp .fc2_weight , Float8Tensor )
165- assert isinstance (model_fp8 .lm_head .dense .weight , Float8Tensor )
180+ assert isinstance (model_fp8 .esm .encoder .layers [0 ].layernorm_mlp .fc2_weight , QuantizedTensor )
181+ assert isinstance (model_fp8 .lm_head .dense .weight , QuantizedTensor )
166182
167183 model_fp8 .save_pretrained (tmp_path / "fp8_checkpoint" )
168184 del model_fp8
169185 NVEsmForMaskedLM .from_pretrained (tmp_path / "fp8_checkpoint" , dtype = torch .bfloat16 )
170186
171187
172- @requires_fp8
173- @pytest .mark .xfail (reason = "BIONEMO-3055: fp8 model init and distributed checkpointing is not currently supported." )
174- def test_fp8_model_distributed_checkpointing_save_and_load (te_model_checkpoint , tmp_path , input_data ):
175- fp8_recipe = DelayedScaling ()
188+ @pytest .mark .parametrize ("fp8_recipe" , parametrize_recipes_with_support (ALL_RECIPES ))
189+ def test_fp8_model_distributed_checkpointing_save_and_load (te_model_checkpoint , tmp_path , input_data , fp8_recipe ):
176190 config = NVEsmConfig .from_pretrained (te_model_checkpoint , dtype = torch .bfloat16 )
177191 with transformer_engine .pytorch .fp8_model_init (enabled = True , recipe = fp8_recipe ):
178192 model_fp8 = NVEsmForMaskedLM (config )
@@ -184,6 +198,7 @@ def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint,
184198 outputs .loss .backward ()
185199
186200 state_dict = get_model_state_dict (model_fp8 )
201+ state_dict = {key : val for key , val in state_dict .items () if not key .endswith ("_extra_state" )}
187202 dcp .save (state_dict , checkpoint_id = tmp_path / "fp8_checkpoint" )
188203
189204 del model_fp8 , state_dict
@@ -192,4 +207,5 @@ def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint,
192207 model_fp8 = NVEsmForMaskedLM (config )
193208
194209 state_dict = model_fp8 .state_dict ()
210+ state_dict = {key : val for key , val in state_dict .items () if not key .endswith ("_extra_state" )}
195211 dcp .load (state_dict , checkpoint_id = tmp_path / "fp8_checkpoint" )
0 commit comments