55#include < openvino/core/node_output.hpp>
66#include < openvino/core/node_vector.hpp>
77#include < openvino/op/add.hpp>
8+ #include < openvino/op/broadcast.hpp>
89#include < openvino/op/concat.hpp>
910#include < openvino/op/constant.hpp>
1011#include < openvino/op/convert_like.hpp>
1112#include < openvino/op/range.hpp>
1213#include < openvino/op/reshape.hpp>
1314#include < openvino/op/scatter_nd_update.hpp>
1415#include < openvino/op/slice.hpp>
16+ #include < openvino/op/squeeze.hpp>
1517#include < openvino/op/transpose.hpp>
1618#include < openvino/op/unsqueeze.hpp>
1719#include < vector>
@@ -57,6 +59,13 @@ OutputVector translate_cpy(const NodeContext& context) {
5759 token_len = std::make_shared<ov::op::v1::Reshape>(token_len,
5860 ov::op::v0::Constant::create (ov::element::i64 , {0 }, {}),
5961 false );
62+
63+ if (context.is_static ()) {
64+ int32_t * op_params = context.get_input_op_params (1 );
65+ int64_t past_token_len_val = op_params[0 ] / context.get_input_stride (1 )[2 ] / num_heads / head_size;
66+ past_token_len = ov::op::v0::Constant::create (ov::element::i64 , {}, {past_token_len_val});
67+ }
68+
6069 auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
6170 std::shared_ptr<ov::Node> indices =
6271 std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len, one, ov::element::i64 );
@@ -67,39 +76,88 @@ OutputVector translate_cpy(const NodeContext& context) {
6776 res = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices, src0);
6877 } else {
6978 // Write V to cache_v
70- int64_t total_head_size = src0_shape[1 ];
71- auto total_head_size_node = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {total_head_size});
72-
7379 auto zero = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
7480 auto one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
81+ auto two = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {2 });
82+
83+ auto zero_scalar = ov::op::v0::Constant::create (ov::element::i64 , {}, {0 });
84+ auto one_scalar = ov::op::v0::Constant::create (ov::element::i64 , {}, {1 });
85+
86+ int64_t total_head_size = src0_shape[1 ];
87+ auto total_head_size_node = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {total_head_size});
88+ auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero);
7589
7690 auto token_len = get_dimensions (src0.get_node_shared_ptr (), {2 });
77- past_token_len = std::make_shared<ov::op::v0::Unsqueeze>(past_token_len, zero);
78- auto total_token_len = std::make_shared<ov::op::v1::Add>(past_token_len, token_len);
91+ auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
92+ if (context.is_static ()) {
93+ int32_t * op_params = context.get_input_op_params (1 );
94+ int64_t past_token_len_val = op_params[0 ] / context.get_input_stride (1 )[2 ];
95+ past_token_len = ov::op::v0::Constant::create (ov::element::i64 , {}, {past_token_len_val});
96+ }
97+ auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len, token_len_scalar);
98+
99+ // auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
100+ // src1,
101+ // ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
102+ // false);
103+
104+ // auto src1_left = std::make_shared<ov::op::v8::Slice>(
105+ // reshaped_src1,
106+ // ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 0, 0}),
107+ // std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, total_head_size_node, past_token_len}, 0),
108+ // ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
109+
110+ // auto src1_right = std::make_shared<ov::op::v8::Slice>(
111+ // reshaped_src1,
112+ // std::make_shared<ov::op::v0::Concat>(ov::OutputVector{zero, zero, total_token_len}, 0),
113+ // ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, INT_MAX}),
114+ // ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 1, 1}));
115+
116+ // auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
117+ // src0,
118+ // ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{1, total_head_size, -1}),
119+ // false);
120+
121+ // auto res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2);
122+
123+ // 1D tensor of shape [total_head_size], values starting from 0
124+ auto range_row =
125+ std::make_shared<ov::op::v4::Range>(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i64 );
126+ auto range_row_reshaped =
127+ std::make_shared<ov::op::v0::Unsqueeze>(range_row,
128+ ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 , 2 }));
129+ auto row_indices = std::make_shared<ov::op::v3::Broadcast>(
130+ range_row_reshaped,
131+ std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0 ));
132+
133+ // 1D tensor of shape [token_len], values starting from past_token_len
134+ auto range_col =
135+ std::make_shared<ov::op::v4::Range>(past_token_len, total_token_len_scalar, one_scalar, element::i64 );
136+ auto range_col_reshaped =
137+ std::make_shared<ov::op::v0::Unsqueeze>(range_col,
138+ ov::op::v0::Constant::create (ov::element::i64 , {2 }, {0 , 2 }));
139+ auto col_indices = std::make_shared<ov::op::v3::Broadcast>(
140+ range_col_reshaped,
141+ std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0 ));
142+
143+ // Stack row_indices and col_indices along last axis: [total_head_size, token_len, 2]
144+ auto indices = std::make_shared<ov::op::v0::Concat>(OutputVector{row_indices, col_indices}, 2 );
145+ auto indices_final = std::make_shared<ov::op::v1::Reshape>(
146+ indices,
147+ ov::op::v0::Constant::create (ov::element::i64 , {2 }, std::vector<int64_t >{-1 , 2 }),
148+ false );
79149
150+ auto flattend_src0 =
151+ std::make_shared<ov::op::v1::Reshape>(src0,
152+ ov::op::v0::Constant::create (element::i64 , Shape{1 }, {-1 }),
153+ false );
80154 auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
81155 src1,
82- ov::op::v0::Constant::create (ov::element::i64 , {3 }, std::vector<int64_t >{1 , total_head_size, -1 }),
83- false );
84-
85- auto src1_left = std::make_shared<ov::op::v8::Slice>(
86- reshaped_src1,
87- ov::op::v0::Constant::create (ov::element::i64 , {3 }, {0 , 0 , 0 }),
88- std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, total_head_size_node, past_token_len}, 0 ),
89- ov::op::v0::Constant::create (ov::element::i64 , {3 }, {1 , 1 , 1 }));
90-
91- auto src1_right = std::make_shared<ov::op::v8::Slice>(
92- reshaped_src1,
93- std::make_shared<ov::op::v0::Concat>(ov::OutputVector{zero, zero, total_token_len}, 0 ),
94- ov::op::v0::Constant::create (ov::element::i64 , {3 }, std::vector<int64_t >{1 , total_head_size, INT_MAX}),
95- ov::op::v0::Constant::create (ov::element::i64 , {3 }, {1 , 1 , 1 }));
96-
97- auto reshaped_src0 = std::make_shared<ov::op::v1::Reshape>(
98- src0,
99- ov::op::v0::Constant::create (ov::element::i64 , {3 }, std::vector<int64_t >{1 , total_head_size, -1 }),
156+ ov::op::v0::Constant::create (ov::element::i64 , {2 }, std::vector<int64_t >{total_head_size, -1 }),
100157 false );
101158
102- res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2 );
159+ auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices_final, flattend_src0);
160+ res = std::make_shared<ov::op::v0::Unsqueeze>(updated, zero);
103161 }
104162
105163 return rename_outputs_with_suffix ({res}, context.get_name ());
0 commit comments