Skip to content

Commit 327fbbf

Browse files
committed
cpu: x64: brg_matmul: postpone access zp buffers before parallel task
1 parent b37ca7c commit 327fbbf

File tree

4 files changed

+87
-86
lines changed

4 files changed

+87
-86
lines changed

src/cpu/x64/matmul/brgemm_matmul.cpp

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
749749
= brgmm_ctx.get_zp_a_compensation_ptr(ithr, b_idx, n_blk_idx);
750750
const auto zp_comp_b
751751
= brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx);
752-
const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
753752
const auto &post_ops_binary_rhs_arg_vec
754753
= brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
755754
const bool post_ops_applicable = bgmmc.post_ops_applicable
@@ -794,8 +793,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
794793
first_mb_matrix_addr_off,
795794
static_cast<const void *>(zp_comp_a),
796795
static_cast<const void *>(zp_comp_b),
797-
static_cast<const void *>(zp_c_val_ptr), false, 1, false,
798-
false, brgmm_ctx.get_src_scales_ptr(),
796+
brgmm_ctx.get_zp_c_ptr(), false, 1, false, false,
797+
brgmm_ctx.get_src_scales_ptr(),
799798
brgmm_ctx.get_wei_scales_ptr(n),
800799
brgmm_ctx.get_dst_scales_inv_ptr(ithr)};
801800
brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch,
@@ -850,8 +849,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
850849
first_mb_matrix_addr_off,
851850
static_cast<const void *>(zp_comp_a),
852851
static_cast<const void *>(zp_comp_b),
853-
static_cast<const void *>(zp_c_val_ptr), false, 1, false,
854-
false, brgmm_ctx.get_src_scales_ptr(),
852+
brgmm_ctx.get_zp_c_ptr(), false, 1, false, false,
853+
brgmm_ctx.get_src_scales_ptr(),
855854
brgmm_ctx.get_wei_scales_ptr(n),
856855
brgmm_ctx.get_dst_scales_inv_ptr(ithr)};
857856

@@ -1135,7 +1134,6 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
11351134
const auto zp_comp_b
11361135
= brgmm_ctx.get_zp_b_compensation_result_ptr(
11371136
ithr, mb);
1138-
const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
11391137
const auto &post_ops_binary_rhs_arg_vec
11401138
= brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
11411139

@@ -1158,9 +1156,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
11581156
dst_anchor_point, first_mb_matrix_addr_off,
11591157
static_cast<const void *>(zp_comp_a),
11601158
static_cast<const void *>(zp_comp_b),
1161-
static_cast<const void *>(zp_c_val_ptr),
1162-
skip_accumulation, 1, false, false,
1163-
brgmm_ctx.get_src_scales_ptr(),
1159+
brgmm_ctx.get_zp_c_ptr(), skip_accumulation, 1,
1160+
false, false, brgmm_ctx.get_src_scales_ptr(),
11641161
brgmm_ctx.get_wei_scales_ptr(n),
11651162
brgmm_ctx.get_dst_scales_inv_ptr(ithr)};
11661163

