Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 22e9847

Browse files
ayzhuangdiyessi
authored andcommitted
Migrate #3736 from master. (#3745)
* Use Eigen kernel for REFLECT mode Pad. * Do not call is_optimized_et.
1 parent a68a3fa commit 22e9847

File tree

5 files changed

+137
-35
lines changed

5 files changed

+137
-35
lines changed

src/ngraph/runtime/cpu/builder/pad.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ namespace ngraph
5050
auto padding_above = pad->get_padding_above();
5151
auto pad_mode = pad->get_pad_mode();
5252

53-
if (pad_mode == ngraph::op::PadMode::CONSTANT)
53+
if (pad_mode == ngraph::op::PadMode::CONSTANT ||
54+
pad_mode == ngraph::op::PadMode::REFLECT)
5455
{
5556
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
5657

@@ -65,6 +66,7 @@ namespace ngraph
6566
out_shape,
6667
padding_below,
6768
padding_above,
69+
pad_mode,
6870
arg_buffer_index,
6971
padding_value_index,
7072
out_buffer_index](CPURuntimeContext* ctx,
@@ -76,6 +78,7 @@ namespace ngraph
7678
out_shape,
7779
CoordinateDiff(padding_below.begin(), padding_below.end()),
7880
CoordinateDiff(padding_above.begin(), padding_above.end()),
81+
pad_mode,
7982
ectx->arena);
8083
};
8184
functors.emplace_back(functor);
@@ -123,7 +126,8 @@ namespace ngraph
123126
auto padding_above = pad->get_padding_above();
124127
auto pad_mode = pad->get_pad_mode();
125128

126-
if (pad_mode == ngraph::op::PadMode::CONSTANT)
129+
if (pad_mode == ngraph::op::PadMode::CONSTANT ||
130+
pad_mode == ngraph::op::PadMode::REFLECT)
127131
{
128132
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
129133

@@ -132,17 +136,19 @@ namespace ngraph
132136
arg_shape.size(),
133137
runtime::cpu::kernel::pad_and_slice);
134138

135-
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above](
136-
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
137-
kernel(inputs[0],
138-
outputs[0],
139-
inputs[1],
140-
arg_shape,
141-
out_shape,
142-
CoordinateDiff(padding_below.begin(), padding_below.end()),
143-
CoordinateDiff(padding_above.begin(), padding_above.end()),
144-
0);
145-
};
139+
auto functor =
140+
[kernel, arg_shape, out_shape, padding_below, padding_above, pad_mode](
141+
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
142+
kernel(inputs[0],
143+
outputs[0],
144+
inputs[1],
145+
arg_shape,
146+
out_shape,
147+
CoordinateDiff(padding_below.begin(), padding_below.end()),
148+
CoordinateDiff(padding_above.begin(), padding_above.end()),
149+
pad_mode,
150+
0);
151+
};
146152
return functor;
147153
}
148154
else

src/ngraph/runtime/cpu/cpu_emitter.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,8 +3107,26 @@ namespace ngraph
31073107
auto arg0_shape = args[0].get_shape();
31083108
auto result_shape = out[0].get_shape();
31093109

3110+
std::string pad_mode_string;
3111+
switch (pad->get_pad_mode())
3112+
{
3113+
case ngraph::op::PadMode::CONSTANT:
3114+
pad_mode_string = "ngraph::op::PadMode::CONSTANT";
3115+
break;
3116+
case ngraph::op::PadMode::EDGE:
3117+
pad_mode_string = "ngraph::op::PadMode::EDGE";
3118+
break;
3119+
case ngraph::op::PadMode::REFLECT:
3120+
pad_mode_string = "ngraph::op::PadMode::REFLECT";
3121+
break;
3122+
case ngraph::op::PadMode::SYMMETRIC:
3123+
pad_mode_string = "ngraph::op::PadMode::SYMMETRIC";
3124+
break;
3125+
}
3126+
31103127
if (arg0_shape.size() == 4 && args[0].get_element_type() == element::f32 &&
3111-
pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT)
3128+
(pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT ||
3129+
pad->get_pad_mode() == ngraph::op::PadMode::REFLECT))
31123130
{
31133131
writer << "cpu::kernel::pad_4d_float32(" << args[0].get_name() << ",\n"
31143132
<< " " << out[0].get_name() << ",\n"
@@ -3118,26 +3136,12 @@ namespace ngraph
31183136
<< " {" << join(pad->get_padding_below())
31193137
<< "},\n"
31203138
<< " {" << join(pad->get_padding_above())
3121-
<< "}, 0);\n";
3139+
<< "}, \n"
3140+
<< " " << pad_mode_string << ",\n"
3141+
<< " 0);\n";
31223142
}
31233143
else
31243144
{
3125-
std::string pad_mode_string;
3126-
switch (pad->get_pad_mode())
3127-
{
3128-
case ngraph::op::PadMode::CONSTANT:
3129-
pad_mode_string = "ngraph::op::PadMode::CONSTANT";
3130-
break;
3131-
case ngraph::op::PadMode::EDGE:
3132-
pad_mode_string = "ngraph::op::PadMode::EDGE";
3133-
break;
3134-
case ngraph::op::PadMode::REFLECT:
3135-
pad_mode_string = "ngraph::op::PadMode::REFLECT";
3136-
break;
3137-
case ngraph::op::PadMode::SYMMETRIC:
3138-
pad_mode_string = "ngraph::op::PadMode::SYMMETRIC";
3139-
break;
3140-
}
31413145
writer << "reference::pad<" << out[0].get_type() << ">(" << args[0].get_name()
31423146
<< ",\n";
31433147
writer << " " << args[1].get_name() << ",\n";

src/ngraph/runtime/cpu/cpu_kernels.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <random>
2222
#include <vector>
2323

24+
#include "ngraph/op/pad.hpp"
25+
2426
// CBLAS types and wrappers
2527

2628
namespace cblas
@@ -146,6 +148,7 @@ namespace ngraph
146148
const Shape& output_shape,
147149
const CoordinateDiff& padding_below,
148150
const CoordinateDiff& padding_above,
151+
const ngraph::op::PadMode pad_mode,
149152
int arena);
150153

