Skip to content

Commit 4f779ef

Browse files
committed
cpu: rv64: add rvv batch normalization integration
1 parent e45c87c commit 4f779ef

File tree

4 files changed

+345
-0
lines changed

4 files changed

+345
-0
lines changed

src/cpu/cpu_batch_normalization_list.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ using namespace dnnl::impl::cpu::x64;
3434
#include "cpu/aarch64/acl_batch_normalization.hpp"
3535
#endif
3636
using namespace dnnl::impl::cpu::aarch64;
37+
#elif DNNL_RV64
38+
#if defined(DNNL_RISCV_USE_RVV_INTRINSICS)
39+
#include "cpu/rv64/rvv_batch_normalization.hpp"
40+
using namespace dnnl::impl::cpu::rv64;
41+
#endif
3742
#endif
3843

3944
namespace dnnl {
@@ -59,6 +64,7 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
5964
CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t<sve_256>)
6065
CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t<asimd>)
6166
CPU_INSTANCE_AARCH64_ACL(acl_batch_normalization_fwd_t)
67+
CPU_INSTANCE_RV64GCV(rvv_batch_normalization_fwd_t)
6268
CPU_INSTANCE(ncsp_batch_normalization_fwd_t<f32>)
6369
CPU_INSTANCE(ncsp_batch_normalization_fwd_t<bf16>)
6470
CPU_INSTANCE(ncsp_batch_normalization_fwd_t<f16>)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/******************************************************************************
2+
* Copyright 2025
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
******************************************************************************/
16+
17+
#include <assert.h>
18+
#include <math.h>
19+
#include <vector>
20+
#include <riscv_vector.h>
21+
22+
#include "common/c_types_map.hpp"
23+
#include "common/dnnl_thread.hpp"
24+
#include "common/type_helpers.hpp"
25+
26+
#include "cpu/rv64/rvv_batch_normalization.hpp"
27+
28+
namespace dnnl {
29+
namespace impl {
30+
namespace cpu {
31+
namespace rv64 {
32+
33+
namespace {
34+
35+
// If per_elem_params is false, uses broadcast scalars mean/sm/sv (mean[0], sm[0], sv[0]).
36+
// If true, loads per-element mean/sm/sv from the provided arrays.
37+
static inline void bn_fwd_kernel_f32(const void *s_base, void *d_base,
38+
size_t len, const float *mean, const float *sm, const float *sv,
39+
bool per_elem_params, const rv64::rvv_postops_t &po) {
40+
const size_t data_size = types::data_type_size(data_type::f32);
41+
for (size_t i = 0; i < len;) {
42+
size_t vl = __riscv_vsetvl_e32m1(len - i);
43+
44+
const float *s_ptr = reinterpret_cast<const float *>(
45+
reinterpret_cast<const char *>(s_base) + i * data_size);
46+
float *d_ptr = reinterpret_cast<float *>(
47+
reinterpret_cast<char *>(d_base) + i * data_size);
48+
49+
vfloat32m1_t vx = __riscv_vle32_v_f32m1(s_ptr, vl);
50+
51+
vfloat32m1_t vmean_v;
52+
vfloat32m1_t vsm_v;
53+
vfloat32m1_t vsv_v;
54+
if (per_elem_params) {
55+
vmean_v = __riscv_vle32_v_f32m1(mean + i, vl);
56+
vsm_v = __riscv_vle32_v_f32m1(sm + i, vl);
57+
vsv_v = __riscv_vle32_v_f32m1(sv + i, vl);
58+
} else {
59+
vmean_v = __riscv_vfmv_v_f_f32m1(mean[0], vl);
60+
vsm_v = __riscv_vfmv_v_f_f32m1(sm[0], vl);
61+
vsv_v = __riscv_vfmv_v_f_f32m1(sv[0], vl);
62+
}
63+
64+
vfloat32m1_t vtmp = __riscv_vfsub_vv_f32m1(vx, vmean_v, vl);
65+
vfloat32m1_t vout = __riscv_vfmul_vv_f32m1(vtmp, vsm_v, vl);
66+
vout = __riscv_vfadd_vv_f32m1(vout, vsv_v, vl);
67+
vout = po.apply(vout, vl);
68+
69+
__riscv_vse32_v_f32m1(d_ptr, vout, vl);
70+
i += vl;
71+
}
72+
}
73+
74+
} // namespace
75+
76+
status_t rvv_batch_normalization_fwd_t::execute_forward(
77+
const exec_ctx_t &ctx) const {
78+
const memory_desc_wrapper data_d(pd()->src_md());
79+
const auto dtsrc = pd()->src_md()->data_type;
80+
const int ndims = data_d.ndims();
81+
82+
const dim_t N = pd()->MB();
83+
const dim_t C = pd()->C();
84+
const dim_t D = pd()->D();
85+
const dim_t H = pd()->H();
86+
const dim_t W = pd()->W();
87+
88+
const float eps = pd()->desc()->batch_norm_epsilon;
89+
90+
void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
91+
const void *src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
92+
const float *mean = CTX_IN_MEM(const float *, DNNL_ARG_MEAN);
93+
const float *var = CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE);
94+
const float *scale = pd()->use_scale()
95+
? CTX_IN_MEM(const float *, DNNL_ARG_SCALE)
96+
: nullptr;
97+
const float *shift = pd()->use_shift()
98+
? CTX_IN_MEM(const float *, DNNL_ARG_SHIFT)
99+
: nullptr;
100+
101+
rv64::rvv_postops_t po = pd()->fused_relu_in_kernel()
102+
? rv64::rvv_postops_t(alg_kind::eltwise_relu)
103+
: rv64::rvv_postops_t(pd()->attr()->post_ops_);
104+
105+
auto off = [&](dim_t n, dim_t c, dim_t d, dim_t h, dim_t w) -> size_t {
106+
switch (ndims) {
107+
case 3: return data_d.off(n, c, w);
108+
case 4: return data_d.off(n, c, h, w);
109+
case 5: return data_d.off(n, c, d, h, w);
110+
default: assert(!"unsupported ndims"); return dim_t(0);
111+
}
112+
};
113+
114+
const bool channels_dense = data_d.blocking_desc().strides[1] == 1;
115+
116+
if (!channels_dense) {
117+
// abx data tag: vectorize over W for fixed channel
118+
parallel_nd(C, N, D, H, [&](dim_t c, dim_t n, dim_t d, dim_t h) {
119+
const float vmean = mean[c];
120+
const float inv_std = 1.0f / sqrtf(var[c] + eps);
121+
const float vscale = scale ? scale[c] : 1.0f;
122+
const float vshift = shift ? shift[c] : 0.0f;
123+
const float sm = vscale * inv_std;
124+
const float sv = vshift;
125+
size_t base_off = off(n, c, d, h, 0);
126+
127+
switch (dtsrc) {
128+
case data_type::f32: {
129+
const size_t data_size
130+
= types::data_type_size(data_type::f32);
131+
const void *s_ptr = reinterpret_cast<const void *>(
132+
reinterpret_cast<const char *>(src)
133+
+ base_off * data_size);
134+
void *d_ptr = reinterpret_cast<void *>(
135+
reinterpret_cast<char *>(dst)
136+
+ base_off * data_size);
137+
const float mean_b[1] = {vmean};
138+
const float sm_b[1] = {sm};
139+
const float sv_b[1] = {sv};
140+
bn_fwd_kernel_f32(s_ptr, d_ptr, static_cast<size_t>(W),
141+
mean_b, sm_b, sv_b, /*per_elem_params=*/false, po);
142+
break;
143+
}
144+
default:
145+
assert(!"Unsupported data type for RVV batch "
146+
"normalization");
147+
}
148+
});
149+
} else {
150+
// axb data tag: vectorize across channels
151+
auto &grantor = ctx.get_scratchpad_grantor();
152+
float *sm_arr = grantor.template get<float>(
153+
memory_tracking::names::key_bnorm_tmp_mean);
154+
float *sv_arr = grantor.template get<float>(
155+
memory_tracking::names::key_bnorm_tmp_var);
156+
for (dim_t c = 0; c < C; ++c) {
157+
const float inv_std = 1.0f / sqrtf(var[c] + eps);
158+
const float vscale = scale ? scale[c] : 1.0f;
159+
const float vshift = shift ? shift[c] : 0.0f;
160+
sm_arr[static_cast<size_t>(c)] = vscale * inv_std;
161+
sv_arr[static_cast<size_t>(c)] = vshift;
162+
}
163+
164+
parallel_nd(N, D, H, W, [&](dim_t n, dim_t d, dim_t h, dim_t w) {
165+
switch (dtsrc) {
166+
case data_type::f32: {
167+
const size_t data_size
168+
= types::data_type_size(data_type::f32);
169+
size_t base_off = off(n, 0, d, h, w);
170+
const void *s_ptr = reinterpret_cast<const void *>(
171+
reinterpret_cast<const char *>(src)
172+
+ base_off * data_size);
173+
void *d_ptr = reinterpret_cast<void *>(
174+
reinterpret_cast<char *>(dst)
175+
+ base_off * data_size);
176+
177+
bn_fwd_kernel_f32(s_ptr, d_ptr, static_cast<size_t>(C),
178+
mean, sm_arr, sv_arr,
179+
/*per_elem_params=*/true, po);
180+
break;
181+
}
182+
default:
183+
assert(!"Unsupported data type for RVV batch "
184+
"normalization");
185+
}
186+
});
187+
}
188+
189+
return status::success;
190+
}
191+
192+
} // namespace rv64
193+
} // namespace cpu
194+
} // namespace impl
195+
} // namespace dnnl
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/******************************************************************************
2+
* Copyright 2025
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
******************************************************************************/
16+
17+
#ifndef CPU_RV64_RVV_BATCH_NORMALIZATION_HPP
18+
#define CPU_RV64_RVV_BATCH_NORMALIZATION_HPP
19+
20+
#include "common/memory_tracking.hpp"
21+
#include "common/primitive.hpp"
22+
23+
#include "cpu/cpu_batch_normalization_pd.hpp"
24+
#include "cpu/platform.hpp"
25+
#include "cpu/rv64/rvv_postops.hpp"
26+
27+
namespace dnnl {
28+
namespace impl {
29+
namespace cpu {
30+
namespace rv64 {
31+
32+
struct rvv_batch_normalization_fwd_t : public primitive_t {
33+
struct pd_t : public cpu_batch_normalization_fwd_pd_t {
34+
using cpu_batch_normalization_fwd_pd_t::
35+
cpu_batch_normalization_fwd_pd_t;
36+
37+
DECLARE_COMMON_PD_T_("RISCV64GCV", rvv_batch_normalization_fwd_t);
38+
39+
status_t init(engine_t *engine) {
40+
UNUSED(engine);
41+
42+
using namespace data_type;
43+
44+
VDISPATCH_BNORM(is_fwd(), VERBOSE_BAD_PROPKIND);
45+
46+
const data_type_t dtsrc = src_md()->data_type;
47+
const data_type_t dtdst = dst_md()->data_type;
48+
bool types_ok = (dtsrc == f32 && dtdst == f32)
49+
&& platform::has_data_type_support(dtsrc)
50+
&& IMPLICATION(is_training(),
51+
platform::has_training_support(dtsrc));
52+
VDISPATCH_BNORM(types_ok, VERBOSE_UNSUPPORTED_DT);
53+
54+
VDISPATCH_BNORM(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, "");
55+
56+
// Require global stats (G). Flags C/H/R(inference) are optional. Disallow none and A.
57+
VDISPATCH_BNORM(!fuse_norm_add_relu(), VERBOSE_UNSUPPORTED_FEATURE,
58+
"fuse_norm_add_relu not supported");
59+
VDISPATCH_BNORM(use_global_stats(), VERBOSE_UNSUPPORTED_FEATURE,
60+
"stats must already have been computed (use global stats)");
61+
using smask_t = primitive_attr_t::skip_mask_t;
62+
VDISPATCH_BNORM(!(fuse_norm_relu()
63+
&& desc()->prop_kind
64+
== prop_kind::forward_training),
65+
VERBOSE_UNSUPPORTED_FEATURE,
66+
"forward training with fused ReLU is not supported");
67+
// Only support eltwise ReLU without alpha/beta post-op as current rvv_postops requires.
68+
VDISPATCH_BNORM(attr()->has_default_values(smask_t::post_ops),
69+
VERBOSE_UNSUPPORTED_ATTR);
70+
{
71+
const post_ops_t &po = attr()->post_ops_;
72+
bool relu_no_params_ok = true;
73+
if (po.len() == 1) {
74+
const auto &e = po.entry_[0];
75+
relu_no_params_ok = e.is_eltwise()
76+
&& e.eltwise.alg == alg_kind::eltwise_relu
77+
&& e.eltwise.alpha == 0.f && e.eltwise.beta == 0.f;
78+
} else if (po.len() > 1) {
79+
relu_no_params_ok = false;
80+
}
81+
VDISPATCH_BNORM(relu_no_params_ok, VERBOSE_UNSUPPORTED_ATTR);
82+
}
83+
VDISPATCH_BNORM(rv64::rvv_postops_t::post_ops_ok(attr()->post_ops_),
84+
VERBOSE_UNSUPPORTED_ATTR);
85+
86+
// Simplest memory layouts only: plain, dense, same layout src/dst, no blocking/padding.
87+
VDISPATCH_BNORM(
88+
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
89+
const memory_desc_wrapper src_d(src_md());
90+
const memory_desc_wrapper dst_d(dst_md());
91+
VDISPATCH_BNORM(
92+
check_layouts(src_d, dst_d), VERBOSE_UNSUPPORTED_TAG);
93+
94+
fused_relu_in_kernel_ = fuse_norm_relu();
95+
init_scratchpad();
96+
97+
return status::success;
98+
}
99+
bool check_layouts(const memory_desc_wrapper &src_d,
100+
const memory_desc_wrapper &dst_d) const {
101+
// Require plain, dense, no blocking/padding, same plain layout.
102+
bool ndims_ok = utils::one_of(ndims(), 3, 4, 5);
103+
bool plain_dense = src_d.blocking_desc().inner_nblks == 0
104+
&& dst_d.blocking_desc().inner_nblks == 0
105+
&& src_d.is_dense(/*with_padding=*/false)
106+
&& dst_d.is_dense(/*with_padding=*/false)
107+
&& src_d.is_plain() && dst_d.is_plain();
108+
bool same_layouts = src_d.similar_to(dst_d, /*with_strides=*/true,
109+
/*with_pads=*/false);
110+
return ndims_ok && plain_dense && same_layouts;
111+
}
112+
113+
bool fused_relu_in_kernel() const { return fused_relu_in_kernel_; }
114+
115+
private:
116+
void init_scratchpad() {
117+
using namespace memory_tracking::names;
118+
auto scratchpad = scratchpad_registry().registrar();
119+
// Reserve per-channel temporary buffers for axb (channels-dense) path
120+
scratchpad.template book<float>(key_bnorm_tmp_mean, C());
121+
scratchpad.template book<float>(key_bnorm_tmp_var, C());
122+
}
123+
bool fused_relu_in_kernel_ = false;
124+
};
125+
126+
rvv_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {}
127+
128+
status_t execute(const exec_ctx_t &ctx) const override {
129+
return execute_forward(ctx);
130+
}
131+
132+
private:
133+
status_t execute_forward(const exec_ctx_t &ctx) const;
134+
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
135+
};
136+
137+
} // namespace rv64
138+
} // namespace cpu
139+
} // namespace impl
140+
} // namespace dnnl
141+
142+
#endif // CPU_RV64_RVV_BATCH_NORMALIZATION_HPP

src/cpu/rv64/rvv_postops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ struct rvv_postops_t {
8686
return status::success;
8787
}
8888

89+
explicit rvv_postops_t(alg_kind_t alg) : alg_(alg) {}
90+
8991
static bool post_ops_ok(const post_ops_t &po) {
9092
if (po.len() == 0) return true;
9193
if (po.len() > 1) return false;

0 commit comments

Comments
 (0)