1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616# ******************************************************************************
17- from ngraph import Type , Function
18- from ngraph import Node
19- from ngraph .op import Parameter , Maximum , Reshape , Dot , Broadcast
20- from ngraph .op import Constant , Exp , Log , Sum
21- from ngraph .op import Greater , Convert , Reduce
22- from ngraph .op import OneHot
17+ from ngraph . impl import Type , Function
18+ from ngraph . impl import Node , Shape , AxisVector , AxisSet
19+ from ngraph .impl . op import Parameter , Maximum , Reshape , Dot , Broadcast
20+ from ngraph .impl . op import Constant , Exp , Log , Sum
21+ from ngraph .impl . op import Greater , Convert , Reduce
22+ from ngraph .impl . op import OneHot
2323
2424from typing import List , Dict , Set
2525
2929bz = 53
3030lr = 0.2
3131
32- Input = Parameter (float_element_type , [bz , 28 , 28 ])
33- Label = Parameter (int_element_type , [bz ])
34- LabelOneHot = Convert ((OneHot (Label , [bz , 10 ], 1 )), float_element_type )
32+ Input = Parameter (float_element_type , Shape ( [bz , 28 , 28 ]) )
33+ Label = Parameter (int_element_type , Shape ( [bz ]) )
34+ LabelOneHot = Convert ((OneHot (Label , Shape ( [bz , 10 ]) , 1 )), float_element_type )
3535
36- MaxParam1 = Parameter (float_element_type , [] )
37- MaxParam2 = Parameter (float_element_type , [] )
38- MaxFn = Function ([ Maximum (MaxParam1 , MaxParam2 )] ,
36+ MaxParam1 = Parameter (float_element_type , Shape ([]) )
37+ MaxParam2 = Parameter (float_element_type , Shape ([]) )
38+ MaxFn = Function (Maximum (MaxParam1 , MaxParam2 ),
3939 [MaxParam1 , MaxParam2 ],
4040 'mnist' )
4141
@@ -44,10 +44,10 @@ def make_scalar_constant(elem_type, scalar, shape=None, axis_set=None):
4444 # type: (int, float, List[int], Set[int]) -> float
4545 """Create a Constant node for scalar value."""
4646 if shape is None :
47- shape = []
47+ shape = Shape ([])
4848 if axis_set is None :
49- axis_set = set ()
50- scalar_shape = [] # type: List[int]
49+ axis_set = AxisSet ( set () )
50+ scalar_shape = Shape ([]) # type: List[int]
5151 constant_op = Constant (elem_type , scalar_shape , [scalar ])
5252 constant_broadcast = Broadcast (constant_op , shape , axis_set )
5353 return constant_broadcast
@@ -60,7 +60,7 @@ def make_float32_constant(scalar, shape=None, axis_set=None):
6060 shape = []
6161 if axis_set is None :
6262 axis_set = set ()
63- return make_scalar_constant (Type .f32 , scalar , shape , axis_set )
63+ return make_scalar_constant (Type .f32 , scalar , Shape ( shape ), AxisSet ( axis_set ) )
6464
6565
6666def make_float32_constant_like (scalar , op ): # type: (float, Node) -> float
@@ -69,7 +69,7 @@ def make_float32_constant_like(scalar, op): # type: (float, Node) -> float
6969 shape = op .get_shape ()
7070 for i in range (len (shape )):
7171 v .add (i )
72- return make_float32_constant (scalar , shape , v )
72+ return make_float32_constant (scalar , Shape ( shape ), AxisSet ( v ) )
7373
7474
7575def transpose (op , order ): # type: (Node, List[int]) -> Node
@@ -78,7 +78,7 @@ def transpose(op, order): # type: (Node, List[int]) -> Node
7878 for i in range (len (order )):
7979 v .append (op .get_shape ()[order [i ]])
8080 new_shape = v
81- return Reshape (op , order , new_shape )
81+ return Reshape (op , AxisVector ( order ), Shape ( new_shape ) )
8282
8383
8484def relu (op ): # type: (Node) -> Node
@@ -87,45 +87,45 @@ def relu(op): # type: (Node) -> Node
8787
8888
8989# Flatten
90- X1 = Reshape (Input , [0 , 1 , 2 ], [bz , 784 ])
90+ X1 = Reshape (Input , AxisVector ( [0 , 1 , 2 ]), Shape ( [bz , 784 ]) )
9191
9292# Normalize
9393X2 = X1 / make_float32_constant_like (255. , X1 )
9494
9595# Affine 1
96- W1 = Parameter (float_element_type , [784 , 100 ])
97- b1 = Parameter (float_element_type , [100 ])
98- X3 = Dot (X2 , W1 ) + Broadcast (b1 , [bz , 100 ], {0 })
96+ W1 = Parameter (float_element_type , Shape ( [784 , 100 ]) )
97+ b1 = Parameter (float_element_type , Shape ( [100 ]) )
98+ X3 = Dot (X2 , W1 ) + Broadcast (b1 , Shape ( [bz , 100 ]), AxisSet ( {0 }) )
9999X4 = relu (X3 )
100100
101101# Affine 2
102- W2 = Parameter (float_element_type , [100 , 10 ])
103- b2 = Parameter (float_element_type , [10 ])
104- X5 = Dot (X4 , W2 ) + Broadcast (b2 , [bz , 10 ], {0 })
102+ W2 = Parameter (float_element_type , Shape ( [100 , 10 ]) )
103+ b2 = Parameter (float_element_type , Shape ( [10 ]) )
104+ X5 = Dot (X4 , W2 ) + Broadcast (b2 , Shape ( [bz , 10 ]), AxisSet ( {0 }) )
105105
106106# Softmax
107107Logits = X5
108108Exp = Exp (Logits )
109- Max = Reduce (Exp , make_float32_constant (0. , [], set ()), MaxFn , {1 })
110- MaxBroadcast = Broadcast (Max , [bz , 10 ], {1 })
109+ Max = Reduce (Exp , make_float32_constant (0. , [], set ()), MaxFn , AxisSet ( {1 }) )
110+ MaxBroadcast = Broadcast (Max , Shape ( [bz , 10 ]), AxisSet ( {1 }) )
111111Softmax = Exp / MaxBroadcast
112112
113113# Loss
114114LogSoftmax = Log (Softmax )
115- Loss = Sum (LogSoftmax * LabelOneHot , {0 , 1 }) / make_float32_constant (float (bz ), [], set ())
115+ Loss = Sum (LogSoftmax * LabelOneHot , AxisSet ( {0 , 1 }) ) / make_float32_constant (float (bz ), [], set ())
116116
117117# Derivatives
118118dLogits = Softmax - LabelOneHot
119119dX5 = dLogits
120120
121- dX4 = Dot (dX5 , transpose (W2 , [1 , 0 ]))
122- dW2 = Dot (transpose (X4 , [1 , 0 ]), dX5 )
123- db2 = Sum (dX5 , {0 })
121+ dX4 = Dot (dX5 , transpose (W2 , Shape ( [1 , 0 ]) ))
122+ dW2 = Dot (transpose (X4 , Shape ( [1 , 0 ]) ), dX5 )
123+ db2 = Sum (dX5 , AxisSet ( {0 }) )
124124
125125dX3 = Convert ((Greater (X3 , make_float32_constant (0. , [bz , 100 ], {0 , 1 }))), float_element_type ) * dX4
126- dX2 = Dot (dX3 , transpose (W1 , [1 , 0 ]))
127- dW1 = Dot (transpose (X2 , [1 , 0 ]), dX3 )
128- db1 = Sum (dX3 , {0 })
126+ dX2 = Dot (dX3 , transpose (W1 , Shape ( [1 , 0 ]) ))
127+ dW1 = Dot (transpose (X2 , Shape ( [1 , 0 ]) ), dX3 )
128+ db1 = Sum (dX3 , AxisSet ( {0 }) )
129129
130130nW1 = W1 - make_float32_constant_like (lr , dW1 ) * dW1
131131nb1 = b1 - make_float32_constant_like (lr , db1 ) * db1
0 commit comments