Skip to content

Commit 2cd2104

Browse files
Resolved merge conflicts, linter issues, added pytest for packed fp6 dims
1 parent 4a3a54b commit 2cd2104

File tree

5 files changed

+27
-11
lines changed

5 files changed

+27
-11
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_linear_eager(elem_dtype, bias, input_shape):
5959
nn.Linear(8, 8, bias=bias, device="cuda"),
6060
)
6161
m_mx = copy.deepcopy(m)
62-
block_size = 2
62+
block_size = 4
6363
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)
6464

6565
x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
@@ -94,10 +94,10 @@ def test_activation_checkpointing():
9494
elem_dtype = torch.float8_e4m3fn
9595

9696
m = nn.Sequential(
97-
nn.Linear(4, 6, bias=True, device="cuda"),
98-
nn.Linear(6, 6, bias=True, device="cuda"),
97+
nn.Linear(4, 8, bias=True, device="cuda"),
98+
nn.Linear(8, 8, bias=True, device="cuda"),
9999
)
100-
block_size = 2
100+
block_size = 4
101101
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)
102102

103103
x = torch.randn(*input_shape, device="cuda").requires_grad_()
@@ -133,7 +133,7 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
133133
m_mx = nn.Sequential(
134134
nn.Linear(K, N, bias=bias, device="cuda"),
135135
)
136-
block_size = 2
136+
block_size = 4
137137
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
138138
m_mx_c = copy.deepcopy(m_mx)
139139
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
@@ -188,7 +188,6 @@ def test_inference_linear(elem_dtype, bias, input_shape):
188188
y_ref = m(x)
189189
y_mx = m_mx(x)
190190
sqnr = compute_error(y_ref, y_mx)
191-
print(sqnr)
192191
if elem_dtype is torch.float8_e4m3fn:
193192
assert sqnr >= 20.0
194193
else:
@@ -254,4 +253,4 @@ def test_filter_fn():
254253

255254
swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
256255
assert type(m2[0]) == MXInferenceLinear
257-
assert type(m2[1]) == torch.nn.Linear
256+
assert type(m2[1]) == torch.nn.Linear

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,23 @@ def test_view(elem_dtype):
219219
x_mx_2 = x_mx.view(2, 4) # noqa: F841
220220

221221

222+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
223+
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
224+
@pytest.mark.parametrize("do_fp6_packing", [False, True])
225+
def test_fp6_packing(elem_dtype, do_fp6_packing):
226+
config.pack_fp6 = do_fp6_packing
227+
x = torch.randn(1, 2, 4, device="cuda")
228+
block_size = 4
229+
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
230+
if config.pack_fp6:
231+
expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4])
232+
else:
233+
expected_packed_shape = x.shape
234+
config.pack_fp6 = True
235+
236+
assert x_mx._data.shape == expected_packed_shape
237+
238+
222239
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
223240
@pytest.mark.skipif(
224241
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"

third_party/cutlass

Submodule cutlass updated 2031 files

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from torchao.prototype.mx_formats.constants import DTYPE_FP4, DTYPE_FP6_E3M2, DTYPE_FP6_E2M3
2626
from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501
2727
MXTensor,
28-
tensor_size_hp_to_fp4x2,
2928
tensor_size_hpx3_to_fp6x4,
29+
tensor_size_hp_to_fp4x2,
3030
)
3131

3232
aten = torch.ops.aten

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@
4545
f4_unpacked_to_f32,
4646
f6_e2m3_unpacked_to_f32,
4747
f6_e3m2_unpacked_to_f32,
48-
triton_f6_e2m3_to_scaled_bf16,
49-
triton_f6_e3m2_to_scaled_bf16,
5048
f32_to_f4_unpacked,
5149
f32_to_f6_e2m3_unpacked,
5250
f32_to_f6_e3m2_unpacked,
5351
pack_uint4,
5452
pack_uint6,
5553
triton_f4_to_scaled_bf16,
54+
triton_f6_e2m3_to_scaled_bf16,
55+
triton_f6_e3m2_to_scaled_bf16,
5656
unpack_uint4,
5757
)
5858

0 commit comments

Comments
 (0)