Skip to content

Commit a5436a5

Browse files
Xia-Weiwenpytorchmergebot
authored andcommitted
[CPU] add onednn context cache for qlinear to improve performance (pytorch#168150)
**Summary** We noticed big framework overhead of `qlinear`. It's because to call onednn's primitive, we need to prepare a bunch of data structs as its args, which has big overhead. In the past, such things are cached in the context and attached to torch jit graph. However, Inductor does not support non-tensor data on graph. This PR adds a cache of those data structs by using a static `std::unordered_map`, whose key is weight data address as an `int64` and value is a struct that contains all data needed to run a primitive. This cache is safe in most cases where weight data address won't change during inference and weight data are not reused by different layers. However, since we cannot guarantee the assumption, we define an environment variable `"ONEDNN_CACHE_CONTEXT_UNSAFE"` to control this feature. Users should use it at their own risk. We found >5% E2E performance gain when running ViT with PT2E static quantization on an 6th gen of Intel Xeon CPU. **Test plan** ``` pytest -sv test/test_quantization.py -k "qlinear and pt2e" ``` Pull Request resolved: pytorch#168150 Approved by: https://github.com/mingfeima, https://github.com/jerryzh168
1 parent ca3e8b3 commit a5436a5

File tree

3 files changed

+125
-33
lines changed

3 files changed

+125
-33
lines changed

aten/src/ATen/native/quantized/cpu/OnednnUtils.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,4 +462,40 @@ at::Tensor _qconv_prepack_onednn(
462462

463463
#define FP8E4M3_MAX 448.0
464464

465+
#define CACHE_ONEDNN_CONTEXT_FLAG "ONEDNN_CACHE_CONTEXT_UNSAFE"
466+
467+
struct QlinearForwardParams {
468+
dnnl::matmul primitive;
469+
ideep::exec_args args;
470+
ideep::tensor packed_weight;
471+
ideep::tensor weight_scales;
472+
std::optional<ideep::tensor> src_scale;
473+
std::optional<ideep::tensor> src_zero_point;
474+
std::optional<ideep::tensor> dst_scale;
475+
std::optional<ideep::tensor> dst_zero_point;
476+
std::optional<ideep::tensor> bias;
477+
ideep::tensor scratchpad;
478+
479+
void init_args() {
480+
args.insert({DNNL_ARG_WEIGHTS, packed_weight});
481+
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
482+
if (bias.has_value()) {
483+
args.insert({DNNL_ARG_BIAS, bias.value()});
484+
}
485+
if (src_scale.has_value()) {
486+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scale.value()});
487+
}
488+
if (dst_scale.has_value()) {
489+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scale.value()});
490+
}
491+
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, weight_scales});
492+
if (src_zero_point.has_value()) {
493+
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point.value()});
494+
}
495+
if (dst_zero_point.has_value()) {
496+
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_point.value()});
497+
}
498+
}
499+
};
500+
465501
#endif // #if AT_MKLDNN_ENABLED()

aten/src/ATen/native/quantized/cpu/qlinear.cpp

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,24 +1147,13 @@ static at::Tensor linear_int8_with_onednn_weight(
11471147
dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous();
11481148

11491149
auto src = at::native::itensor_from_tensor(input_contig);
1150-
auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight);
1151-
int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1);
1150+
int64_t K = input.size(dim - 1), M = input.numel() / K, N = onednn_weight.size(1);
11521151

11531152
auto output_size = input.sizes().vec();
11541153
output_size[dim - 1] = N;
11551154

