33#include < openvino/op/constant.hpp>
44#include < openvino/op/divide.hpp>
55#include < openvino/op/multiply.hpp>
6+ #include < openvino/op/power.hpp>
67#include < openvino/op/reduce_mean.hpp>
78#include < openvino/op/sqrt.hpp>
89
@@ -19,18 +20,17 @@ OutputVector translate_rms_norm(const NodeContext& context) {
1920 num_inputs_check (context, 1 , 1 );
2021
2122 auto input_node = context.get_input (0 );
22- auto square = std::make_shared<ov::op::v1::Multiply>(input_node, input_node);
23+ auto square = std::make_shared<ov::op::v1::Power>(
24+ input_node, ov::op::v0::Constant::create (ov::element::f32 , ov::Shape{1 }, {2 .0f }));
2325
24- auto mean =
25- std::make_shared<ov::op::v1::ReduceMean>(square,
26- ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{1 }, {2 }),
27- true );
26+ auto mean = std::make_shared<ov::op::v1::ReduceMean>(
27+ square, ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{1 }, {-1 }), true );
2828
2929 float eps;
3030 memcpy (&eps, context.get_output_op_params (0 ), sizeof (float ));
3131
3232 auto rms = std::make_shared<ov::op::v0::Sqrt>(
33- std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create (ov::element::f32 , ov::Shape{}, {eps})));
33+ std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create (ov::element::f32 , ov::Shape{1 }, {eps})));
3434
3535 auto reciprocal =
3636 std::make_shared<ov::op::v1::Divide>(ov::op::v0::Constant::create (ov::element::f32 , ov::Shape{1 }, {1 .0f }), rms);
0 commit comments