|
6 | 6 |
|
7 | 7 | #include "pool2d.hpp" |
8 | 8 | #include <ngraph/opsets/opset6.hpp> |
| 9 | +#include <ngraph/opsets/opset8.hpp> |
9 | 10 |
|
10 | 11 | namespace ngraph |
11 | 12 | { |
@@ -138,56 +139,39 @@ namespace ngraph |
138 | 139 | { |
139 | 140 | PDPD_ASSERT(input_shape[2].is_static() && input_shape[3].is_static(), |
140 | 141 | "pool2d: spatial dim must be static when using adaptive pool"); |
141 | | - uint64_t pool_size_Height, pool_size_Width; |
142 | | - uint64_t input_h = input_shape[input_rank - 2].get_length(); |
143 | | - uint64_t input_w = input_shape[input_rank - 1].get_length(); |
| 142 | + auto pool_size = std::vector<int64_t>(2, 0); |
144 | 143 |
|
145 | 144 | if (kernel_shape.size() == 1) |
146 | 145 | { |
147 | 146 | // Not tested: implemented according to spec, but can't generate real |
148 | 147 | // model to test |
149 | | - pool_size_Height = pool_size_Width = kernel_shape[0]; |
| 148 | + pool_size[0] = pool_size[1] = kernel_shape[0]; |
150 | 149 | } |
151 | 150 | else |
152 | 151 | { |
153 | | - pool_size_Height = kernel_shape[0]; |
154 | | - pool_size_Width = kernel_shape[1]; |
| 152 | + pool_size[0] = kernel_shape[0]; |
| 153 | + pool_size[1] = kernel_shape[1]; |
155 | 154 | } |
156 | 155 |
|
157 | | - uint64_t stride_h = int64_t(input_h / pool_size_Height); |
158 | | - uint64_t stride_w = int64_t(input_w / pool_size_Width); |
159 | | - uint64_t kernel_h = input_h - (pool_size_Height - 1) * stride_h; |
160 | | - uint64_t kernel_w = input_w - (pool_size_Width - 1) * stride_w; |
161 | | - |
162 | | - PDPD_ASSERT(stride_h >= 1 && stride_w >= 1, |
163 | | - "Pool2d stride must be greater than 1"); |
| 156 | + const Output<ngraph::Node> output_shape = ngraph::opset6::Constant::create( |
| 157 | + ngraph::element::i64, {pool_size.size()}, pool_size); |
164 | 158 |
|
165 | 159 | if (pooling_type == "max") |
166 | 160 | { |
167 | | - return node.default_single_output_mapping( |
168 | | - {std::make_shared<ngraph::opset6::MaxPool>( |
169 | | - data, |
170 | | - ngraph::Strides{stride_h, stride_w}, |
171 | | - pad_begin, |
172 | | - pad_end, |
173 | | - ngraph::Shape{kernel_h, kernel_w}, |
174 | | - rounding_type, |
175 | | - auto_pad)}, |
176 | | - {"Out"}); |
| 161 | + std::vector<Output<Node>> pool_outputs; |
| 162 | + pool_outputs = std::make_shared<ngraph::opset8::AdaptiveMaxPool>( |
| 163 | + data, output_shape, ngraph::element::i64) |
| 164 | + ->outputs(); |
| 165 | + NamedOutputs outputs; |
| 166 | + outputs["Out"] = {pool_outputs[0]}; |
| 167 | + outputs["Mask"] = {pool_outputs[1]}; |
| 168 | + return outputs; |
177 | 169 | } |
178 | 170 | else |
179 | 171 | { |
180 | | - bool exclude_pad = node.get_attribute<bool>("exclusive", false); |
181 | 172 | return node.default_single_output_mapping( |
182 | | - {std::make_shared<ngraph::opset6::AvgPool>( |
183 | | - data, |
184 | | - ngraph::Strides{stride_h, stride_w}, |
185 | | - pad_begin, |
186 | | - pad_end, |
187 | | - ngraph::Shape{kernel_h, kernel_w}, |
188 | | - exclude_pad, |
189 | | - rounding_type, |
190 | | - auto_pad)}, |
| 173 | + {std::make_shared<ngraph::opset8::AdaptiveAvgPool>(data, |
| 174 | + output_shape)}, |
191 | 175 | {"Out"}); |
192 | 176 | } |
193 | 177 | } |
|
0 commit comments