1156-
std::optional<ideep::tensor> onednn_bias{std::nullopt};
11571155
bool with_bias = bias.has_value();
1158-
at::Tensor bias_val_float;
1159-
if (with_bias) {
1160-
bias_val_float = bias.value().to(at::kFloat);
1161-
if (bias_val_float.dim() == 1) {
1162-
auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)});
1163-
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
1164-
} else {
1165-
onednn_bias = at::native::itensor_view_from_dense(bias_val_float);
1166-
}
1167-
}
1156+
11681157
std::vector<int64_t> src_dims = {M, K};
11691158
std::vector<int64_t> dst_dims = {M, N};
11701159
auto out_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type();
@@ -1185,14 +1174,47 @@ static at::Tensor linear_int8_with_onednn_weight(
11851174
at::native::itensor_view_from_dense(other.value().reshape({-1, other.value().size(dim - 1)})) :
11861175
empty_tensor;
11871176

1177+
// Fast path with cache of params
1178+
static const char* env_var = std::getenv(CACHE_ONEDNN_CONTEXT_FLAG);
1179+
static const std::string cache_flag_str = env_var ? std::string(env_var) : "";
1180+
static const bool context_cache_enabled = cache_flag_str != "" && cache_flag_str == "1";
1181+
static std::unordered_map<int64_t, QlinearForwardParams> qlinear_forward_params_map;
1182+
int64_t weight_addr = at::native::data_ptr_from_mkldnn(onednn_weight);
1183+
if (context_cache_enabled) {
1184+
auto it = qlinear_forward_params_map.find(weight_addr);
1185+
if (it != qlinear_forward_params_map.end()) {
1186+
auto& params = it->second;
1187+
auto& args = params.args;
1188+
args[DNNL_ARG_SRC] = std::move(src);
1189+
args[DNNL_ARG_DST] = std::move(dst);
1190+
if (binary_post_op == "add") {
1191+
args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = std::move(src1);
1192+
}
1193+
params.primitive.execute(ideep::stream::default_stream(), args);
1194+
return dim == 2 ? output : output.resize_(output_size);
1195+
}
1196+
}
1197+
1198+
// Regular path
1199+
auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight);
1200+
tensor onednn_bias;
1201+
if (with_bias) {
1202+
at::Tensor bias_val_float = bias.value();
1203+
if (bias_val_float.dim() == 1) {
1204+
auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)});
1205+
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
1206+
} else {
1207+
onednn_bias = at::native::itensor_view_from_dense(bias_val_float);
1208+
}
1209+
}
11881210
// Create onednn primitive
11891211
auto src_dtype = at::native::get_mkldnn_dtype(input.scalar_type());
11901212
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
11911213
auto weights_desc = packed_weight.get_desc();
11921214
auto dst_dtype = dst.get_data_type();
11931215
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
11941216
auto bias_desc = with_bias ?
1195-
tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) :
1217+
tensor::desc(onednn_bias.get_dims(), onednn_bias.get_data_type(), ideep::format_tag::any) :
11961218
empty_tensor_desc;
11971219
// Get op attr for primitive
11981220
// Note: output_scale & output_zero_point are for re-quantization of the final output.
@@ -1249,7 +1271,7 @@ static at::Tensor linear_int8_with_onednn_weight(
12491271
args.insert({DNNL_ARG_DST, dst});
12501272
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
12511273
if (with_bias) {
1252-
args.insert({DNNL_ARG_BIAS, onednn_bias.value()});
1274+
args.insert({DNNL_ARG_BIAS, onednn_bias});
12531275
}
12541276
tensor src_scales_t = tensor(ideep::scale_t(1, input_scale));
12551277
tensor wei_scales_t = at::native::itensor_from_tensor(weight_scales);
@@ -1273,7 +1295,22 @@ static at::Tensor linear_int8_with_onednn_weight(
12731295
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, src1});
12741296
}
12751297
primitive.execute(ideep::stream::default_stream(), args);
1276-
return dim == 2 ? output : output.reshape(output_size);
1298+
// Update cache if needed
1299+
if (context_cache_enabled) {
1300+
QlinearForwardParams params;
1301+
params.primitive = primitive;
1302+
params.packed_weight = expected_weight;
1303+
params.weight_scales = wei_scales_t;
1304+
params.src_scale = input_scale != 1.0f ? std::make_optional<tensor>(src_scales_t) : std::nullopt;
1305+
params.dst_scale = output_scale != 1.0f ? std::make_optional<tensor>(dst_scales_t) : std::nullopt;
1306+
params.src_zero_point = input_zero_point != 0 ? std::make_optional<tensor>(src_zp_t) : std::nullopt;
1307+
params.dst_zero_point = output_zero_point != 0 ? std::make_optional<tensor>(dst_zp_t) : std::nullopt;
1308+
params.bias = with_bias ? std::make_optional<tensor>(onednn_bias) : std::nullopt;
1309+
params.scratchpad = scratchpad;
1310+
params.init_args();
1311+
qlinear_forward_params_map[weight_addr] = params;
1312+
}
1313+
return dim == 2 ? output : output.resize_(output_size);
12771314
}
12781315

