1414 DTYPE_FP6_E3M2 ,
1515 SUPPORTED_ELEM_DTYPES ,
1616)
17- from torchao .prototype .mx_formats .custom_cast import pack_uint4
17+ from torchao .prototype .mx_formats .custom_cast import pack_uint4 , pack_uint6
1818from torchao .prototype .mx_formats .mx_tensor import (
1919 E8M0_EXPONENT_NAN_VAL ,
2020 MXTensor ,
@@ -70,15 +70,15 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7070@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
7171def test_hello_world (elem_dtype ):
7272 data = torch .randn (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
73- block_size = 2
73+ block_size = 4
7474 _test_mx (data , elem_dtype , block_size )
7575
7676
7777@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
7878@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
7979def test_all_zeros (elem_dtype ):
8080 data = torch .zeros (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
81- block_size = 2
81+ block_size = 4
8282 _test_mx (data , elem_dtype , block_size )
8383
8484
@@ -88,7 +88,7 @@ def test_some_zeros(elem_dtype):
8888 data = torch .randn (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
8989 data [0 , :] = 0.0
9090 data [:, 2 ] = 0.0
91- block_size = 2
91+ block_size = 4
9292 _test_mx (data , elem_dtype , block_size )
9393
9494
@@ -100,9 +100,9 @@ def test_exponent_nan_in(elem_dtype):
100100 value is set to is NaN
101101 """
102102 tensor_hp = torch .tensor (
103- [float ("nan" ), 1 , 2 , 3 , 4 , 5 ], device = "cuda" , dtype = torch .bfloat16
103+ [float ("nan" ), 1 , 2 , 3 , 4 , 5 , 6 , 7 ], device = "cuda" , dtype = torch .bfloat16
104104 )
105- block_size = 2
105+ block_size = 4
106106 tensor_mx = MXTensor .to_mx (tensor_hp , elem_dtype , block_size )
107107 assert torch .all (tensor_mx ._scale_e8m0 [0 ] == E8M0_EXPONENT_NAN_VAL )
108108 assert not torch .any (tensor_mx ._scale_e8m0 [1 :] == E8M0_EXPONENT_NAN_VAL )
@@ -115,24 +115,30 @@ def test_exponent_nan_out(elem_dtype):
115115 If block exponent value is NaN, the MX tensor block value is NaN
116116 """
117117 scale_e8m0_bits = torch .tensor (
118- [E8M0_EXPONENT_NAN_VAL , 23 , 42 ], dtype = torch .uint8 , device = "cuda"
118+ [E8M0_EXPONENT_NAN_VAL , 23 ], dtype = torch .uint8 , device = "cuda"
119119 )
120+
121+ block_size = 4
122+
120123 if elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
121- data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = elem_dtype , device = "cuda" ) # noqa: E501
124+ data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = elem_dtype , device = "cuda" ) # noqa: E501
122125 elif elem_dtype in (DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ):
123- data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = torch .uint8 , device = "cuda" ) # noqa: E501
126+ data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = "cuda" ) # noqa: E501
127+ if config .pack_fp6 :
128+ data_bits = data_bits .reshape (- 1 , block_size )
129+ data_bits = pack_uint6 (data_bits )
124130 elif elem_dtype == DTYPE_FP4 :
125- data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = torch .uint8 , device = "cuda" ) # noqa: E501
131+ data_bits = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = "cuda" ) # noqa: E501
126132 data_bits = pack_uint4 (data_bits )
127133 else :
128134 raise AssertionError ("unsupported" )
129- block_size = 2
135+
130136 tensor_mx = MXTensor (
131137 scale_e8m0_bits , data_bits , elem_dtype , block_size , torch .float
132138 )
133139 tensor_hp = tensor_mx .to_dtype (torch .float )
134- assert torch .all (torch .isnan (tensor_hp [0 :1 ]))
135- assert not torch .any (torch .isnan (tensor_hp [ 2 :]))
140+ assert torch .all (torch .isnan (tensor_hp . flatten () [0 :4 ]))
141+ assert not torch .any (torch .isnan (tensor_hp . flatten ()[ 4 :]))
136142
137143
138144@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -141,24 +147,26 @@ def test_ranks(elem_dtype):
141147 """
142148 The reshaping logic works for various ranks
143149 """
144- B = 2
145- shapes = ((B * 4 ,), (B * 4 , 2 ), (B * 4 , 2 , 2 ), (B * 4 , 2 , 2 , 2 ))
150+ B = 4
151+ shapes = ((B * 4 ,), (B * 4 , 4 ), (B * 4 , 4 , 4 ), (B * 4 , 4 , 4 , 4 ))
146152 for s in shapes :
147153 tensor_hp = torch .randn (* s , device = "cuda" , dtype = torch .bfloat16 )
148154 _test_mx (tensor_hp , elem_dtype , B )
149155
150156
151157@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
152158@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
153- def test_block_sizes (elem_dtype ):
159+ @pytest .mark .parametrize ("B" , [1 , 4 , 32 ])
160+ def test_block_sizes (elem_dtype , B ):
154161 """
155162 Smoke test for various block sizes
156163 """
157- for B in (1 , 2 , 32 ):
158- if B == 1 and elem_dtype == DTYPE_FP4 :
159- pytest .skip ("unsupported configuration" )
160- tensor_hp = torch .randn (B , device = "cuda" , dtype = torch .bfloat16 )
161- _test_mx (tensor_hp , elem_dtype , B )
164+ if B == 1 and elem_dtype == DTYPE_FP4 :
165+ pytest .skip ("unsupported configuration" )
166+ elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ]:
167+ pytest .skip ("unsupported configuration" )
168+ tensor_hp = torch .randn (B , device = "cuda" , dtype = torch .bfloat16 )
169+ _test_mx (tensor_hp , elem_dtype , B )
162170
163171
164172@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -202,10 +210,11 @@ def test_cast_autograd(elem_dtype):
202210 torch .testing .assert_close (grad , x .grad , atol = 0 , rtol = 0 )
203211
204212
213+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
205214@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
206215def test_view (elem_dtype ):
207- x = torch .randn (1 , 2 , 4 )
208- block_size = 2
216+ x = torch .randn (1 , 2 , 4 , device = "cuda" )
217+ block_size = 4
209218 x_mx = MXTensor .to_mx (x , elem_dtype , block_size )
210219 x_mx_2 = x_mx .view (2 , 4 ) # noqa: F841
211220
@@ -231,7 +240,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
231240 x = torch .randn (* shape , dtype = hp_dtype , device = "cuda" )
232241 else :
233242 x = torch .zeros (* shape , dtype = hp_dtype , device = "cuda" )
234- block_size = 2
243+ block_size = 4
235244 to_mx_c = torch .compile (MXTensor .to_mx , fullgraph = True )
236245
237246 x_mx = MXTensor .to_mx (x , elem_dtype , block_size )
0 commit comments