Skip to content

Commit c64b809

Browse files
authored
GNA convert matmul to pointwise convolution transformation unit tests (openvinotoolkit#6524)
* ConvertMatmulToPointWiseConvolutionTest first test * add ConvertMatmulToPointWiseConvolutionFqTest * use general functions to create test subgraphs * use general funstion to append node; add ConvertMatmulWithBiasToPointWiseConvolutionTest * add ConvertMatmulWithBiasToPointWiseConvolutionFqTest * use decorator instead of bool function arguments * remove unused functions * cleanup * add ConvertMatmulWithFqToPointWiseConvolutionTest * add ConvertMatmulWithFqToPointWiseConvolutionFqTest * add ConvertMatmulWithFqToPointWiseConvolutionTestNoAddNode * remove debug * add ConvertMatmulToPointWiseConvolutionTestInputRank3 * use TEST_P for ConvertMatmulToPointWiseConvolution tests * use testing::values fixture instead of multiple tests * cleanup * use combine tests for invalid inputs * code style cleanup * fix unique_ptr build under Windows * code review fixes: function template params * code review fixes: remove duplicated test entry * fix function arguments alignments
1 parent f48ea5d commit c64b809

File tree

2 files changed

+426
-0
lines changed

2 files changed

+426
-0
lines changed

inference-engine/src/gna_plugin/transformations/convert_matmul_to_pointwise_convolution.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ngraph/opsets/opset7.hpp>
99
#include <ngraph/pattern/op/or.hpp>
1010
#include <ngraph/pattern/op/wrap_type.hpp>
11+
#include <ngraph/rt_info.hpp>
1112

1213
#include "layers/gna_permute.hpp"
1314
#include "backend/gna_limitations.hpp"
@@ -62,37 +63,44 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
6263
ngraph::Shape{1, 1, width, in_channels});
6364
auto reshape_before = std::make_shared<ngraph::opset7::Reshape>(input_node, reshape_const_before, false);
6465
reshape_before->set_friendly_name(base_name + "/reshape_in");
66+
ngraph::copy_runtime_info(input_node, reshape_before);
6567

6668
auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before,
6769
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
6870
GetPermuteOrder(InferenceEngine::Layout::NHWC, InferenceEngine::Layout::NCHW)));
6971
transpose_before->set_friendly_name(base_name + "/transpose_in");
72+
ngraph::copy_runtime_info(matmul_node, transpose_before);
7073

7174
auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
7275
ngraph::Shape{4}, ngraph::Shape{out_channels, in_channels, 1, 1});
7376
auto weights_reshaped = std::make_shared<ngraph::opset7::Reshape>(weights_node, weights_reshape_const, false);
77+
ngraph::copy_runtime_info(weights_node, weights_reshaped);
7478

7579
std::shared_ptr<ngraph::Node> conv_node = std::make_shared<ngraph::opset7::Convolution>(transpose_before, weights_reshaped,
7680
ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0},
7781
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
7882
conv_node->set_friendly_name(base_name + "/conv");
83+
ngraph::copy_runtime_info(transpose_before, conv_node);
7984

8085
std::shared_ptr<ngraph::Node> root_node = matmul_node;
8186
if (bias != nullptr) {
8287
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, bias);
88+
ngraph::copy_runtime_info(transpose_before, conv_node);
8389
root_node = add;
8490
}
8591

8692
if (fq != nullptr) {
8793
conv_node = fq->clone_with_new_inputs({conv_node, fq->input_value(1), fq->input_value(2),
8894
fq->input_value(3), fq->input_value(4)});
95+
ngraph::copy_runtime_info(fq, conv_node);
8996
root_node = fq;
9097
}
9198

9299
auto transpose_after = std::make_shared<ngraph::opset7::Transpose>(conv_node,
93100
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
94101
GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC)));
95102
transpose_after->set_friendly_name(base_name + "/transpose_out");
103+
ngraph::copy_runtime_info(conv_node, transpose_after);
96104

97105
auto output_shape = matmul_node->get_output_shape(0);
98106
output_shape[output_shape.size() - 1] = out_channels;
@@ -102,6 +110,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
102110
output_shape);
103111
auto reshape_after = std::make_shared<ngraph::opset7::Reshape>(transpose_after, reshape_const_after, false);
104112
reshape_after->set_friendly_name(base_name);
113+
ngraph::copy_runtime_info(transpose_after, reshape_after);
105114

106115
ngraph::replace_node(root_node, reshape_after);
107116
return true;

0 commit comments

Comments
 (0)