151154
void reduce_sum_all_1d_float32(float* input,

src/ngraph/runtime/cpu/kernel/pad.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace ngraph
3131
const Shape& output_shape,
3232
const CoordinateDiff& padding_below,
3333
const CoordinateDiff& padding_above,
34+
const ngraph::op::PadMode pad_mode,
3435
int arena)
3536
{
3637
pad_and_slice<float, 4>(input,
@@ -40,6 +41,7 @@ namespace ngraph
4041
output_shape,
4142
padding_below,
4243
padding_above,
44+
pad_mode,
4345
arena);
4446
}
4547
}

src/ngraph/runtime/cpu/kernel/pad.hpp

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,19 @@ namespace ngraph
6767
const Shape& output_shape,
6868
const CoordinateDiff& padding_below,
6969
const CoordinateDiff& padding_above,
70+
const ngraph::op::PadMode pad_mode,
7071
int arena)
7172
{
72-
Eigen::array<Eigen::Index, Rank> out_dims, in_dims;
73+
Eigen::array<Eigen::Index, Rank> out_dims, in_dims, temp_dims;
7374
Eigen::array<Eigen::IndexPair<size_t>, Rank> padding;
7475
Eigen::array<Eigen::Index, Rank> indices;
7576

77+
bool has_negative_below_padding = false;
78+
7679
for (int i = 0; i < Rank; i++)
7780
{
7881
out_dims[i] = output_shape[i];
82+
temp_dims[i] = output_shape[i];
7983
in_dims[i] = input_shape[i];
8084

8185
padding[i] = {
@@ -88,6 +92,8 @@ namespace ngraph
8892
{
8993
NGRAPH_CHECK(padding_below[i] > INT_MIN);
9094
indices[i] = -padding_below[i];
95+
temp_dims[i] -= padding_below[i];
96+
has_negative_below_padding = true;
9197
}
9298
else
9399
{
@@ -97,12 +103,93 @@ namespace ngraph
97103

98104
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
99105
static_cast<ElementType*>(output), out_dims);
106+
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> temp(
107+
static_cast<ElementType*>(output), temp_dims);
100108
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in(
101109
static_cast<ElementType*>(input), in_dims);
102110

103-
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
104-
in.pad(padding, *static_cast<ElementType*>(pad_value))
105-
.slice(indices, out_dims);
111+
if (pad_mode == ngraph::op::PadMode::CONSTANT)
112+
{
113+
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
114+
arena)) = in.pad(padding, *static_cast<ElementType*>(pad_value))
115+
.slice(indices, out_dims);
116+
}
117+
else
118+
{
119+
// clang-format off
120+
// PadMode::REFLECT
121+
// We should have dim >= 2 for each dim.
122+
// Example:
123+
//
124+
// Input shape: [4]
125+
// Padding: 6 below, 13 above
126+
// Output shape: [23]
127+
//
128+
// Input: 1 2 3 4
129+
// Expected output: 1 2 3 4 3 2 1 2 3 4 3 2 1 2 3 4 3 2 1 2 3 4 3
130+
// Pattern: ... | original n elements | middle (n - 2) elements of original n in reverse order |
131+
// original n elements | middle (n - 2) elements of original n in reverse order | ...
132+
// | 1 2 3 4 | 3 2 | 1 2 3 4 | 3 2 | 1 2 3 4 | 3 2 | 1 2 3 4 | 3
133+
// clang-format on
134+
auto generator =
135+
[&](const Eigen::array<Eigen::DenseIndex, Rank>& out_index) {
136+
Eigen::array<Eigen::DenseIndex, Rank> in_index;
137+
for (size_t i = 0; i < Rank; i++)
138+
{
139+
auto origin_length = in_dims[i];
140+
auto p_below = padding_below[i] >= 0 ? padding_below[i] : 0;
141+
if (out_index[i] < p_below)
142+
{
143+
// padding below
144+
auto reverse = p_below - out_index[i];
145+
auto res = reverse % (origin_length * 2 - 2);
146+
if (res <= origin_length - 2)
147+
{
148+
// copy one of the middle n-2 items
149+
in_index[i] = res;
150+
}
151+
else
152+
{
153+
// copy one of the n items
154+
in_index[i] = origin_length * 2 - 2 - res;
155+
}
156+
}
157+
else if (out_index[i] < in_dims[i] + p_below)
158+
{
159+
// original
160+
in_index[i] = out_index[i] - p_below;
161+
}
162+
else
163+
{
164+
// padding above
165+
auto pos = out_index[i] - in_dims[i] - p_below;
166+
auto res = pos % (origin_length * 2 - 2);
167+
if (res < origin_length - 2)
168+
{
169+
// copy one of the middle n-2 items
170+
in_index[i] = origin_length - 2 - res;
171+
}
172+
else
173+
{
174+
// copy one of the n items
175+
in_index[i] = res - (origin_length - 2);
176+
}
177+
}
178+
}
179+
return in(in_index);
180+
};
181+
182+
if (has_negative_below_padding)
183+
{
184+
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
185+
arena)) = temp.generate(generator).slice(indices, out_dims);
186+
}
187+
else
188+
{
189+
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
190+
arena)) = out.generate(generator);
191+
}
192+
}
106193
}
107194

108195
template <typename ElementType>

0 commit comments

Comments
 (0)