@@ -42,15 +42,15 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
4242 mask_sliced = context.get_input (mask_name);
4343 } else {
4444 auto token_len = get_dimensions (q, {2 });
45+ auto kv_len = get_dimensions (k.get_node_shared_ptr (), {2 });
46+
4547 auto zero_2d = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {0 ,0 });
4648 auto one_2d = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 ,1 });
4749 auto zero_1d = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
4850 auto two_1d = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {2 });
4951 auto axes = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 ,2 });
50- auto inp_pos = context.get_input (" inp_pos" );
51- auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
52- auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
53- auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_inp_pos}, 0 );
52+
53+ auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0 );
5454 mask_sliced =
5555 std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
5656 mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
0 commit comments