Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit feeaed5

Browse files
silee2rkimballn1
authored andcommitted
[v0.1.0] Silee2/python binding (#674)
* Remove unsupported linker flags for Mac build. * Restructure python binding. Put low level direct wrapper for ngraph c++ API into ngraph/impl Move high level API from ngraph_api to ngraph * Move CMakeLists.txt to its own PR
1 parent fad8569 commit feeaed5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+266
-265
lines changed

doc/examples/onnx_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
# Convert a serialized ONNX model to an ngraph model
2222
from ngraph_onnx.onnx_importer.importer import import_onnx_model
2323
ng_model = import_onnx_model(onnx_protobuf)[0]
24-
2524

26-
# Using ngraph_api, create a callable computation object
27-
import ngraph_api as ng
25+
26+
# Using an ngraph runtime (CPU backend), create a callable computation
27+
import ngraph as ng
2828
runtime = ng.runtime(manager_name='CPU')
2929
resnet = runtime.computation(ng_model['output'], *ng_model['inputs'])
3030

python/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pip install -r test_requirements.txt
4444
Then run a test.
4545
```
4646
pytest test/test_ops.py
47-
pytest test/ngraph_api/*
47+
pytest test/ngraph/*
4848
```
4949

5050
## Running tests with tox
@@ -70,7 +70,7 @@ You can run tests using only Python 3 or 2 using the `-e` (environment) switch:
7070

7171
You can check styles in a particular code directory by specifying the path:
7272

73-
tox ngraph_api/
73+
tox ngraph/
7474

7575
In case of problems, try to recreate the virtual environments by deleting the `.tox` directory:
7676

python/examples/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""Usage example for the ngraph Pythonic API."""
1717

1818
import numpy as np
19-
import ngraph_api as ng
19+
import ngraph as ng
2020

2121
shape = [2, 2]
2222
A = ng.parameter(shape, name='A')

python/examples/mnist_mlp.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
# ******************************************************************************
17-
from ngraph import Type, Function
18-
from ngraph import Node
19-
from ngraph.op import Parameter, Maximum, Reshape, Dot, Broadcast
20-
from ngraph.op import Constant, Exp, Log, Sum
21-
from ngraph.op import Greater, Convert, Reduce
22-
from ngraph.op import OneHot
17+
from ngraph.impl import Type, Function
18+
from ngraph.impl import Node, Shape, AxisVector, AxisSet
19+
from ngraph.impl.op import Parameter, Maximum, Reshape, Dot, Broadcast
20+
from ngraph.impl.op import Constant, Exp, Log, Sum
21+
from ngraph.impl.op import Greater, Convert, Reduce
22+
from ngraph.impl.op import OneHot
2323

2424
from typing import List, Dict, Set
2525

@@ -29,13 +29,13 @@
2929
bz = 53
3030
lr = 0.2
3131

32-
Input = Parameter(float_element_type, [bz, 28, 28])
33-
Label = Parameter(int_element_type, [bz])
34-
LabelOneHot = Convert((OneHot(Label, [bz, 10], 1)), float_element_type)
32+
Input = Parameter(float_element_type, Shape([bz, 28, 28]))
33+
Label = Parameter(int_element_type, Shape([bz]))
34+
LabelOneHot = Convert((OneHot(Label, Shape([bz, 10]), 1)), float_element_type)
3535

36-
MaxParam1 = Parameter(float_element_type, [])
37-
MaxParam2 = Parameter(float_element_type, [])
38-
MaxFn = Function([Maximum(MaxParam1, MaxParam2)],
36+
MaxParam1 = Parameter(float_element_type, Shape([]))
37+
MaxParam2 = Parameter(float_element_type, Shape([]))
38+
MaxFn = Function(Maximum(MaxParam1, MaxParam2),
3939
[MaxParam1, MaxParam2],
4040
'mnist')
4141

@@ -44,10 +44,10 @@ def make_scalar_constant(elem_type, scalar, shape=None, axis_set=None):
4444
# type: (int, float, List[int], Set[int]) -> float
4545
"""Create a Constant node for scalar value."""
4646
if shape is None:
47-
shape = []
47+
shape = Shape([])
4848
if axis_set is None:
49-
axis_set = set()
50-
scalar_shape = [] # type: List[int]
49+
axis_set = AxisSet(set())
50+
scalar_shape = Shape([]) # type: List[int]
5151
constant_op = Constant(elem_type, scalar_shape, [scalar])
5252
constant_broadcast = Broadcast(constant_op, shape, axis_set)
5353
return constant_broadcast
@@ -60,7 +60,7 @@ def make_float32_constant(scalar, shape=None, axis_set=None):
6060
shape = []
6161
if axis_set is None:
6262
axis_set = set()
63-
return make_scalar_constant(Type.f32, scalar, shape, axis_set)
63+
return make_scalar_constant(Type.f32, scalar, Shape(shape), AxisSet(axis_set))
6464

6565

6666
def make_float32_constant_like(scalar, op): # type: (float, Node) -> float
@@ -69,7 +69,7 @@ def make_float32_constant_like(scalar, op): # type: (float, Node) -> float
6969
shape = op.get_shape()
7070
for i in range(len(shape)):
7171
v.add(i)
72-
return make_float32_constant(scalar, shape, v)
72+
return make_float32_constant(scalar, Shape(shape), AxisSet(v))
7373

7474

7575
def transpose(op, order): # type: (Node, List[int]) -> Node
@@ -78,7 +78,7 @@ def transpose(op, order): # type: (Node, List[int]) -> Node
7878
for i in range(len(order)):
7979
v.append(op.get_shape()[order[i]])
8080
new_shape = v
81-
return Reshape(op, order, new_shape)
81+
return Reshape(op, AxisVector(order), Shape(new_shape))
8282

8383

8484
def relu(op): # type: (Node) -> Node
@@ -87,45 +87,45 @@ def relu(op): # type: (Node) -> Node
8787

8888

8989
# Flatten
90-
X1 = Reshape(Input, [0, 1, 2], [bz, 784])
90+
X1 = Reshape(Input, AxisVector([0, 1, 2]), Shape([bz, 784]))
9191

9292
# Normalize
9393
X2 = X1 / make_float32_constant_like(255., X1)
9494

9595
# Affine 1
96-
W1 = Parameter(float_element_type, [784, 100])
97-
b1 = Parameter(float_element_type, [100])
98-
X3 = Dot(X2, W1) + Broadcast(b1, [bz, 100], {0})
96+
W1 = Parameter(float_element_type, Shape([784, 100]))
97+
b1 = Parameter(float_element_type, Shape([100]))
98+
X3 = Dot(X2, W1) + Broadcast(b1, Shape([bz, 100]), AxisSet({0}))
9999
X4 = relu(X3)
100100

101101
# Affine 2
102-
W2 = Parameter(float_element_type, [100, 10])
103-
b2 = Parameter(float_element_type, [10])
104-
X5 = Dot(X4, W2) + Broadcast(b2, [bz, 10], {0})
102+
W2 = Parameter(float_element_type, Shape([100, 10]))
103+
b2 = Parameter(float_element_type, Shape([10]))
104+
X5 = Dot(X4, W2) + Broadcast(b2, Shape([bz, 10]), AxisSet({0}))
105105

106106
# Softmax
107107
Logits = X5
108108
Exp = Exp(Logits)
109-
Max = Reduce(Exp, make_float32_constant(0., [], set()), MaxFn, {1})
110-
MaxBroadcast = Broadcast(Max, [bz, 10], {1})
109+
Max = Reduce(Exp, make_float32_constant(0., [], set()), MaxFn, AxisSet({1}))
110+
MaxBroadcast = Broadcast(Max, Shape([bz, 10]), AxisSet({1}))
111111
Softmax = Exp / MaxBroadcast
112112

113113
# Loss
114114
LogSoftmax = Log(Softmax)
115-
Loss = Sum(LogSoftmax * LabelOneHot, {0, 1}) / make_float32_constant(float(bz), [], set())
115+
Loss = Sum(LogSoftmax * LabelOneHot, AxisSet({0, 1})) / make_float32_constant(float(bz), [], set())
116116

117117
# Derivatives
118118
dLogits = Softmax - LabelOneHot
119119
dX5 = dLogits
120120

121-
dX4 = Dot(dX5, transpose(W2, [1, 0]))
122-
dW2 = Dot(transpose(X4, [1, 0]), dX5)
123-
db2 = Sum(dX5, {0})
121+
dX4 = Dot(dX5, transpose(W2, Shape([1, 0])))
122+
dW2 = Dot(transpose(X4, Shape([1, 0])), dX5)
123+
db2 = Sum(dX5, AxisSet({0}))
124124

125125
dX3 = Convert((Greater(X3, make_float32_constant(0., [bz, 100], {0, 1}))), float_element_type) * dX4
126-
dX2 = Dot(dX3, transpose(W1, [1, 0]))
127-
dW1 = Dot(transpose(X2, [1, 0]), dX3)
128-
db1 = Sum(dX3, {0})
126+
dX2 = Dot(dX3, transpose(W1, Shape([1, 0])))
127+
dW1 = Dot(transpose(X2, Shape([1, 0])), dX3)
128+
db1 = Sum(dX3, AxisSet({0}))
129129

130130
nW1 = W1 - make_float32_constant_like(lr, dW1) * dW1
131131
nb1 = b1 - make_float32_constant_like(lr, db1) * db1

python/ngraph/__init__.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ******************************************************************************
2-
# Copyright 2017-2018 Intel Corporation
2+
# Copyright 2018 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -13,36 +13,43 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ******************************************************************************
16-
"""
17-
Package: ngraph
18-
Low level wrappers for the nGraph c++ api.
19-
"""
16+
"""ngraph module namespace, exposing factory functions for all ops and other classes."""
2017

21-
# flake8: noqa
18+
from ngraph.ops import absolute
19+
from ngraph.ops import absolute as abs
20+
from ngraph.ops import add
21+
from ngraph.ops import avg_pool
22+
from ngraph.ops import broadcast
23+
from ngraph.ops import ceiling
24+
from ngraph.ops import ceiling as ceil
25+
from ngraph.ops import constant
26+
from ngraph.ops import convert
27+
from ngraph.ops import convolution
28+
from ngraph.ops import divide
29+
from ngraph.ops import dot
30+
from ngraph.ops import equal
31+
from ngraph.ops import exp
32+
from ngraph.ops import floor
33+
from ngraph.ops import greater
34+
from ngraph.ops import greater_eq
35+
from ngraph.ops import log
36+
from ngraph.ops import less
37+
from ngraph.ops import less_eq
38+
from ngraph.ops import logical_not
39+
from ngraph.ops import max
40+
from ngraph.ops import maximum
41+
from ngraph.ops import max_pool
42+
from ngraph.ops import min
43+
from ngraph.ops import minimum
44+
from ngraph.ops import multiply
45+
from ngraph.ops import negative
46+
from ngraph.ops import not_equal
47+
from ngraph.ops import parameter
48+
from ngraph.ops import prod
49+
from ngraph.ops import reshape
50+
from ngraph.ops import sqrt
51+
from ngraph.ops import subtract
52+
from ngraph.ops import sum
53+
from ngraph.ops import tanh
2254

23-
import sys
24-
import six
25-
26-
# workaround to load the libngraph.so with RTLD_GLOBAL
27-
if six.PY3:
28-
import os
29-
flags = os.RTLD_NOW | os.RTLD_GLOBAL
30-
else:
31-
import ctypes
32-
flags = sys.getdlopenflags() | ctypes.RTLD_GLOBAL
33-
sys.setdlopenflags(flags)
34-
35-
from _pyngraph import Function
36-
from _pyngraph import Node
37-
from _pyngraph import NodeVector
38-
from _pyngraph import Type
39-
from _pyngraph import TensorViewType
40-
from _pyngraph import Shape
41-
from _pyngraph import Strides
42-
from _pyngraph import CoordinateDiff
43-
from _pyngraph import AxisSet
44-
from _pyngraph import AxisVector
45-
from _pyngraph import Coordinate
46-
47-
from _pyngraph import serialize
48-
from _pyngraph import util
55+
from ngraph.runtime import runtime
File renamed without changes.

python/ngraph/impl/__init__.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# ******************************************************************************
2+
# Copyright 2017-2018 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ******************************************************************************
16+
"""
17+
Package: ngraph
18+
Low level wrappers for the nGraph c++ api.
19+
"""
20+
21+
# flake8: noqa
22+
23+
import sys
24+
import six
25+
26+
# workaround to load the libngraph.so with RTLD_GLOBAL
27+
if six.PY3:
28+
import os
29+
flags = os.RTLD_NOW | os.RTLD_GLOBAL
30+
else:
31+
import ctypes
32+
flags = sys.getdlopenflags() | ctypes.RTLD_GLOBAL
33+
sys.setdlopenflags(flags)
34+
35+
from _pyngraph import Function
36+
from _pyngraph import Node
37+
from _pyngraph import NodeVector
38+
from _pyngraph import Type
39+
from _pyngraph import TensorViewType
40+
from _pyngraph import Shape
41+
from _pyngraph import Strides
42+
from _pyngraph import CoordinateDiff
43+
from _pyngraph import AxisSet
44+
from _pyngraph import AxisVector
45+
from _pyngraph import Coordinate
46+
47+
from _pyngraph import serialize
48+
from _pyngraph import util
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)