Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 6e405b8

Browse files
authored
Atan2 op (#3859)
* Atan2 op
1 parent 1d9a495 commit 6e405b8

File tree

18 files changed

+301
-0
lines changed

18 files changed

+301
-0
lines changed

src/ngraph/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ set (SRC
106106
op/asin.hpp
107107
op/atan.cpp
108108
op/atan.hpp
109+
op/atan2.cpp
110+
op/atan2.hpp
109111
op/avg_pool.cpp
110112
op/avg_pool.hpp
111113
op/batch_norm.cpp

src/ngraph/ngraph.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
#include "ngraph/op/argmin.hpp"
7070
#include "ngraph/op/asin.hpp"
7171
#include "ngraph/op/atan.hpp"
72+
#include "ngraph/op/atan2.hpp"
7273
#include "ngraph/op/avg_pool.hpp"
7374
#include "ngraph/op/batch_norm.hpp"
7475
#include "ngraph/op/broadcast.hpp"

src/ngraph/op/atan2.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//*****************************************************************************
2+
// Copyright 2017-2019 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#include "ngraph/op/atan2.hpp"
18+
19+
using namespace std;
20+
using namespace ngraph;
21+
22+
const string op::Atan2::type_name{"Atan2"};
23+
24+
op::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBroadcastSpec& autob)
25+
: BinaryElementwiseArithmetic(y, x, autob)
26+
{
27+
constructor_validate_and_infer_types();
28+
}
29+
30+
shared_ptr<Node> op::Atan2::copy_with_new_args(const NodeVector& new_args) const
31+
{
32+
check_new_args_count(this, new_args);
33+
return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob());
34+
}
35+
36+
void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
37+
{
38+
if (get_autob().m_type != op::AutoBroadcastType::NONE)
39+
{
40+
throw ngraph_error("Autodiff not supported with auto broadcasting");
41+
}
42+
throw ngraph_error("Autodiff not supported for Atan2");
43+
}

src/ngraph/op/atan2.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//*****************************************************************************
2+
// Copyright 2017-2019 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#pragma once
18+
19+
#include <memory>
20+
21+
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
22+
23+
namespace ngraph
24+
{
25+
namespace op
26+
{
27+
/// \brief Elementwise full arctan operation
28+
class Atan2 : public util::BinaryElementwiseArithmetic
29+
{
30+
public:
31+
NGRAPH_API
32+
static const std::string type_name;
33+
const std::string& description() const override { return type_name; }
34+
Atan2() = default;
35+
36+
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed order).
37+
///
38+
/// \param y
39+
/// \param x
40+
Atan2(const Output<Node>& y,
41+
const Output<Node>& x,
42+
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
43+
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
44+
45+
protected:
46+
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
47+
const NodeVector& deltas) override;
48+
};
49+
}
50+
}

src/ngraph/op/op_tbl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ NGRAPH_OP(ArgMax, ngraph::op)
6161
NGRAPH_OP(ArgMin, ngraph::op)
6262
NGRAPH_OP(Asin, ngraph::op)
6363
NGRAPH_OP(Atan, ngraph::op)
64+
NGRAPH_OP(Atan2, ngraph::op)
6465
NGRAPH_OP(AvgPool, ngraph::op)
6566
NGRAPH_OP(AvgPoolBackprop, ngraph::op)
6667
NGRAPH_OP(BatchMatMul, ngraph::op)

src/ngraph/pass/cse.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "ngraph/op/add.hpp"
3030
#include "ngraph/op/asin.hpp"
3131
#include "ngraph/op/atan.hpp"
32+
#include "ngraph/op/atan2.hpp"
3233
#include "ngraph/op/broadcast.hpp"
3334
#include "ngraph/op/ceiling.hpp"
3435
#include "ngraph/op/constant.hpp"
@@ -151,6 +152,7 @@ static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node
151152
{TI(op::Acos), cse_unarywise},
152153
{TI(op::Asin), cse_unarywise},
153154
{TI(op::Atan), cse_unarywise},
155+
{TI(op::Atan2), cse_binarywise},
154156
{TI(op::Ceiling), cse_unarywise},
155157
{TI(op::Constant), cse_constant},
156158
{TI(op::Cos), cse_unarywise},

src/ngraph/runtime/cpu/cpu_builder.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "ngraph/op/and.hpp"
3232
#include "ngraph/op/asin.hpp"
3333
#include "ngraph/op/atan.hpp"
34+
#include "ngraph/op/atan2.hpp"
3435
#include "ngraph/op/ceiling.hpp"
3536
#include "ngraph/op/constant.hpp"
3637
#include "ngraph/op/cos.hpp"
@@ -75,6 +76,7 @@
7576
#include "ngraph/runtime/cpu/kernel/and.hpp"
7677
#include "ngraph/runtime/cpu/kernel/asin.hpp"
7778
#include "ngraph/runtime/cpu/kernel/atan.hpp"
79+
#include "ngraph/runtime/cpu/kernel/atan2.hpp"
7880
#include "ngraph/runtime/cpu/kernel/broadcast.hpp"
7981
#include "ngraph/runtime/cpu/kernel/ceil.hpp"
8082
#include "ngraph/runtime/cpu/kernel/cos.hpp"
@@ -311,6 +313,12 @@ namespace ngraph
311313
BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::atan);
312314
}
313315

