diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 1a437093a718..bb0840ddbe8a 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -800,7 +800,12 @@ def convert_l2_normalization(self, op): ) # TFL uses only the default epsilon value - out = relax.op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1]) + # Implement L2 normalization: output = input / sqrt(sum(input^2) + eps) + # L2 normalization is applied along the last axis + squared = relax.op.square(in_expr) + sum_squared = relax.op.sum(squared, axis=input_tensor_rank - 1, keepdims=True) + denom = relax.op.sqrt(relax.op.add(sum_squared, relax.const(1e-12, "float32"))) + out = relax.op.divide(in_expr, denom) # if we have fused activation fn if output_tensor.qnn_params: @@ -2251,8 +2256,11 @@ def convert_slice(self, op): else: end[i] += begin[i] - out = relax.op.strided_slice(in_expr, begin, end) - + # Create axes list for all dimensions being sliced + axes = list(range(input_tensor_rank)) + begin = [int(v) for v in begin] + end = [int(v) for v in end] + out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end) return out def convert_select(self, op): @@ -3425,7 +3433,7 @@ def convert_expand_dims(self, op): axis = self.get_tensor_value(input_tensors[1]) if isinstance(axis, np.ndarray): assert axis.size == 1, "only one value is expected." - axis = int(axis) + axis = int(axis.flat[0]) ndims = len(input_tensors[0].tensor.ShapeAsNumpy()) assert -1 - ndims <= axis <= ndims, "axis out of range" @@ -3492,9 +3500,9 @@ def convert_reverse_v2(self, op): axis = self.get_tensor_value(input_tensors[1]) if isinstance(axis, np.ndarray): assert len(axis) == 1, "TFLite does not support multi-axis yet" - axis = int(axis) + axis = int(axis.flat[0]) - out = relax.op.reverse(input_expr, axis) + out = relax.op.flip(input_expr, axis) return out def convert_matrix_set_diag(self, op): diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 275e162b818b..9cfe1a792ad6 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -707,6 +707,55 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in verify(TfInput, Expected) +def test_l2_normalization(): + class L2Normalization(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.float32)]) + def func(self, x): + return tf.nn.l2_normalize(x, axis=-1) + + verify(L2Normalization) + + +def test_slice(): + class Slice(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(3, 4), dtype=tf.float32)]) + def func(self, x): + return tf.slice(x, begin=[1, 1], size=[2, 2]) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.strided_slice( + x, axes=[0, 1], begin=[1, 1], end=[3, 3] + ) + R.output(gv) + return gv + + verify(Slice, Expected) + + +def test_reverse_v2(): + class ReverseV2(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + return tf.reverse(x, axis=[1]) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 3), dtype="float32") = R.flip(x, axis=1) + R.output(gv) + return gv + + verify(ReverseV2, Expected) + + def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): class Conv2DModule(tf.Module): @tf.function(