1111#include < openvino/op/shape_of.hpp>
1212#include < openvino/op/slice.hpp>
1313#include < openvino/op/squeeze.hpp>
14+ #include < openvino/op/transpose.hpp>
1415
1516#include " ../node_context.hpp"
1617#include " ../op_table.hpp"
@@ -25,21 +26,40 @@ OutputVector translate_set_rows(const NodeContext& context) {
2526 num_inputs_check (context, 2 , 2 );
2627
2728 auto data = context.get_input (0 );
28- auto indices = context.get_input ( 1 );
29- auto dst = context. get_input (context. get_output_name ());
29+ data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type ( 0 ) );
30+
3031 auto dst_shape = context.get_output_shape (0 ).to_shape ();
3132 FRONT_END_OP_CONVERSION_CHECK (dst_shape[0 ] == 1 , " Unsupported shape in SET_ROWS" );
3233
33- auto zero = ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{1 }, {0 });
34+ if (context.is_static () && context.is_first_token ()) {
35+ Output<Node> res;
36+ if (context.get_op_case () == 2 ) {
37+ res = std::make_shared<ov::op::v1::Reshape>(
38+ data,
39+ ov::op::v0::Constant::create (
40+ ov::element::i64 ,
41+ {3 },
42+ {context.get_context_size (), context.get_num_heads_kv (), context.get_head_size ()}),
43+ false );
44+ res = std::make_shared<ov::op::v1::Transpose>(
45+ res, ov::op::v0::Constant::create (ov::element::i64 , {3 }, {1 , 2 , 0 }));
46+ } else {
47+ res = data;
48+ }
49+ return rename_outputs_with_suffix ({res}, context.get_name ());
50+ }
3451
52+ auto indices = context.get_input (1 );
53+ auto dst = context.get_input (context.get_output_name ());
54+
55+ auto zero = ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{1 }, {0 });
3556 auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
3657 dst,
3758 ov::op::v0::Constant::create (ov::element::i64 , {2 }, {(int64_t ) dst_shape[1 ], (int64_t ) dst_shape[2 ]}),
3859 false );
3960 auto indices_reshaped =
4061 std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create (ov::element::i64 , {2 }, {0 , 1 }));
41- auto data_converted = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type (0 ));
42- auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data_converted, zero);
62+ auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data, zero);
4363 auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
4464 auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false );
4565 return rename_outputs_with_suffix ({res}, context.get_name ());
0 commit comments