316+
template <>
317+
void Builder::BUILDER_DECL(ngraph::op::Atan2)
318+
{
319+
BUILD_BINARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::atan2);
320+
}
321+
314322
template <>
315323
void Builder::BUILDER_DECL(ngraph::op::Ceiling)
316324
{
@@ -628,6 +636,7 @@ namespace ngraph
628636
REGISTER_OP_BUILDER(Acos);
629637
REGISTER_OP_BUILDER(Asin);
630638
REGISTER_OP_BUILDER(Atan);
639+
REGISTER_OP_BUILDER(Atan2);
631640
REGISTER_OP_BUILDER(Ceiling);
632641
REGISTER_OP_BUILDER(Cos);
633642
REGISTER_OP_BUILDER(Cosh)

src/ngraph/runtime/cpu/cpu_emitter.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "ngraph/op/argmin.hpp"
3535
#include "ngraph/op/asin.hpp"
3636
#include "ngraph/op/atan.hpp"
37+
#include "ngraph/op/atan2.hpp"
3738
#include "ngraph/op/avg_pool.hpp"
3839
#include "ngraph/op/batch_norm.hpp"
3940
#include "ngraph/op/broadcast.hpp"
@@ -1829,6 +1830,21 @@ namespace ngraph
18291830
writer.block_end();
18301831
}
18311832

1833+
template <>
1834+
void CPU_Emitter::EMITTER_DECL(ngraph::op::Atan2)
1835+
{
1836+
(void)external_function;
1837+
(void)node;
1838+
writer.block_begin();
1839+
writer << "#pragma omp parallel for\n";
1840+
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
1841+
writer.block_begin();
1842+
writer << out[0].get_name() << "[i] = atan2(" << args[0].get_name() << ", "
1843+
<< args[1].get_name() << "[i]);\n";
1844+
writer.block_end();
1845+
writer.block_end();
1846+
}
1847+
18321848
static void emitArgMinArgMax(const std::vector<TensorViewWrapper>& args,
18331849
const std::vector<TensorViewWrapper>& out,
18341850
size_t reduction_axis,

src/ngraph/runtime/cpu/cpu_external_function.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include "ngraph/op/argmin.hpp"
5757
#include "ngraph/op/asin.hpp"
5858
#include "ngraph/op/atan.hpp"
59+
#include "ngraph/op/atan2.hpp"
5960
#include "ngraph/op/avg_pool.hpp"
6061
#include "ngraph/op/batch_norm.hpp"
6162
#include "ngraph/op/broadcast.hpp"
@@ -365,6 +366,7 @@ static const runtime::cpu::OpMap dispatcher{
365366
{TI(ngraph::op::ArgMax), &runtime::cpu::CPU_Emitter::emit<op::ArgMax>},
366367
{TI(ngraph::op::Acos), &runtime::cpu::CPU_Emitter::emit<op::Acos>},
367368
{TI(ngraph::op::Atan), &runtime::cpu::CPU_Emitter::emit<op::Atan>},
369+
{TI(ngraph::op::Atan2), &runtime::cpu::CPU_Emitter::emit<op::Atan2>},
368370
{TI(ngraph::op::ReplaceSlice), &runtime::cpu::CPU_Emitter::emit<op::ReplaceSlice>},
369371
{TI(ngraph::op::UpdateSlice), &runtime::cpu::CPU_Emitter::emit<op::UpdateSlice>},
370372
{TI(ngraph::op::OneHot), &runtime::cpu::CPU_Emitter::emit<op::OneHot>},
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//*****************************************************************************
2+
// Copyright 2017-2019 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#pragma once
18+
19+
#include <cmath>
20+
21+
#define EIGEN_USE_THREADS
22+
#include <unsupported/Eigen/CXX11/Tensor>
23+
24+
#include "ngraph/runtime/cpu/cpu_executor.hpp"
25+
26+
namespace ngraph
27+
{
28+
namespace runtime
29+
{
30+
namespace cpu
31+
{
32+
namespace kernel
33+
{
34+
template <typename ElementType>
35+
void atan2(void* input0, void* input1, void* output, size_t count, int arena)
36+
{
37+
Eigen::array<Eigen::Index, 1> out_dims, in_dims;
38+
39+
out_dims[0] = in_dims[0] = count;
40+
41+
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out(
42+
static_cast<ElementType*>(output), out_dims);
43+
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in0(
44+
static_cast<ElementType*>(input0), in_dims);
45+
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in1(
46+
static_cast<ElementType*>(input1), in_dims);
47+
48+
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
49+
in0.binaryExpr(in1, [](ElementType y, ElementType x) {
50+
return static_cast<ElementType>(std::atan2(y, x));
51+
});
52+
}
53+
}
54+
}
55+
}
56+
}

0 commit comments

Comments
 (0)