@@ -24,7 +24,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
2424 num_inputs_check (context, 2 , 2 );
2525
2626 int op_case = context.get_op_case ();
27- FRONT_END_CHECK_IMPLEMENTED (op_case == 1 || op_case == 2 , " Unsupported MULMAT case" );
27+ FRONT_END_CHECK_IMPLEMENTED (op_case == 1 || op_case == 2 || op_case == 3 , " Unsupported MULMAT case" );
2828
2929 ov::Output<Node> res;
3030
@@ -59,8 +59,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
5959 auto src0 = context.get_input (0 );
6060 auto src0_shape = context.get_input_shape (0 ).to_shape ();
6161 auto src0_stride = context.get_input_stride (0 );
62- auto permuted = is_permuted (src0_stride);
63- auto token_dim = permuted ? 0 : 2 ;
62+ auto token_dim = op_case == 2 ? 0 : 2 ;
6463
6564 auto attention_size = context.get_input (" attention_size" );
6665
@@ -81,7 +80,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
8180 auto src0_reshape = std::make_shared<ov::op::v1::Reshape>(src0, src0_reshape_shape, false );
8281
8382 std::shared_ptr<ov::Node> slice_end;
84- if (permuted ) {
83+ if (op_case == 2 ) {
8584 slice_end = std::make_shared<ov::op::v0::Concat>(
8685 ov::OutputVector{attention_size, ov::op::v0::Constant::create (ov::element::i64 , {2 }, src0_slice_shape)},
8786 0 );
@@ -94,7 +93,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
9493 auto slice_step = ov::op::v0::Constant::create (ov::element::i64 , {3 }, std::vector<int64_t >(3 , 1 ));
9594 auto src0_slice = std::make_shared<ov::op::v8::Slice>(src0_reshape, slice_start, slice_end, slice_step);
9695
97- if (permuted ) {
96+ if (op_case == 2 ) {
9897 B = std::make_shared<ov::op::v1::Transpose>(
9998 src0_slice,
10099 ov::op::v0::Constant::create (ov::element::i64 , {src0_perm.size ()}, src0_perm));
0 commit comments