@@ -77,23 +77,28 @@ void add_token_len(TensorMap& tensor_map) {
7777 tensor_map.insert ({" token_len" , token_len->output (0 )});
7878}
7979
80- void add_sliced_mask (TensorMap& tensor_map) {
80+ void add_sliced_mask (TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder ) {
8181 auto token_len = tensor_map.at (" token_len" ).get_node_shared_ptr ();
8282
83- auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name) {
83+ auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name, bool is_static ) {
8484 if (tensor_map.find (mask_name) != tensor_map.end ()) {
85- auto zero = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
86- auto one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
8785 auto mask = tensor_map.at (mask_name).get_node_shared_ptr ();
88- std::shared_ptr<ov::Node> mask_sliced =
89- std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
90- mask_sliced->set_friendly_name (sliced_name);
86+ std::shared_ptr<ov::Node> mask_sliced;
87+ if (is_static) {
88+ mask_sliced = mask;
89+ } else {
90+ auto zero = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
91+ auto one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
92+ mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
93+ mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16 );
94+ mask_sliced->set_friendly_name (sliced_name);
95+ }
9196 tensor_map.insert ({sliced_name, mask_sliced->output (0 )});
9297 }
9398 };
9499
95- create_sliced_mask (" KQ_mask" , " KQ_mask_sliced" );
96- create_sliced_mask (" KQ_mask_swa" , " KQ_mask_swa_sliced" );
100+ create_sliced_mask (" KQ_mask" , " KQ_mask_sliced" , ggml_model_decoder. is_static () );
101+ create_sliced_mask (" KQ_mask_swa" , " KQ_mask_swa_sliced" , ggml_model_decoder. is_static () );
97102}
98103
99104void add_rope_sin_cos (TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
@@ -117,7 +122,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
117122// Create common patterns
118123void preprocess (TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
119124 add_token_len (tensor_map);
120- add_sliced_mask (tensor_map);
125+ add_sliced_mask (tensor_map, ggml_model_decoder );
121126 add_rope_sin_cos (tensor_map, ggml_model_decoder);
122127}
123128
0 commit comments