Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 9ffb514

Browse files
tsochaKrovatkin
authored andcommitted
[Py] Enable ngraph-cpp ops in Python API (#820)
* Enable BatchNorm op * Enable function call op * Enable get output element op
1 parent eec1922 commit 9ffb514

File tree

4 files changed

+68
-5
lines changed

4 files changed

+68
-5
lines changed

python/ngraph/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ngraph.ops import asin
2323
from ngraph.ops import atan
2424
from ngraph.ops import avg_pool
25+
from ngraph.ops import batch_norm
2526
from ngraph.ops import broadcast
2627
from ngraph.ops import ceiling
2728
from ngraph.ops import ceiling as ceil
@@ -35,7 +36,9 @@
3536
from ngraph.ops import dot
3637
from ngraph.ops import equal
3738
from ngraph.ops import exp
39+
from ngraph.ops import function_call
3840
from ngraph.ops import floor
41+
from ngraph.ops import get_output_element
3942
from ngraph.ops import greater
4043
from ngraph.ops import greater_eq
4144
from ngraph.ops import less

python/ngraph/ops.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, NodeVector, \
2121
Shape, Strides
2222

23-
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
24-
Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq,\
25-
Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \
26-
OneHot, Pad, Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, \
27-
Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh
23+
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling,\
24+
Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
25+
FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
26+
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, Relu, \
27+
ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, \
28+
Tan, Tanh
2829

2930
from typing import Iterable, List
3031

@@ -761,3 +762,33 @@ def reverse(node, reversed_axes, name=None): # type: (Node, List[int], str) ->
761762
:return: The new node with reversed axes.
762763
"""
763764
return Reverse(node, AxisSet(reversed_axes))
765+
766+
767+
@nameable_op
768+
def batch_norm(eps, # type: float
769+
gamma, # type: Node
770+
beta, # type: Node
771+
data, # type: Node
772+
mean=None, # type: Node
773+
variance=None, # type: Node
774+
training=False, # type: bool
775+
name=None, # type: str
776+
):
777+
# type: (...) -> Node
778+
"""Return batch normalization node."""
779+
if mean is None and variance is None:
780+
return BatchNorm(eps, gamma, beta, data)
781+
else:
782+
return BatchNorm(eps, gamma, beta, data, mean, variance, training)
783+
784+
785+
@nameable_op
786+
def function_call(function_to_call, args): # type: (Node, NodeVector) -> Node
787+
"""Return Function call op."""
788+
return FunctionCall(function_to_call, args)
789+
790+
791+
@nameable_op
792+
def get_output_element(data, index): # type: (Node, int) -> Node
793+
"""Return the `n`th element of the input tuple."""
794+
return GetOutputElement(data, index)

python/pyngraph/ops/batch_norm.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ void regclass_pyngraph_op_BatchNorm(py::module m)
3333
const std::shared_ptr<ngraph::Node>&,
3434
const std::shared_ptr<ngraph::Node>&,
3535
const std::shared_ptr<ngraph::Node>&>());
36+
37+
batch_norm.def(py::init<double,
38+
const std::shared_ptr<ngraph::Node>&,
39+
const std::shared_ptr<ngraph::Node>&,
40+
const std::shared_ptr<ngraph::Node>&,
41+
const std::shared_ptr<ngraph::Node>&,
42+
const std::shared_ptr<ngraph::Node>&,
43+
bool&>());
3644
}
3745

3846
void regclass_pyngraph_op_BatchNormBackprop(py::module m)

python/test/ngraph/test_basic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import ngraph as ng
2121
from test.ngraph.util import get_runtime, run_op_node
22+
from ngraph.impl import Function, NodeVector
2223

2324

2425
@pytest.mark.parametrize('dtype', [np.float32, np.float64,
@@ -48,6 +49,26 @@ def test_simple_computation_on_ndarrays(dtype):
4849
assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype))
4950

5051

52+
def test_function_call():
53+
runtime = get_runtime()
54+
dtype = int
55+
shape = [2, 2]
56+
parameter_a = ng.parameter(shape, dtype=dtype, name='A')
57+
parameter_b = ng.parameter(shape, dtype=dtype, name='B')
58+
parameter_c = ng.parameter(shape, dtype=dtype, name='C')
59+
parameter_list = [parameter_a, parameter_b, parameter_c]
60+
ops = ((parameter_a + parameter_b) * parameter_c)
61+
func = Function(NodeVector([ops]), parameter_list, 'addmul')
62+
fc = ng.function_call(func, NodeVector(parameter_list))
63+
computation = runtime.computation(fc, parameter_a, parameter_b, parameter_c)
64+
65+
value_a = np.array([[1, 2], [3, 4]], dtype=dtype)
66+
value_b = np.array([[5, 6], [7, 8]], dtype=dtype)
67+
value_c = np.array([[9, 10], [11, 12]], dtype=dtype)
68+
result = computation(value_a, value_b, value_c)
69+
assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype))
70+
71+
5172
def test_serialization():
5273
dtype = np.float32
5374
manager_name = pytest.config.getoption('backend', default='CPU')

0 commit comments

Comments
 (0)