Skip to content

Commit 3d80d73

Browse files
x64: matmul: Update weight decompression docs & examples
1 parent b196ee1 commit 3d80d73

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

doc/primitives/matmul.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,11 @@ types for source, destination, weights, and bias tensors:
9898
| Source | Weights | Destination | Bias |
9999
|:-----------------|:---------------------------------------|:---------------------------------|:----------------------------|
100100
| f64 | f64 | f64 | f64, f32, f16, bf16, s8, u8 |
101-
| f32 | f32 | f32 | f32, bf16, f16, u8, s8 |
101+
| f32 | f32, u8, s8, u4, s4 | f32 | f32, bf16, f16, u8, s8 |
102102
| f16 | f16, u8, s8, u4, s4 | f16, u8, s8 | f32 |
103-
| f16 | f16, u8, s8 | f32 | f32, f16 |
103+
| f16 | f16, u8, s8, u4, s4 | f32, f16 | f32, f16 |
104104
| bf16 | bf16, u8, s8, u4, s4 | f32, bf16 | f32, bf16 |
105-
| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 |
106-
| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 |
105+
| f32, bf16, f16 | u8, s8, u4, s4 | f32, bf16, f16 | f32, bf16, f16 |
107106
| bf16, f16 | f8_e5m2, f8_e4m3, f4_e2m1, f4_e3m0 | f32, f16, bf16 | f32, bf16, f16 |
108107
| f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | f32, f16, bf16, f8_e5m2, f8_e4m3 | f32, bf16, f16 |
109108
| f4_e2m1, f4_e3m0 | f4_e2m1, f4_e3m0 | f32, f16, bf16, f4_e2m1, f4_e3m0 | f32, bf16, f16 |
@@ -146,8 +145,8 @@ The following attributes and post-ops are supported:
146145

147146
| Type | Operation | Description | Restrictions |
148147
|:----------|:---------------------------------------------------------------|:------------------------------------------------------------------------------|:------------------------------------|
149-
| Attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the result by given scaling factor(s) | |
150-
| Attribute | [Zero-points](@ref dnnl::primitive_attr::set_zero_points_mask) | Sets zero-point(s) for the corresponding tensors | `int8` computations only |
148+
| Attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the result by given scaling factor(s) | |
149+
| Attribute | [Zero-points](@ref dnnl::primitive_attr::set_zero_points_mask) | Sets zero-point(s) for the corresponding tensors | |
151150
| Attribute | [Dropout](@ref dnnl::primitive_attr::set_dropout) | Applies pseudo-random dropout to destination buffer, also fills mask buffer | |
152151
| Attribute | [Precomputed reductions](@ref dnnl::primitive_attr::set_precomputed_reductions) | Sets precomputed reductions for the corresponding tensors | Requires weight zero-points and full matrix mask |
153152
| Post-op | [Eltwise](@ref dnnl::post_ops::append_eltwise) | Applies an @ref dnnl_api_eltwise operation to the result | |
@@ -156,9 +155,13 @@ The following attributes and post-ops are supported:
156155
| Post-op | [Prelu](@ref dnnl::post_ops::append_prelu) | Applies an @ref dnnl_api_prelu operation to the result | |
157156

158157
The following masks are supported by the primitive:
159-
- 0, which applies one scale / zero point value to an entire tensor, and
160-
- 2, which applies a scale value per column along the
161-
`n`dimension for `DNNL_ARG_WEIGHTS`.
158+
- 0, which applies one scale / zero point value to an entire tensor
159+
- 1, which applies a scale / zero point values along `k`-dimension
160+
for `DNNL_ARG_WEIGHTS`. Values could be grouped along this dimension
161+
via specifying scales / zero points shapes for the attribute
162+
(see the code example @ref weights_decompression_matmul_cpp).
163+
- 2, which applies a scale / zero point values per column along the
164+
`n`-dimension for `DNNL_ARG_WEIGHTS`.
162165

163166
When scales and/or zero-points masks are specified, the user must
164167
provide the corresponding scales and/or zero-points as additional

examples/tutorials/matmul/weights_decompression_matmul.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ void init_vector(std::vector<float> &v) {
7878
int number_of_runs = 1;
7979

8080
// Create a MatMul primitive descriptor for the following op:
81-
// C_f32 = A_f32 * (B_s8 - zp_B) * sc_B[:]
81+
// C_f32 = A_f32 * (B_s8 - zp_B[:]) * sc_B[:]
8282
//
8383
// Here:
8484
// - Matrices A and C are of f32 data type.
@@ -96,15 +96,15 @@ matmul::primitive_desc matmul_pd_create(
9696
// Create attributes and indicate that the alpha and zero points are
9797
// runtime parameters
9898
primitive_attr attr;
99-
// Set scales with multiple scales along K and N dimensions and with groups along K.
99+
// Set scales with multiple values along K and N dimensions and with groups along K.
100100
attr.set_scales(DNNL_ARG_WEIGHTS,
101101
/* mask */ (1 << 0) + (1 << 1), {G, 1}, memory::data_type::f32);
102-
// Set a single zero point with s8 data type.
103-
attr.set_zero_points(
104-
DNNL_ARG_WEIGHTS, /* mask */ 0, {}, memory::data_type::s8);
102+
// Set zero points with multiple values along K and N dimensions and with groups along K.
103+
attr.set_zero_points(DNNL_ARG_WEIGHTS, /* mask */ (1 << 0) + (1 << 1),
104+
{G, 1}, memory::data_type::s8);
105105
// Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to
106106
// integral primitives (in this example, matmul).
107-
attr.set_fpmath_mode(fpmath_mode::bf16, true);
107+
attr.set_fpmath_mode(fpmath_mode::strict, true);
108108

109109
// Create a MatMul primitive descriptor
110110
return matmul::primitive_desc(eng, a_md, b_md, c_md, attr);
@@ -136,7 +136,7 @@ void infer(const matmul &matmul_p, int64_t M, int64_t N, int64_t K, int64_t G,
136136
// De-quantization parameters (eg. Scale and Shift)
137137
const int64_t n_groups = K / G;
138138
memory sc_B_mem({{N, n_groups}, memory::data_type::f32, {1, N}}, eng);
139-
memory zp_B_mem({{1}, memory::data_type::s8, {1}}, eng);
139+
memory zp_B_mem({{N, n_groups}, memory::data_type::s8, {1, N}}, eng);
140140

141141
// the function below fills dnnl::memory with some values
142142
// these memories, typically, come from the previous layers / operations

0 commit comments

Comments
 (0)