Skip to content

Commit 555c845

Browse files
Added MXFP6 packing and fused unpack-dequantise kernels, amended pytests to suit appropriate tensor dimensions
1 parent 867a91f commit 555c845

File tree

7 files changed

+638
-48
lines changed

7 files changed

+638
-48
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
4949
Smoke test for training linear module with mx weight
5050
"""
5151
grad_shape = list(input_shape)
52-
grad_shape[-1] = 6
52+
grad_shape[-1] = 8
5353

5454
m = nn.Sequential(
55-
nn.Linear(8, 6, bias=bias, device="cuda"),
55+
nn.Linear(8, 8, bias=bias, device="cuda"),
5656
)
5757
m_mx = copy.deepcopy(m)
58-
block_size = 2
58+
block_size = 4
5959
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
6060

6161
x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
@@ -86,14 +86,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
8686
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
8787
def test_activation_checkpointing():
8888
input_shape = (2, 4)
89-
grad_shape = (2, 6)
89+
grad_shape = (2, 8)
9090
elem_dtype = torch.float8_e4m3fn
9191

9292
m = nn.Sequential(
93-
nn.Linear(4, 6, bias=True, device="cuda"),
94-
nn.Linear(6, 6, bias=True, device="cuda"),
95-
)
96-
block_size = 2
93+
nn.Linear(4, 8, bias=True, device="cuda"),
94+
nn.Linear(8, 8, bias=True, device="cuda"), )
95+
block_size = 4
9796
swap_linear_with_mx_linear(m, elem_dtype, block_size)
9897

9998
x = torch.randn(*input_shape, device="cuda").requires_grad_()
@@ -123,13 +122,13 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
123122
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
124123
if not is_sm_at_least_89():
125124
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
126-
M, K, N = 4, 8, 6
125+
M, K, N = 4, 8, 8
127126
input_shape = (M, K)
128127
grad_shape = (M, N)
129128
m_mx = nn.Sequential(
130129
nn.Linear(K, N, bias=bias, device="cuda"),
131130
)
132-
block_size = 2
131+
block_size = 4
133132
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
134133
m_mx_c = copy.deepcopy(m_mx)
135134
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
@@ -174,16 +173,17 @@ def test_inference_linear(elem_dtype, bias, input_shape):
174173
"""
175174
Smoke test for inference linear module with mx weight
176175
"""
177-
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
176+
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
178177
m = m.cuda()
179178
m_mx = copy.deepcopy(m)
180-
block_size = 2
179+
block_size = 4
181180
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
182181

183182
x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
184183
y_ref = m(x)
185184
y_mx = m_mx(x)
186185
sqnr = compute_error(y_ref, y_mx)
186+
print(sqnr)
187187
if elem_dtype is torch.float8_e4m3fn:
188188
assert sqnr >= 20.0
189189
else:
@@ -202,10 +202,10 @@ def test_inference_compile_simple(elem_dtype):
202202
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
203203
if not is_sm_at_least_89():
204204
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
205-
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
205+
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
206206
m = m.cuda()
207207
m_mx = copy.deepcopy(m)
208-
block_size = 2
208+
block_size = 4
209209
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
210210
m_mx = torch.compile(m_mx, fullgraph="true")
211211

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
DTYPE_FP6_E3M2,
1515
SUPPORTED_ELEM_DTYPES,
1616
)
17-
from torchao.prototype.mx_formats.custom_cast import pack_uint4
17+
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
1818
from torchao.prototype.mx_formats.mx_tensor import (
1919
E8M0_EXPONENT_NAN_VAL,
2020
MXTensor,
@@ -70,15 +70,15 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7070
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
7171
def test_hello_world(elem_dtype):
7272
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
73-
block_size = 2
73+
block_size = 4
7474
_test_mx(data, elem_dtype, block_size)
7575

7676

7777
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7878
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
7979
def test_all_zeros(elem_dtype):
8080
data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16)
81-
block_size = 2
81+
block_size = 4
8282
_test_mx(data, elem_dtype, block_size)
8383

8484

@@ -88,7 +88,7 @@ def test_some_zeros(elem_dtype):
8888
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
8989
data[0, :] = 0.0
9090
data[:, 2] = 0.0
91-
block_size = 2
91+
block_size = 4
9292
_test_mx(data, elem_dtype, block_size)
9393

9494

@@ -100,9 +100,9 @@ def test_exponent_nan_in(elem_dtype):
100100
value is set to is NaN
101101
"""
102102
tensor_hp = torch.tensor(
103-
[float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16
103+
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16
104104
)
105-
block_size = 2
105+
block_size = 4
106106
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
107107
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
108108
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
@@ -115,24 +115,30 @@ def test_exponent_nan_out(elem_dtype):
115115
If block exponent value is NaN, the MX tensor block value is NaN
116116
"""
117117
scale_e8m0_bits = torch.tensor(
118-
[E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda"
118+
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
119119
)
120+
121+
block_size = 4
122+
120123
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
121-
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda") # noqa: E501
124+
data_bits = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda") # noqa: E501
122125
elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
123-
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
126+
data_bits = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda") # noqa: E501
127+
if config.pack_fp6:
128+
data_bits = data_bits.reshape(-1, block_size)
129+
data_bits = pack_uint6(data_bits)
124130
elif elem_dtype == DTYPE_FP4:
125-
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
131+
data_bits = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda") # noqa: E501
126132
data_bits = pack_uint4(data_bits)
127133
else:
128134
raise AssertionError("unsupported")
129-
block_size = 2
135+
130136
tensor_mx = MXTensor(
131137
scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float
132138
)
133139
tensor_hp = tensor_mx.to_dtype(torch.float)
134-
assert torch.all(torch.isnan(tensor_hp[0:1]))
135-
assert not torch.any(torch.isnan(tensor_hp[2:]))
140+
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
141+
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))
136142

137143

138144
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -141,24 +147,26 @@ def test_ranks(elem_dtype):
141147
"""
142148
The reshaping logic works for various ranks
143149
"""
144-
B = 2
145-
shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2))
150+
B = 4
151+
shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4))
146152
for s in shapes:
147153
tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16)
148154
_test_mx(tensor_hp, elem_dtype, B)
149155

150156

151157
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
152158
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
153-
def test_block_sizes(elem_dtype):
159+
@pytest.mark.parametrize("B", [1, 4, 32])
160+
def test_block_sizes(elem_dtype, B):
154161
"""
155162
Smoke test for various block sizes
156163
"""
157-
for B in (1, 2, 32):
158-
if B == 1 and elem_dtype == DTYPE_FP4:
159-
pytest.skip("unsupported configuration")
160-
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
161-
_test_mx(tensor_hp, elem_dtype, B)
164+
if B == 1 and elem_dtype == DTYPE_FP4:
165+
pytest.skip("unsupported configuration")
166+
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
167+
pytest.skip("unsupported configuration")
168+
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
169+
_test_mx(tensor_hp, elem_dtype, B)
162170

163171

164172
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -202,10 +210,11 @@ def test_cast_autograd(elem_dtype):
202210
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)
203211

204212

213+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
205214
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
206215
def test_view(elem_dtype):
207-
x = torch.randn(1, 2, 4)
208-
block_size = 2
216+
x = torch.randn(1, 2, 4, device="cuda")
217+
block_size = 4
209218
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
210219
x_mx_2 = x_mx.view(2, 4) # noqa: F841
211220

@@ -231,7 +240,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
231240
x = torch.randn(*shape, dtype=hp_dtype, device="cuda")
232241
else:
233242
x = torch.zeros(*shape, dtype=hp_dtype, device="cuda")
234-
block_size = 2
243+
block_size = 4
235244
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
236245

237246
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)

third_party/cutlass

Submodule cutlass updated 2031 files
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# If True, uses a custom triton kernel for fp4 dequantize
22
use_fp4_custom_triton_dequant_kernel = False
3+
pack_fp6 = True

0 commit comments

Comments
 (0)