Skip to content

Commit 73dfc75

Browse files
committed
Perf: RMS fused to OV internal RMS op
1 parent d861305 commit 73dfc75

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ggml/src/ggml-openvino/openvino/op/rms_norm.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <openvino/op/constant.hpp>
44
#include <openvino/op/divide.hpp>
55
#include <openvino/op/multiply.hpp>
6+
#include <openvino/op/power.hpp>
67
#include <openvino/op/reduce_mean.hpp>
78
#include <openvino/op/sqrt.hpp>
89

@@ -19,18 +20,17 @@ OutputVector translate_rms_norm(const NodeContext& context) {
1920
num_inputs_check(context, 1, 1);
2021

2122
auto input_node = context.get_input(0);
22-
auto square = std::make_shared<ov::op::v1::Multiply>(input_node, input_node);
23+
auto square = std::make_shared<ov::op::v1::Power>(
24+
input_node, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {2.0f}));
2325

24-
auto mean =
25-
std::make_shared<ov::op::v1::ReduceMean>(square,
26-
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}),
27-
true);
26+
auto mean = std::make_shared<ov::op::v1::ReduceMean>(
27+
square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true);
2828

2929
float eps;
3030
memcpy(&eps, context.get_output_op_params(0), sizeof(float));
3131

3232
auto rms = std::make_shared<ov::op::v0::Sqrt>(
33-
std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{}, {eps})));
33+
std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps})));
3434

3535
auto reciprocal =
3636
std::make_shared<ov::op::v1::Divide>(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms);

0 commit comments

Comments
 (0)