@@ -749,7 +749,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
749749 = brgmm_ctx.get_zp_a_compensation_ptr (ithr, b_idx, n_blk_idx);
750750 const auto zp_comp_b
751751 = brgmm_ctx.get_zp_b_compensation_result_ptr (ithr, m_blk_idx);
752- const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr ();
753752 const auto &post_ops_binary_rhs_arg_vec
754753 = brgmm_ctx.get_post_ops_binary_rhs_arg_vec ();
755754 const bool post_ops_applicable = bgmmc.post_ops_applicable
@@ -794,8 +793,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
794793 first_mb_matrix_addr_off,
795794 static_cast <const void *>(zp_comp_a),
796795 static_cast <const void *>(zp_comp_b),
797- static_cast < const void *>(zp_c_val_ptr ), false , 1 , false ,
798- false , brgmm_ctx.get_src_scales_ptr (),
796+ brgmm_ctx. get_zp_c_ptr ( ), false , 1 , false , false ,
797+ brgmm_ctx.get_src_scales_ptr (),
799798 brgmm_ctx.get_wei_scales_ptr (n),
800799 brgmm_ctx.get_dst_scales_inv_ptr (ithr)};
801800 brgemm_kernel_execute_postops (brg_kernel, gemm_batch, addr_batch,
@@ -850,8 +849,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
850849 first_mb_matrix_addr_off,
851850 static_cast <const void *>(zp_comp_a),
852851 static_cast <const void *>(zp_comp_b),
853- static_cast < const void *>(zp_c_val_ptr ), false , 1 , false ,
854- false , brgmm_ctx.get_src_scales_ptr (),
852+ brgmm_ctx. get_zp_c_ptr ( ), false , 1 , false , false ,
853+ brgmm_ctx.get_src_scales_ptr (),
855854 brgmm_ctx.get_wei_scales_ptr (n),
856855 brgmm_ctx.get_dst_scales_inv_ptr (ithr)};
857856
@@ -1135,7 +1134,6 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
11351134 const auto zp_comp_b
11361135 = brgmm_ctx.get_zp_b_compensation_result_ptr (
11371136 ithr, mb);
1138- const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr ();
11391137 const auto &post_ops_binary_rhs_arg_vec
11401138 = brgmm_ctx.get_post_ops_binary_rhs_arg_vec ();
11411139
@@ -1158,9 +1156,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
11581156 dst_anchor_point, first_mb_matrix_addr_off,
11591157 static_cast <const void *>(zp_comp_a),
11601158 static_cast <const void *>(zp_comp_b),
1161- static_cast <const void *>(zp_c_val_ptr),
1162- skip_accumulation, 1 , false , false ,
1163- brgmm_ctx.get_src_scales_ptr (),
1159+ brgmm_ctx.get_zp_c_ptr (), skip_accumulation, 1 ,
1160+ false , false , brgmm_ctx.get_src_scales_ptr (),
11641161 brgmm_ctx.get_wei_scales_ptr (n),
11651162 brgmm_ctx.get_dst_scales_inv_ptr (ithr)};
11661163
@@ -1199,9 +1196,18 @@ void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer(
11991196 ctx.zp_a_compensation_result_ptr
12001197 = (void *)brgmm_ctx.get_zp_b_compensation_result_ptr (
12011198 ithr, m_blk_idx);
1202- ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr ();
12031199 ctx.dynamic_src_ld = brgmm_ctx.get_src_stride ();
1204- ctx.zp_b_neg_val_ptr = brgmm_ctx.get_wei_zp_neg_ptr ();
1200+
1201+ // Note: instead of passing an address to a stack variable, a kernel may be
1202+ // changed to take just zp_b value and perform negation itself, but updating
1203+ // kernels is not straightforward for all platforms.
1204+ int32_t neg_zp_b
1205+ = !bgmmc.with_wei_decompression ? brgmm_ctx.get_neg_zp_b () : 0 ;
1206+ int32_t neg_zp_ab_comp = !bgmmc.with_wei_decompression
1207+ ? bgmmc.K * brgmm_ctx.get_neg_zp_a ()
1208+ : 0 ;
1209+ ctx.zp_b_neg_val_ptr = &neg_zp_b;
1210+ ctx.zp_ab_comp_ptr = &neg_zp_ab_comp;
12051211
12061212 for (int gb = 0 ; gb < gemm_batch_iters; gb++) {
12071213 const int k = k_start + gb * bgmmc.K_blk ;
@@ -1260,7 +1266,13 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
12601266
12611267 ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr (
12621268 ithr, b_idx, n_blk_idx);
1263- ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr ();
1269+
1270+ // Note: instead of passing an address to a stack variable, a kernel may be
1271+ // changed to take just zp_a value and perform negation itself, but updating
1272+ // kernels is not straightforward for all platforms.
1273+ int32_t neg_zp_a = brgmm_ctx.get_neg_zp_a ();
1274+ ctx.zp_a_neg_value_ptr = &neg_zp_a;
1275+
12641276 ctx.compensation_ptr
12651277 = (void *)brgmm_ctx.get_s8s8_comp_ptr (ithr, b_idx, n_blk_idx);
12661278
@@ -1379,31 +1391,13 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
13791391 bias_ptr_ = CTX_IN_MEM (const char *, DNNL_ARG_BIAS);
13801392
13811393 // setup scales / zp pointers
1382- const void *src_zero_points = CTX_IN_MEM (
1394+ src_zp_ptr_ = CTX_IN_MEM (
13831395 const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
13841396 wei_zp_ptr_ = CTX_IN_MEM (
13851397 const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
1386- const void *dst_zero_points = CTX_IN_MEM (
1398+ dst_zp_ptr_ = CTX_IN_MEM (
13871399 const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
13881400
1389- zero_point_a_negative_val_ = src_zero_points
1390- ? -cpu::io::load_int_value (
1391- pd->attr ()->zero_points_ .get_data_type (DNNL_ARG_SRC),
1392- src_zero_points, 0 )
1393- : 0 ;
1394- zero_point_c_val_ = dst_zero_points
1395- ? cpu::io::load_int_value (
1396- pd->attr ()->zero_points_ .get_data_type (DNNL_ARG_DST),
1397- dst_zero_points, 0 )
1398- : 0 ;
1399-
1400- wei_zp_neg_val_ = (-1 )
1401- * (wei_zp_ptr_ ? cpu::io::load_int_value (
1402- pd->attr ()->zero_points_ .get_data_type (
1403- DNNL_ARG_WEIGHTS),
1404- wei_zp_ptr_, 0 )
1405- : 0 );
1406-
14071401 const auto &scratchpad = ctx.get_scratchpad_grantor ();
14081402
14091403 const auto &bgmmc = pd->get_brgemm_matmul_conf ();
@@ -1470,9 +1464,6 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
14701464 key_brgemm_primitive_zp_comp_b)
14711465 : nullptr ;
14721466
1473- zero_point_mixed_ab_compensation_component_
1474- = bgmmc.K * zero_point_a_negative_val_;
1475-
14761467 post_ops_binary_rhs_arg_vec_ = binary_injector::prepare_binary_args (
14771468 pd->attr ()->post_ops_ , ctx);
14781469 base_brg_ker_idx_
@@ -2100,11 +2091,19 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
21002091 + ithr * sizeof (float );
21012092 }
21022093
2103- const int32_t *get_zp_a_neg_val_ptr () const {
2104- return &zero_point_a_negative_val_;
2094+ int32_t get_neg_zp_a () const {
2095+ if (!bgmmc_.has_zero_point_a ) return 0 ;
2096+ return -cpu::io::load_int_value (bgmmc_.src_zp_dt , src_zp_ptr_, 0 );
21052097 }
21062098
2107- const void *get_wei_zp_neg_ptr () const { return &wei_zp_neg_val_; }
2099+ // Used to compute compensation. Can't initialize the value at construction
2100+ // time as memory buffers must be accessed inside a parallel task
2101+ // (asynchronous runtime requirement).
2102+ int32_t get_neg_zp_b () const {
2103+ if (!bgmmc_.has_zero_point_b ) return 0 ;
2104+ assert (bgmmc_.is_wei_zp_common );
2105+ return -cpu::io::load_int_value (bgmmc_.wei_zp_dt , wei_zp_ptr_, 0 );
2106+ }
21082107
21092108 const void *get_wei_zp_ptr (int n, int k = 0 ) const {
21102109 if (!bgmmc_.has_zero_point_b ) return nullptr ;
@@ -2127,11 +2126,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
21272126 return (char *)wei_zp_ptr_ + offset;
21282127 }
21292128
2130- const int32_t *get_zp_ab_mixed_comp_ptr () const {
2131- return &zero_point_mixed_ab_compensation_component_;
2132- }
2133-
2134- const int32_t *get_zp_c_val_ptr () const { return &zero_point_c_val_; }
2129+ const void *get_zp_c_ptr () const { return dst_zp_ptr_; }
21352130
21362131 int32_t *get_zp_a_compensation_ptr (
21372132 int ithr, int b_idx, int n_blk_idx) const {
@@ -2152,7 +2147,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
21522147 + n_blk_idx * bgmmc_.wei_n_blk ;
21532148 PRAGMA_OMP_SIMD ()
21542149 for (int b = 0 ; b < bgmmc_.wei_n_blk ; b++)
2155- zp_comp[b] = -zero_point_a_negative_val_
2150+ zp_comp[b] = -get_neg_zp_a ()
21562151 * reorder_zp_a_comp_ptr_[base_offset + b];
21572152 }
21582153 return zp_comp;
@@ -2478,11 +2473,9 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
24782473 int32_t *zero_point_b_compensations_ptr_;
24792474 int32_t *reorder_zp_a_comp_ptr_;
24802475
2481- int32_t zero_point_a_negative_val_;
2482- int32_t zero_point_mixed_ab_compensation_component_;
2483- int32_t zero_point_c_val_;
2484- int32_t wei_zp_neg_val_;
2476+ const void *src_zp_ptr_;
24852477 const void *wei_zp_ptr_;
2478+ const void *dst_zp_ptr_;
24862479 std::vector<const void *> post_ops_binary_rhs_arg_vec_;
24872480
24882481 int base_brg_ker_idx_;
0 commit comments