12791316
#if AT_MKLDNN_ACL_ENABLED()

test/quantization/core/test_quantized_op.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4563,7 +4563,11 @@ def _test_qlinear_pt2e_helper(
45634563
post_op="none",
45644564
unary_post_op_args=(),
45654565
post_op_algorithms=("none",),
4566+
test_fast_path=False,
45664567
):
4568+
if test_fast_path:
4569+
import os
4570+
os.environ["ONEDNN_CACHE_CONTEXT_UNSAFE"] = "1"
45674571
qlinear_prepack = torch.ops.onednn.qlinear_prepack
45684572
linear_op = F.linear
45694573
in_channels_list = [4, 8]
@@ -4615,12 +4619,14 @@ def _test_qlinear_pt2e_helper(
46154619
qw_cpu = qw.int_repr()
46164620
qw_packed = qlinear_prepack(qw_cpu, x.shape)
46174621

4622+
num_iter = 2 if test_fast_path else 1 # rerun to use cache
46184623
if post_op in ("none", "relu", "gelu"):
4619-
qy_cpu = qlinear_op(
4620-
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4621-
b, used_y_scale, used_y_zp, output_dtype,
4622-
post_op, unary_post_op_args, post_op_algo
4623-
)
4624+
for _ in range(num_iter):
4625+
qy_cpu = qlinear_op(
4626+
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4627+
b, used_y_scale, used_y_zp, output_dtype,
4628+
post_op, unary_post_op_args, post_op_algo
4629+
)
46244630
if post_op == "relu":
46254631
y_ref = F.relu(y_ref)
46264632
elif post_op == "gelu":
@@ -4637,12 +4643,14 @@ def _test_qlinear_pt2e_helper(
46374643
accum = qx2.int_repr() if output_dtype is None else qx2.dequantize()
46384644
if bfloat16_out:
46394645
accum = accum.bfloat16()
4640-
qy_cpu = qlinear_op(
4641-
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4642-
accum, b, used_y_scale, used_y_zp, output_dtype,
4643-
x2_scale, x2_zp, "sum", binary_alpha,
4644-
unary_post_op, unary_post_op_args, post_op_algo
4645-
)
4646+
for _ in range(num_iter):
4647+
# clone accum otherwise it gets accumulated multiple times
4648+
qy_cpu = qlinear_op(
4649+
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4650+
accum.clone(), b, used_y_scale, used_y_zp, output_dtype,
4651+
x2_scale, x2_zp, "sum", binary_alpha,
4652+
unary_post_op, unary_post_op_args, post_op_algo
4653+
)
46464654
y_ref = y_ref + x2 * binary_alpha
46474655
if unary_post_op == "relu":
46484656
y_ref = F.relu(y_ref)
@@ -4655,12 +4663,13 @@ def _test_qlinear_pt2e_helper(
46554663
x2 = torch.randn(y_ref.size()) * 10
46564664
unary_post_op = "relu" if post_op == "add_relu" else "none"
46574665
binary_alpha = 1.0 # we only support alpha=1.0 now
4658-
qy_cpu = qlinear_op(
4659-
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4660-
x2, b, used_y_scale, used_y_zp, output_dtype,
4661-
1.0, 0, "add", binary_alpha,
4662-
unary_post_op, unary_post_op_args, post_op_algo
4663-
)
4666+
for _ in range(num_iter):
4667+
qy_cpu = qlinear_op(
4668+
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4669+
x2, b, used_y_scale, used_y_zp, output_dtype,
4670+
1.0, 0, "add", binary_alpha,
4671+
unary_post_op, unary_post_op_args, post_op_algo
4672+
)
46644673
y_ref = y_ref + x2 * binary_alpha
46654674
if unary_post_op == "relu":
46664675
y_ref = F.relu(y_ref)
@@ -4686,48 +4695,58 @@ def _test_qlinear_pt2e_helper(
46864695
y_s: {y_scale}, y_zp: {y_zp}""",
46874696
)
46884697

4698+
if test_fast_path:
4699+
del os.environ["ONEDNN_CACHE_CONTEXT_UNSAFE"]
4700+
46894701
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
46904702
@skipIfNoONEDNN
46914703
def test_qlinear_pt2e(self):
46924704
qlinear = torch.ops.onednn.qlinear_pointwise
46934705
self._test_qlinear_pt2e_helper(qlinear, "none")
4706+
self._test_qlinear_pt2e_helper(qlinear, "none", test_fast_path=True)
46944707

46954708
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
46964709
@skipIfNoONEDNN
46974710
def test_qlinear_relu_pt2e(self):
46984711
qlinear = torch.ops.onednn.qlinear_pointwise
46994712
self._test_qlinear_pt2e_helper(qlinear, "relu")
4713+
self._test_qlinear_pt2e_helper(qlinear, "relu", test_fast_path=True)
47004714

47014715
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
47024716
@skipIfNoONEDNN
47034717
def test_qlinear_gelu_pt2e(self):
47044718
qlinear = torch.ops.onednn.qlinear_pointwise
47054719
post_op_algorithms = ['none', 'tanh']
47064720
self._test_qlinear_pt2e_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms)
4721+
self._test_qlinear_pt2e_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms, test_fast_path=True)
47074722

47084723
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
47094724
@skipIfNoONEDNN
47104725
def test_qlinear_sum_pt2e(self):
47114726
qlinear = torch.ops.onednn.qlinear_pointwise.binary
47124727
self._test_qlinear_pt2e_helper(qlinear, "sum")
4728+
self._test_qlinear_pt2e_helper(qlinear, "sum", test_fast_path=True)
47134729

47144730
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
47154731
@skipIfNoONEDNN
47164732
def test_qlinear_sum_relu_pt2e(self):
47174733
qlinear = torch.ops.onednn.qlinear_pointwise.binary
47184734
self._test_qlinear_pt2e_helper(qlinear, "sum_relu")
4735+
self._test_qlinear_pt2e_helper(qlinear, "sum_relu", test_fast_path=True)
47194736

47204737
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
47214738
@skipIfNoONEDNN
47224739
def test_qlinear_add_pt2e(self):
47234740
qlinear = torch.ops.onednn.qlinear_pointwise.binary
47244741
self._test_qlinear_pt2e_helper(qlinear, "add")
4742+
self._test_qlinear_pt2e_helper(qlinear, "add", test_fast_path=True)
47254743

47264744
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
47274745
@skipIfNoONEDNN
47284746
def test_qlinear_add_relu_pt2e(self):
47294747
qlinear = torch.ops.onednn.qlinear_pointwise.binary
47304748
self._test_qlinear_pt2e_helper(qlinear, "add_relu")
4749+
self._test_qlinear_pt2e_helper(qlinear, "add_relu", test_fast_path=True)
47314750

47324751
def _test_qlinear_fp8_helper(
47334752
self,

0 commit comments

Comments
 (0)