Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,8 +3120,6 @@ def convert_transpose_conv(self, op):
weight_expr_iohw,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="IOHW",
out_dtype=output_tensor_type_str,
Expand Down
83 changes: 82 additions & 1 deletion tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,88 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in
verify(TfInput, Expected)


def test_fully_connected():
class FullyConnected(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 8), dtype=tf.float32)])
def func(self, x):
weight = tf.constant(np.arange(24, dtype=np.float32).reshape((3, 8)))
bias = tf.constant(np.array([0.5, 1.0, -1.0], dtype=np.float32))
out = tf.matmul(x, weight, transpose_b=True)
return tf.nn.bias_add(out, bias)

verify(FullyConnected)


def test_depthwise_conv2d():
class DepthwiseConv2D(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32),
tf.TensorSpec(shape=(3, 3, 2, 1), dtype=tf.float32),
]
)
def func(self, data, kernel):
return tf.nn.depthwise_conv2d(
input=data,
filter=kernel,
strides=[1, 1, 1, 1],
padding="SAME",
)

verify(DepthwiseConv2D)


def test_transpose_conv():
class TransposeConv(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32),
tf.TensorSpec(shape=(3, 3, 3, 2), dtype=tf.float32),
]
)
def func(self, data, kernel):
output_shape = tf.constant([1, 8, 8, 3], dtype=tf.int32)
return tf.nn.conv2d_transpose(
input=data,
filters=kernel,
output_shape=output_shape,
strides=[1, 1, 1, 1],
padding="SAME",
)

verify(TransposeConv)

def test_l2_pool2d():
class L2Pool2D(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32)])
def func(self, data):
squared = tf.math.square(data)
pooled = tf.nn.avg_pool2d(squared, ksize=[2, 2], strides=[1, 1], padding="SAME")
return tf.math.sqrt(pooled)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 8, 8, 2), dtype="float32")
) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
squared = R.power(data, R.const(2.0, "float32"))
pooled = R.nn.avg_pool2d(
squared,
pool_size=[2, 2],
strides=[1, 1],
padding=[0, 0, 1, 1],
layout="NHWC",
)
gv = R.sqrt(pooled)
R.output(gv)
return gv

verify(L2Pool2D, Expected)


def test_l2_normalization():
class L2Normalization(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.float32)])
Expand Down Expand Up @@ -757,7 +839,6 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3

verify(ReverseV2, Expected)


def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding):
class Conv2DModule(tf.Module):
@tf.function(
Expand Down
Loading