@@ -1199,9 +1196,18 @@ void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer(
11991196
ctx.zp_a_compensation_result_ptr
12001197
= (void *)brgmm_ctx.get_zp_b_compensation_result_ptr(
12011198
ithr, m_blk_idx);
1202-
ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr();
12031199
ctx.dynamic_src_ld = brgmm_ctx.get_src_stride();
1204-
ctx.zp_b_neg_val_ptr = brgmm_ctx.get_wei_zp_neg_ptr();
1200+
1201+
// Note: instead of passing an address to a stack variable, a kernel may be
1202+
// changed to take just zp_b value and perform negation itself, but updating
1203+
// kernels is not straightforward for all platforms.
1204+
int32_t neg_zp_b
1205+
= !bgmmc.with_wei_decompression ? brgmm_ctx.get_neg_zp_b() : 0;
1206+
int32_t neg_zp_ab_comp = !bgmmc.with_wei_decompression
1207+
? bgmmc.K * brgmm_ctx.get_neg_zp_a()
1208+
: 0;
1209+
ctx.zp_b_neg_val_ptr = &neg_zp_b;
1210+
ctx.zp_ab_comp_ptr = &neg_zp_ab_comp;
12051211

12061212
for (int gb = 0; gb < gemm_batch_iters; gb++) {
12071213
const int k = k_start + gb * bgmmc.K_blk;
@@ -1260,7 +1266,13 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
12601266

12611267
ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr(
12621268
ithr, b_idx, n_blk_idx);
1263-
ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();
1269+
1270+
// Note: instead of passing an address to a stack variable, a kernel may be
1271+
// changed to take just zp_a value and perform negation itself, but updating
1272+
// kernels is not straightforward for all platforms.
1273+
int32_t neg_zp_a = brgmm_ctx.get_neg_zp_a();
1274+
ctx.zp_a_neg_value_ptr = &neg_zp_a;
1275+
12641276
ctx.compensation_ptr
12651277
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
12661278

@@ -1379,31 +1391,13 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
13791391
bias_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
13801392

13811393
// setup scales / zp pointers
1382-
const void *src_zero_points = CTX_IN_MEM(
1394+
src_zp_ptr_ = CTX_IN_MEM(
13831395
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
13841396
wei_zp_ptr_ = CTX_IN_MEM(
13851397
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
1386-
const void *dst_zero_points = CTX_IN_MEM(
1398+
dst_zp_ptr_ = CTX_IN_MEM(
13871399
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
13881400

1389-
zero_point_a_negative_val_ = src_zero_points
1390-
? -cpu::io::load_int_value(
1391-
pd->attr()->zero_points_.get_data_type(DNNL_ARG_SRC),
1392-
src_zero_points, 0)
1393-
: 0;
1394-
zero_point_c_val_ = dst_zero_points
1395-
? cpu::io::load_int_value(
1396-
pd->attr()->zero_points_.get_data_type(DNNL_ARG_DST),
1397-
dst_zero_points, 0)
1398-
: 0;
1399-
1400-
wei_zp_neg_val_ = (-1)
1401-
* (wei_zp_ptr_ ? cpu::io::load_int_value(
1402-
pd->attr()->zero_points_.get_data_type(
1403-
DNNL_ARG_WEIGHTS),
1404-
wei_zp_ptr_, 0)
1405-
: 0);
1406-
14071401
const auto &scratchpad = ctx.get_scratchpad_grantor();
14081402

14091403
const auto &bgmmc = pd->get_brgemm_matmul_conf();
@@ -1470,9 +1464,6 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
14701464
key_brgemm_primitive_zp_comp_b)
14711465
: nullptr;
14721466

1473-
zero_point_mixed_ab_compensation_component_
1474-
= bgmmc.K * zero_point_a_negative_val_;
1475-
14761467
post_ops_binary_rhs_arg_vec_ = binary_injector::prepare_binary_args(
14771468
pd->attr()->post_ops_, ctx);
14781469
base_brg_ker_idx_
@@ -2100,11 +2091,19 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
21002091
+ ithr * sizeof(float);
21012092
}
21022093

2103-
const int32_t *get_zp_a_neg_val_ptr() const {
2104-
return &zero_point_a_negative_val_;
2094+
int32_t get_neg_zp_a() const {
2095+
if (!bgmmc_.has_zero_point_a) return 0;
2096+
return -cpu::io::load_int_value(bgmmc_.src_zp_dt, src_zp_ptr_, 0);
21052097
}
21062098

2107-
const void *get_wei_zp_neg_ptr() const { return &wei_zp_neg_val_; }
2099+
// Used to compute compensation. Can't initialize the value at construction
2100+
// time as memory buffers must be accessed inside a parallel task
2101+
// (asynchronous runtime requirement).
2102+
int32_t get_neg_zp_b() const {
2103+
if (!bgmmc_.has_zero_point_b) return 0;
2104+
assert(bgmmc_.is_wei_zp_common);
2105+
return -cpu::io::load_int_value(bgmmc_.wei_zp_dt, wei_zp_ptr_, 0);
2106+
}
21082107

21092108
const void *get_wei_zp_ptr(int n, int k = 0) const {
21102109
if (!bgmmc_.has_zero_point_b) return nullptr;
@@ -2127,11 +2126,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
21272126
return (char *)wei_zp_ptr_ + offset;
21282127
}
21292128

2130-
const int32_t *get_zp_ab_mixed_comp_ptr() const {
2131-
return &zero_point_mixed_ab_compensation_component_;
2132-
}
2133-
2134-
const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; }
2129+
const void *get_zp_c_ptr() const { return dst_zp_ptr_; }
21352130

21362131
int32_t *get_zp_a_compensation_ptr(
21372132
int ithr, int b_idx, int n_blk_idx) const {
@@ -2152,7 +2147,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
21522147
+ n_blk_idx * bgmmc_.wei_n_blk;
21532148
PRAGMA_OMP_SIMD()
21542149
for (int b = 0; b < bgmmc_.wei_n_blk; b++)
2155-
zp_comp[b] = -zero_point_a_negative_val_
2150+
zp_comp[b] = -get_neg_zp_a()
21562151
* reorder_zp_a_comp_ptr_[base_offset + b];
21572152
}
21582153
return zp_comp;
@@ -2478,11 +2473,9 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
24782473
int32_t *zero_point_b_compensations_ptr_;
24792474
int32_t *reorder_zp_a_comp_ptr_;
24802475

2481-
int32_t zero_point_a_negative_val_;
2482-
int32_t zero_point_mixed_ab_compensation_component_;
2483-
int32_t zero_point_c_val_;
2484-
int32_t wei_zp_neg_val_;
2476+
const void *src_zp_ptr_;
24852477
const void *wei_zp_ptr_;
2478+
const void *dst_zp_ptr_;
24862479
std::vector<const void *> post_ops_binary_rhs_arg_vec_;
24872480

24882481
int base_brg_ker_idx_;

src/cpu/x64/matmul/brgemm_matmul_copy_utils.hpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@ namespace matmul {
2727

2828
struct jit_brgemm_matmul_copy_b_t {
2929
struct ctx_t {
30-
const void *src;
31-
const void *tr_src;
32-
const void *compensation_ptr;
33-
const void *zp_a_compensation_ptr;
34-
const void *zp_a_neg_value_ptr;
35-
const void *zp_b_value_ptr;
36-
const void *src_scales_ptr;
37-
const void *wei_scales_ptr;
38-
39-
dim_t current_K_start;
40-
dim_t current_K_iters;
41-
dim_t current_K_pad {0};
42-
dim_t current_N_blk;
43-
dim_t dynamic_src_stride;
30+
const void *src = nullptr;
31+
const void *tr_src = nullptr;
32+
const void *compensation_ptr = nullptr;
33+
const void *zp_a_compensation_ptr = nullptr;
34+
const void *zp_a_neg_value_ptr = nullptr;
35+
const void *zp_b_value_ptr = nullptr;
36+
const void *src_scales_ptr = nullptr;
37+
const void *wei_scales_ptr = nullptr;
38+
39+
dim_t current_K_start = 0;
40+
dim_t current_K_iters = 0;
41+
dim_t current_K_pad = 0;
42+
dim_t current_N_blk = 0;
43+
dim_t dynamic_src_stride = 0;
4444
};
4545

4646
virtual void operator()(ctx_t *ctx) = 0;
@@ -55,17 +55,17 @@ struct jit_brgemm_matmul_copy_b_t {
5555

5656
struct jit_brgemm_matmul_copy_a_t {
5757
struct ctx_t {
58-
const void *src;
59-
const void *tr_src;
60-
const void *zp_b_compensation_buffer_ptr;
61-
const void *zp_a_compensation_result_ptr;
62-
const void *zp_b_neg_val_ptr;
63-
const void *zp_ab_comp_ptr;
64-
65-
dim_t current_K_start;
66-
dim_t current_K_blk;
67-
dim_t current_M_blk;
68-
dim_t dynamic_src_ld;
58+
const void *src = nullptr;
59+
const void *tr_src = nullptr;
60+
const void *zp_b_compensation_buffer_ptr = nullptr;
61+
const void *zp_a_compensation_result_ptr = nullptr;
62+
const void *zp_b_neg_val_ptr = nullptr;
63+
const void *zp_ab_comp_ptr = nullptr;
64+
65+
dim_t current_K_start = 0;
66+
dim_t current_K_blk = 0;
67+
dim_t current_M_blk = 0;
68+
dim_t dynamic_src_ld = 0;
6969
};
7070

7171
virtual void operator()(ctx_t *ctx) = 0;

src/cpu/x64/matmul/brgemm_matmul_utils.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1335,7 +1335,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
13351335
bgmmc.dst_dt = dst_d.data_type();
13361336
bgmmc.wei_dt = weights_d.data_type();
13371337
bgmmc.orig_wei_dt = weights_d.data_type();
1338-
bgmmc.wei_zp_dt = attr.zero_points_.get(DNNL_ARG_WEIGHTS).get_data_type();
13391338

13401339
bgmmc.with_reduce = mmd.reduce_desc.format_kind != format_kind::undef;
13411340
bgmmc.reduce_dt
@@ -1466,6 +1465,10 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
14661465
VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0),
14671466
VERBOSE_UNSUPPORTED_SCALES_CFG);
14681467

1468+
const auto &src_zp = attr.zero_points_.get(DNNL_ARG_SRC);
1469+
const auto has_src_zp = !src_zp.has_default_values();
1470+
if (has_src_zp) { bgmmc.src_zp_dt = src_zp.get_data_type(); }
1471+
14691472
const auto &wei_zp = attr.zero_points_.get(DNNL_ARG_WEIGHTS);
14701473
const auto has_wei_zp = !wei_zp.has_default_values();
14711474

src/cpu/x64/matmul/brgemm_matmul_utils.hpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,6 @@ struct brgemm_matmul_conf_t {
115115
bool packed_sparse_weights;
116116
bool with_wei_decompression;
117117
int postops_inst_count;
118-
brgemm_broadcast_t src_zp_type;
119-
brgemm_broadcast_t wei_zp_type;
120-
brgemm_broadcast_t dst_zp_type;
121118

122119
bool use_buffer_a;
123120
bool use_buffer_a_tail_only;
@@ -189,7 +186,6 @@ struct brgemm_matmul_conf_t {
189186
dim_t s8s8_comp_ithr_str;
190187
dim_t s8s8_comp_b_str;
191188
dim_t s8s8_comp_n_str;
192-
bool has_zero_point_a, has_zero_point_b, has_zero_point_c;
193189
bool post_ops_applicable;
194190
bool transposed_A;
195191
bool transposed_B;
@@ -202,14 +198,6 @@ struct brgemm_matmul_conf_t {
202198
// were changed.
203199
bool adjust_a_strides = false;
204200

205-
dim_t zp_a_comp_shift_n;
206-
dim_t zp_a_comp_elems_per_thr;
207-
208-
dim_t zp_b_comp_result_shift_m;
209-
dim_t zp_b_comp_buffer_start;
210-
dim_t zp_b_comp_buffer_shift_m;
211-
dim_t zp_b_comp_elems_per_thr;
212-
213201
int wsp_tile_per_thr_bytes;
214202
int brgemm_batch_element_per_thr_sz;
215203
bool is_amx;
@@ -244,12 +232,29 @@ struct brgemm_matmul_conf_t {
244232
data_type_t wei_scales_dt = data_type::undef;
245233

246234
// Zero points
235+
bool has_zero_point_a;
236+
bool has_zero_point_b;
237+
bool has_zero_point_c;
238+
brgemm_broadcast_t src_zp_type;
239+
brgemm_broadcast_t wei_zp_type;
240+
brgemm_broadcast_t dst_zp_type;
241+
242+
data_type_t src_zp_dt = data_type::undef;
243+
247244
dim_t wei_zp_k_gsize = 0;
248245
bool is_wei_zp_per_k = false;
249246
bool is_wei_zp_per_n = false;
250247
bool is_wei_zp_common = false;
251248
data_type_t wei_zp_dt = data_type::undef;
252249

250+
dim_t zp_a_comp_shift_n;
251+
dim_t zp_a_comp_elems_per_thr;
252+
253+
dim_t zp_b_comp_result_shift_m;
254+
dim_t zp_b_comp_buffer_start;
255+
dim_t zp_b_comp_buffer_shift_m;
256+
dim_t zp_b_comp_elems_per_thr;
257+
253258
bool is_gemv = false;
254259
// Currently, it's only used to enable the N=1 code path for M=1, when B
255260
// is transposed.

0 commit comments

Comments
 (0)