1111#include < openvino/op/convert.hpp>
1212#include < openvino/op/cos.hpp>
1313#include < openvino/op/divide.hpp>
14- #include < openvino/op/gather.hpp>
1514#include < openvino/op/multiply.hpp>
1615#include < openvino/op/parameter.hpp>
1716#include < openvino/op/range.hpp>
1817#include < openvino/op/reshape.hpp>
1918#include < openvino/op/result.hpp>
2019#include < openvino/op/sin.hpp>
20+ #include < openvino/op/slice.hpp>
2121#include < openvino/op/squeeze.hpp>
22+ #include < openvino/op/strided_slice.hpp>
2223#include < openvino/op/transpose.hpp>
2324#include < openvino/op/unsqueeze.hpp>
2425#include < openvino/pass/constant_folding.hpp>
@@ -88,15 +89,27 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
8889 if (is_static) {
8990 mask_sliced = mask;
9091 } else {
91- auto zero_2d = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {0 ,0 });
92- auto one_2d = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 ,1 });
92+ auto zero_2d = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {0 , 0 });
93+ auto one_2d = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 , 1 });
94+ auto one_1d = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
9395 auto zero_1d = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
9496 auto two_1d = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {2 });
95- auto axes = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 ,2 });
96- auto inp_pos = tensor_map.at (" inp_pos" ).get_node_shared_ptr ();
97- auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
98- auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
99- auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_inp_pos}, 0 );
97+ auto axes = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 , 2 });
98+
99+ std::shared_ptr<ov::Node> kv_len;
100+ {
101+ auto start = ov::op::v0::Constant::create (element::i64 , Shape{3 }, {0 , 0 , -1 });
102+ auto stride = ov::op::v0::Constant::create (element::i64 , Shape{3 }, {1 , 1 , 1 });
103+ auto inp_pos = tensor_map.at (" inp_pos" ).get_node_shared_ptr ();
104+ kv_len = std::make_shared<ov::op::v1::StridedSlice>(
105+ inp_pos, start, start, stride, std::vector<int64_t >{0 , 0 , 0 }, std::vector<int64_t >{1 , 1 , 1 });
106+ }
107+ kv_len = std::make_shared<ov::op::v0::Squeeze>(
108+ kv_len, ov::op::v0::Constant::create (ov::element::i64 , {2 }, {0 , 1 }));
109+ kv_len = std::make_shared<ov::op::v0::Convert>(kv_len, ov::element::i64 );
110+ kv_len = std::make_shared<ov::op::v1::Add>(kv_len, one_1d);
111+ auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0 );
112+
100113 mask_sliced =
101114 std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
102115 mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
@@ -108,7 +121,8 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
108121 };
109122
110123 create_sliced_mask (" KQ_mask" , " KQ_mask_sliced" , ggml_model_decoder.is_static ());
111- create_sliced_mask (" KQ_mask_swa" , " KQ_mask_swa_sliced" , ggml_model_decoder.is_static ());
124+ // swa is not working for the `kv_len` is not correct
125+ // create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
112126}
113127
114128void add_rope_sin_cos (TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
@@ -132,7 +146,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
132146// Create common patterns
133147void preprocess (TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
134148 add_token_len (tensor_map);
135- // add_sliced_mask(tensor_map, ggml_model_decoder);
149+ add_sliced_mask (tensor_map, ggml_model_decoder);
136150 add_rope_sin_cos (tensor_map, ggml_model_decoder);
137151}
138152
0 commit comments