88#include < openvino/op/broadcast.hpp>
99#include < openvino/op/concat.hpp>
1010#include < openvino/op/constant.hpp>
11- #include < openvino/op/convert_like .hpp>
11+ #include < openvino/op/convert .hpp>
1212#include < openvino/op/range.hpp>
1313#include < openvino/op/reshape.hpp>
1414#include < openvino/op/scatter_nd_update.hpp>
@@ -34,18 +34,26 @@ OutputVector translate_cpy(const NodeContext& context) {
3434
3535 auto src0 = context.get_input (0 );
3636 auto src1 = context.get_input (1 );
37- auto past_token_len = context.get_input (" past_token_len" );
37+ auto past_token_len_scalar = context.get_input (" past_token_len" );
38+
39+ src0 = std::make_shared<ov::op::v0::Convert>(src0, context.get_input_type (1 ));
3840 ov::Output<Node> res;
3941
42+ if (context.is_static () && context.is_first_token ()) {
43+ res = src0;
44+ return rename_outputs_with_suffix ({res}, context.get_name ());
45+ }
46+
4047 auto src0_shape = context.get_input_shape (0 ).to_shape ();
4148 auto output_shape = context.get_output_shape (0 ).to_shape ();
4249
4350 std::vector<size_t > input0_strides = context.get_input_stride (0 );
4451 std::vector<size_t > output_strides = context.get_output_stride (0 );
4552
46- auto one = ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{}, {1 });
53+ auto zero = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
54+ auto one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
55+ auto one_scalar = ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{}, {1 });
4756
48- src0 = std::make_shared<ov::op::v1::ConvertLike>(src0, src1);
4957 if (op_case == 1 ) {
5058 // Write K to cache_k
5159 int64_t head_size = src0_shape[2 ];
@@ -56,69 +64,36 @@ OutputVector translate_cpy(const NodeContext& context) {
5664 auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(src1, reshaped_src1_shape, false );
5765
5866 auto token_len = get_dimensions (src0.get_node_shared_ptr (), {0 });
59- token_len = std::make_shared<ov::op::v1::Reshape>(token_len,
60- ov::op::v0::Constant::create (ov::element::i64 , {0 }, {}),
61- false );
67+ auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
6268
69+ std::shared_ptr<ov::Node> indices;
6370 if (context.is_static ()) {
64- int32_t * op_params = context.get_input_op_params (1 );
65- int64_t past_token_len_val = op_params[0 ] / context.get_input_stride (1 )[2 ] / num_heads / head_size;
66- past_token_len = ov::op::v0::Constant::create (ov::element::i64 , {}, {past_token_len_val});
71+ indices = past_token_len_scalar.get_node_shared_ptr ();
72+ indices = std::make_shared<ov::op::v0::Unsqueeze>(
73+ indices,
74+ ov::op::v0::Constant::create (ov::element::i64 , {2 }, std::vector<int64_t >{0 , 1 }));
75+ } else {
76+ auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
77+ indices = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
78+ total_token_len_scalar,
79+ one_scalar,
80+ ov::element::i64 );
81+ indices = std::make_shared<ov::op::v0::Unsqueeze>(indices, one);
6782 }
6883
69- auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
70- std::shared_ptr<ov::Node> indices =
71- std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len, one, ov::element::i64 );
72- indices = std::make_shared<ov::op::v0::Unsqueeze>(
73- indices,
74- ov::op::v0::Constant::create (ov::element::i64 , {1 }, std::vector<int64_t >{1 }));
75-
7684 res = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices, src0);
7785 } else {
7886 // Write V to cache_v
79- auto zero = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
8087 auto one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
8188 auto two = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {2 });
82-
8389 auto zero_scalar = ov::op::v0::Constant::create (ov::element::i64 , {}, {0 });
84- auto one_scalar = ov::op::v0::Constant::create (ov::element::i64 , {}, {1 });
8590
8691 int64_t total_head_size = src0_shape[1 ];
8792 auto total_head_size_node = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {total_head_size});
8893 auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero);
8994
9095 auto token_len = get_dimensions (src0.get_node_shared_ptr (), {2 });
9196 auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
92- if (context.is_static ()) {
93- int32_t * op_params = context.get_input_op_params (1 );
94- int64_t past_token_len_val = op_params[0 ] / context.get_input_stride (1 )[2 ];
95- past_token_len = ov::op::v0::Constant::create (ov::element::i64 , {}, {past_token_len_val});
96- }
97- auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len, token_len_scalar);
98-
99- // auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
100- // src1,
101- // ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
102- // false);
103-
104- // auto src1_left = std::make_shared<ov::op::v8::Slice>(
105- // reshaped_src1,
106- // ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 0, 0}),
107- // std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, total_head_size_node, past_token_len}, 0),
108- // ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
109-
110- // auto src1_right = std::make_shared<ov::op::v8::Slice>(
111- // reshaped_src1,
112- // std::make_shared<ov::op::v0::Concat>(ov::OutputVector{zero, zero, total_token_len}, 0),
113- // ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, INT_MAX}),
114- // ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
115-
116- // auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
117- // src0,
118- // ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
119- // false);
120-
121- // auto res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2);
12297
12398 // 1D tensor of shape [total_head_size], values starting from 0
12499 auto range_row =
@@ -131,8 +106,19 @@ OutputVector translate_cpy(const NodeContext& context) {
131106 std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0 ));
132107
133108 // 1D tensor of shape [token_len], values starting from past_token_len
134- auto range_col =
135- std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len_scalar, one_scalar, element::i64 );
109+ std::shared_ptr<ov::Node> range_col;
110+ if (context.is_static ()) {
111+ range_col = past_token_len_scalar.get_node_shared_ptr ();
112+ range_col = std::make_shared<ov::op::v0::Unsqueeze>(
113+ range_col,
114+ ov::op::v0::Constant::create (ov::element::i64 , {1 }, std::vector<int64_t >{0 }));
115+ } else {
116+ auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
117+ range_col = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
118+ total_token_len_scalar,
119+ one_scalar,
120+ ov::element::i64 );
121+ }
136122 auto range_col_reshaped =
137123 std::make_shared<ov::op::v0::Unsqueeze>(range_col,
138124 ov::op::v0::Constant::create (ov::element::i64 , {2 }, {0 , 2 }));
0 commit comments