Skip to content

Commit 06964b3

Browse files
authored
[Graph] Support float32/int32/int64 type for select fusions. (#148)
1 parent 1d1db47 commit 06964b3

File tree

3 files changed

+336
-5
lines changed

3 files changed

+336
-5
lines changed

tensorflow/core/graph/template_logicsum_base.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ class TemplateLogicSumBase: public TemplateBase {
113113
LOG(WARNING) << "Input check failed";
114114
return false;
115115
}
116-
LOG(INFO) << "Fusion template[" << name() << "] match op[" << nodes[first_key_].node->name() << "]";
116+
LOG(INFO) << "Fusion template[" << name() << "] match op[" << nodes[first_key_].node->name() <<
117+
"][new_name:" << name_prefix << "_" << name() << "]";
117118

118119
Node* node_fused_logicsum = add_fused_logicsum_node(nodes, name_prefix, g, inputs, outputs);
119120
if (!node_fused_logicsum) {

tensorflow/core/graph/template_select_base.h

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@ class TemplateSelectBase: public TemplateBase {
3333
std::string name_prefix, Graph* g,
3434
std::vector<const Edge*>& inputs,
3535
std::vector<std::vector<const Edge*>>& outputs) override {
36-
LOG(INFO) << "Fusion template[" << name() << "] match op[" << nodes[first_key_].node->name() << "]";
36+
DataType datatype = get_data_type(nodes[first_key_].node);
37+
if (datatype != DT_FLOAT && datatype != DT_INT32 && datatype != DT_INT64) {
38+
LOG(INFO) << "Drop fusion template[" << name() << "] match op[" << nodes[first_key_].node->DebugString() << "]";
39+
return false;
40+
} else {
41+
LOG(INFO) << "Fusion template[" << name() << "] match op[" << nodes[first_key_].node->name() <<
42+
"][new_name:" << name_prefix << "_" << name() << "]";
43+
}
3744

3845
Node* node_const_zero = add_zero_like_node(nodes, name_prefix, g, inputs, outputs);
3946
if (!node_const_zero) {
@@ -66,6 +73,14 @@ class TemplateSelectBase: public TemplateBase {
6673
return false;
6774
}
6875

76+
DataType get_data_type(const Node* node) {
77+
DataType datatype;
78+
if (GetNodeAttr(node->def(), "T", &datatype) != Status::OK()) {
79+
return DT_INVALID;
80+
}
81+
return datatype;
82+
}
83+
6984
protected:
7085
virtual Node* add_zero_like_node(
7186
std::map<std::string, MatchedNode>& nodes,
@@ -76,11 +91,21 @@ class TemplateSelectBase: public TemplateBase {
7691
NodeDef const_zero;
7792
const_zero.set_op("Const");
7893
const_zero.set_name(name_prefix + "_const_zero_" + name());
94+
95+
DataType datatype = get_data_type(nodes[first_key_].node);
7996
AttrValue attr_type;
80-
attr_type.set_type(DT_FLOAT);
97+
attr_type.set_type(datatype);
8198
const_zero.mutable_attr()->insert({"dtype", attr_type});
82-
Tensor tensor_zero(DT_FLOAT, {});
83-
tensor_zero.scalar<float>()() = 0.0;
99+
100+
Tensor tensor_zero(datatype, {});
101+
if (datatype == DT_FLOAT) {
102+
tensor_zero.scalar<float>()() = 0;
103+
} else if (datatype == DT_INT32) {
104+
tensor_zero.scalar<int32>()() = 0;
105+
} else if (datatype == DT_INT64) {
106+
tensor_zero.scalar<int64>()() = 0;
107+
}
108+
84109
AttrValue value_zero;
85110
tensor_zero.AsProtoTensorContent(value_zero.mutable_tensor());
86111
const_zero.mutable_attr()->insert({"value", value_zero});
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Tests for tensorflow.ops.tf.MSBatchMatMulGrad"""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
22+
import numpy as np
23+
import os
24+
import shutil
25+
# os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '2'
26+
27+
from tensorflow.contrib import layers
28+
from tensorflow.core.protobuf import config_pb2
29+
from tensorflow.python.client import session
30+
from tensorflow.python.framework import constant_op
31+
from tensorflow.python.framework import dtypes
32+
from tensorflow.python.framework import ops
33+
from tensorflow.python.framework import random_seed
34+
from tensorflow.python.ops import array_ops
35+
from tensorflow.python.ops import init_ops
36+
from tensorflow.python.ops import math_ops
37+
from tensorflow.python.ops import nn_impl
38+
from tensorflow.python.ops import nn_ops
39+
from tensorflow.python.ops import variables
40+
from tensorflow.python.ops import variable_scope
41+
from tensorflow.python.platform import test
42+
from tensorflow.python.summary import summary
43+
from tensorflow.python.training import adagrad
44+
from tensorflow.python.ops import array_ops
45+
46+
47+
48+
# run without auto-replacement of fused ops
49+
def runNonFuse():
50+
g1 = ops.Graph()
51+
with g1.as_default():
52+
random_seed.set_random_seed(0)
53+
54+
n_num = 1024
55+
q_num = 50
56+
k_num = 50
57+
c_num = 128 # c_num % split_num == 0
58+
split_num = 8
59+
60+
data_float32_q = array_ops.placeholder(
61+
dtypes.float32, shape=(None, q_num, c_num))
62+
data_float32_k = array_ops.placeholder(
63+
dtypes.float32, shape=(None, k_num, c_num))
64+
65+
x_float32 = data_float32_q
66+
y_float32 = data_float32_k
67+
m = variable_scope.get_variable(
68+
"m_non_fuse", [split_num, n_num, q_num, k_num],
69+
dtype=dtypes.int32,
70+
initializer=init_ops.random_uniform_initializer(0, 2))
71+
m_bool = math_ops.cast(m, dtype=dtypes.bool)
72+
m_bool = array_ops.reshape(m_bool, [-1, q_num, k_num])
73+
p_float32 = constant_op.constant(
74+
0, shape=[split_num*n_num, q_num, k_num],
75+
dtype=dtypes.float32)
76+
77+
with ops.name_scope('NonFuseForward') as scope:
78+
with ops.device("/cpu:0"):
79+
80+
x_float32 = layers.fully_connected(
81+
x_float32, c_num,
82+
activation_fn=nn_ops.leaky_relu, scope="X")
83+
84+
y_float32 = layers.fully_connected(
85+
y_float32, c_num,
86+
activation_fn=nn_ops.leaky_relu, scope="Y")
87+
88+
xs_float32 = array_ops.concat(
89+
array_ops.split(x_float32, split_num, axis=2), axis=0)
90+
ys_float32 = array_ops.concat(
91+
array_ops.split(y_float32, split_num, axis=2), axis=0)
92+
output_non_fuse_float32 = math_ops.matmul(
93+
xs_float32, ys_float32,
94+
transpose_a=False, transpose_b=True)
95+
96+
zero_tensor = array_ops.zeros_like(array_ops.identity(output_non_fuse_float32))
97+
output_non_fuse_float32 = array_ops.where(
98+
m_bool, output_non_fuse_float32, zero_tensor)
99+
zero_tensor2 =array_ops.zeros_like(zero_tensor)
100+
101+
layer1_non_fuse_float32 = layers.fully_connected(
102+
output_non_fuse_float32, 40,
103+
activation_fn=nn_ops.leaky_relu)
104+
layer2_non_fuse_float32 = layers.fully_connected(
105+
layer1_non_fuse_float32, 20,
106+
activation_fn=nn_ops.leaky_relu)
107+
layer2_non_fuse_float32 = array_ops.reshape(
108+
layer2_non_fuse_float32, [n_num, -1])
109+
layer3_non_fuse_float32 = layers.fully_connected(
110+
layer2_non_fuse_float32, 1,
111+
activation_fn=nn_ops.leaky_relu)
112+
labels_non_fuse_float32 = constant_op.constant(
113+
1, shape=[n_num, 1], dtype=dtypes.float32)
114+
loss_op_non_fuse_float32 = math_ops.reduce_mean(
115+
nn_impl.sigmoid_cross_entropy_with_logits(
116+
logits=layer3_non_fuse_float32,
117+
labels=labels_non_fuse_float32))
118+
119+
with ops.name_scope('NonFuseBackward') as scope:
120+
with ops.device("/cpu:0"):
121+
train_op_non_fuse_float32 = adagrad.AdagradOptimizer(
122+
learning_rate=0.0001,
123+
initial_accumulator_value=0.1).minimize(
124+
loss_op_non_fuse_float32)
125+
126+
init_global = variables.global_variables_initializer()
127+
init_local = variables.local_variables_initializer()
128+
129+
# trigger fusion op or not
130+
graph_options = config_pb2.GraphOptions(
131+
optimizer_options=config_pb2.OptimizerOptions(
132+
do_op_fusion=False))
133+
config = config_pb2.ConfigProto(
134+
allow_soft_placement=False, graph_options=graph_options)
135+
with session.Session(config=config) as sess:
136+
from tensorflow.python.framework import graph_io
137+
graph_io.write_graph(sess.graph, './', 'train.pbtxt')
138+
139+
# output the graph_def
140+
np.random.seed(0)
141+
feed_data_q = np.random.rand(n_num, q_num, c_num)
142+
feed_data_k = np.random.rand(n_num, k_num, c_num)
143+
144+
sess.run([init_global, init_local])
145+
for step in range(50):
146+
loss_val_non_fuse, train_op_val = sess.run(
147+
[loss_op_non_fuse_float32,
148+
train_op_non_fuse_float32],
149+
feed_dict={data_float32_q: feed_data_q,
150+
data_float32_k: feed_data_k})
151+
152+
print("loss val non-fuse: %2.7f" % (loss_val_non_fuse))
153+
return loss_val_non_fuse
154+
155+
156+
def runFuse():
157+
158+
g2 = ops.Graph()
159+
with g2.as_default():
160+
random_seed.set_random_seed(0)
161+
162+
n_num = 1024
163+
q_num = 50
164+
k_num = 50
165+
c_num = 128 # c_num % split_num == 0
166+
split_num = 8
167+
168+
data_float32_q = array_ops.placeholder(
169+
dtypes.float32, shape=(None, q_num, c_num))
170+
data_float32_k = array_ops.placeholder(
171+
dtypes.float32, shape=(None, k_num, c_num))
172+
173+
x_float32 = data_float32_q
174+
y_float32 = data_float32_k
175+
m = variable_scope.get_variable(
176+
"m_fuse", [split_num, n_num, q_num, k_num], dtype=dtypes.int32,
177+
initializer=init_ops.random_uniform_initializer(0, 2))
178+
m_bool = math_ops.cast(m, dtype=dtypes.bool)
179+
m_bool = array_ops.reshape(m_bool, [-1, q_num, k_num])
180+
p_float32 = constant_op.constant(
181+
0, shape=[split_num*n_num, q_num, k_num], dtype=dtypes.float32)
182+
183+
with ops.name_scope('FuseForward') as scope:
184+
with ops.device("/cpu:0"):
185+
186+
x_float32 = layers.fully_connected(
187+
x_float32, c_num,
188+
activation_fn=nn_ops.leaky_relu, scope="X")
189+
190+
y_float32 = layers.fully_connected(
191+
y_float32, c_num,
192+
activation_fn=nn_ops.leaky_relu, scope="Y")
193+
194+
xs_float32 = array_ops.concat(
195+
array_ops.split(x_float32, split_num, axis=2), axis=0)
196+
ys_float32 = array_ops.concat(
197+
array_ops.split(y_float32, split_num, axis=2), axis=0)
198+
output_fuse_float32 = math_ops.matmul(
199+
xs_float32, ys_float32,
200+
transpose_a=False, transpose_b=True)
201+
202+
zero_tensor = array_ops.zeros_like(array_ops.identity(output_fuse_float32))
203+
output_fuse_float32 = array_ops.where(
204+
m_bool, output_fuse_float32, zero_tensor)
205+
zero_tensor2 = array_ops.zeros_like(zero_tensor)
206+
207+
layer1_fuse_float32 = layers.fully_connected(
208+
output_fuse_float32, 40,
209+
activation_fn=nn_ops.leaky_relu)
210+
layer2_fuse_float32 = layers.fully_connected(
211+
layer1_fuse_float32, 20,
212+
activation_fn=nn_ops.leaky_relu)
213+
layer2_fuse_float32 = array_ops.reshape(
214+
layer2_fuse_float32, [n_num, -1])
215+
layer3_fuse_float32 = layers.fully_connected(
216+
layer2_fuse_float32, 1,
217+
activation_fn=nn_ops.leaky_relu)
218+
labels_fuse_float32 = constant_op.constant(
219+
1, shape=[n_num, 1], dtype=dtypes.float32)
220+
loss_op_fuse_float32 = math_ops.reduce_mean(
221+
nn_impl.sigmoid_cross_entropy_with_logits(
222+
logits=layer3_fuse_float32,
223+
labels=labels_fuse_float32))
224+
225+
with ops.name_scope('FuseBackward') as scope:
226+
with ops.device("/cpu:0"):
227+
train_op_fuse_float32 = adagrad.AdagradOptimizer(
228+
learning_rate=0.0001,
229+
initial_accumulator_value=0.1).minimize(
230+
loss_op_fuse_float32)
231+
232+
init_global = variables.global_variables_initializer()
233+
init_local = variables.local_variables_initializer()
234+
235+
# trigger fusion op or not
236+
graph_options = config_pb2.GraphOptions(
237+
optimizer_options=config_pb2.OptimizerOptions(
238+
do_op_fusion=True))
239+
config = config_pb2.ConfigProto(
240+
allow_soft_placement=False, graph_options=graph_options)
241+
with session.Session(config=config) as sess:
242+
from tensorflow.python.framework import graph_io
243+
graph_io.write_graph(sess.graph, './', 'train2.pbtxt')
244+
245+
np.random.seed(0)
246+
feed_data_q = np.random.rand(n_num, q_num, c_num)
247+
feed_data_k = np.random.rand(n_num, k_num, c_num)
248+
sess.run([init_global, init_local])
249+
250+
for step in range(50):
251+
loss_val_replaced, train_op_val = sess.run(
252+
[loss_op_fuse_float32, train_op_fuse_float32],
253+
feed_dict={data_float32_q: feed_data_q,
254+
data_float32_k: feed_data_k})
255+
256+
print("loss val fuse: %2.7f" % loss_val_replaced)
257+
return loss_val_replaced
258+
259+
def runFuseForIntType():
260+
graph_options = config_pb2.GraphOptions(
261+
optimizer_options=config_pb2.OptimizerOptions(
262+
do_op_fusion=True))
263+
config = config_pb2.ConfigProto(
264+
allow_soft_placement=False, graph_options=graph_options)
265+
266+
with session.Session(config=config) as sess:
267+
with sess.graph.as_default():
268+
269+
# with ops.name_scope('FuseForward') as scope:
270+
t_cond = variables.Variable([[True, True], [False, False]], dtype=dtypes.bool)
271+
t_then = variables.Variable([[11,12],[13,14]], dtype=dtypes.int32)
272+
t_else = variables.Variable([[21,22],[23,24]], dtype=dtypes.int32)
273+
t_out = variables.Variable([[31,32],[33,34]], dtype=dtypes.int32)
274+
275+
t_then = array_ops.zeros_like(array_ops.reshape(array_ops.unique(array_ops.reshape(t_then, [-1]))[0], [-1, 2]))
276+
t_select = array_ops.where(
277+
t_cond, t_then, t_else)
278+
t_result = t_out + t_select
279+
280+
init_global = variables.global_variables_initializer()
281+
init_local = variables.local_variables_initializer()
282+
283+
from tensorflow.python.framework import graph_io
284+
graph_io.write_graph(sess.graph, './', 'train_3.pbtxt')
285+
286+
np.random.seed(0)
287+
feed_p_input = np.random.rand(2, 2)
288+
sess.run([init_global, init_local])
289+
290+
result = sess.run([t_result, ])
291+
print("result:", result)
292+
return result
293+
294+
class SelectZeroLikeFusionTest(test.TestCase):
295+
def testFusion(self):
296+
res_non_fuse = runNonFuse()
297+
res_fuse = runFuse()
298+
self.assertAllCloseAccordingToType(res_non_fuse, res_fuse)
299+
300+
def testFusionForIntType(self):
301+
result = runFuseForIntType()
302+
self.assertAllEqual(result, [[[31, 32], [56, 58]]])
303+
304+
if __name__ == "__main__":
305+
test.main()

0 commit comments

Comments
 (0)