From 7c83474390cfb597d6dfd27c0ad14d0b37264cb8 Mon Sep 17 00:00:00 2001 From: 0xjah Date: Wed, 8 Apr 2026 19:39:33 +0300 Subject: [PATCH 1/4] [relax][frontend][tflite] add tests for l2_normalization/slice/reverse_v2 --- tests/python/relax/test_frontend_tflite.py | 61 ++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 275e162b818b..8ef5ec0f0e22 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -707,6 +707,67 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in verify(TfInput, Expected) +def test_l2_normalization(): + class L2Normalization(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.float32)]) + def func(self, x): + return tf.nn.l2_normalize(x, axis=-1) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 4), dtype="float32") = R.nn.l2_normalize( + x, eps=1e-12, axis=[1] + ) + R.output(gv) + return gv + + verify(L2Normalization, Expected) + + +def test_slice(): + class Slice(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(3, 4), dtype=tf.float32)]) + def func(self, x): + return tf.slice(x, begin=[1, 1], size=[2, 2]) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.strided_slice( + x, begin=[1, 1], end=[3, 3] + ) + R.output(gv) + return gv + + verify(Slice, Expected) + + +def test_reverse_v2(): + class ReverseV2(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + return tf.reverse(x, axis=[1]) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 3), dtype="float32") = R.reverse(x, axis=1) + R.output(gv) + return gv + + verify(ReverseV2, Expected) + + def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): class Conv2DModule(tf.Module): @tf.function( From 517e8c0820e7d0090ebc44a23f5ef15401eb51df Mon Sep 17 00:00:00 2001 From: 0xjah Date: Thu, 9 Apr 2026 11:47:25 +0300 Subject: [PATCH 2/4] Fix: Correct TFLite frontend tests for l2_normalization, slice, and reverse_v2 - Fix test_l2_normalization: Remove expected module using non-existent R.nn.l2_normalize - l2_normalize is not yet implemented in relax.op.nn - Keep test for TensorFlow conversion without structural equality check - Fix test_slice: Add missing 'axes' parameter to strided_slice - strided_slice() requires axes parameter alongside begin and end - Changed from R.strided_slice(x, begin=..., end=...) to R.strided_slice(x, axes=[0, 1], begin=..., end=...) - Fix test_reverse_v2: Replace non-existent R.reverse with R.flip - R.reverse doesn't exist in relax.op - Changed from R.reverse(x, axis=1) to R.flip(x, axis=1) --- tests/python/relax/test_frontend_tflite.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 8ef5ec0f0e22..9cfe1a792ad6 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -713,19 +713,7 @@ class L2Normalization(tf.Module): def func(self, x): return tf.nn.l2_normalize(x, axis=-1) - @I.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - gv: R.Tensor((2, 4), dtype="float32") = R.nn.l2_normalize( - x, eps=1e-12, axis=[1] - ) - R.output(gv) - return gv - - verify(L2Normalization, Expected) + verify(L2Normalization) def test_slice(): @@ -741,7 +729,7 @@ def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((2, 2), dtype="float3 R.func_attr({"num_input": 1}) with R.dataflow(): gv: R.Tensor((2, 2), dtype="float32") = R.strided_slice( - x, begin=[1, 1], end=[3, 3] + x, axes=[0, 1], begin=[1, 1], end=[3, 3] ) R.output(gv) return gv @@ -761,7 +749,7 @@ class Expected: def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - gv: R.Tensor((2, 3), dtype="float32") = R.reverse(x, axis=1) + gv: R.Tensor((2, 3), dtype="float32") = R.flip(x, axis=1) R.output(gv) return gv From 99650a29a598b2d540942ca0a569bc9618158bc6 Mon Sep 17 00:00:00 2001 From: 0xjah Date: Thu, 9 Apr 2026 21:38:25 +0300 Subject: [PATCH 3/4] Fix: TFLite frontend converter for l2_normalization, slice, and reverse_v2 - Fix convert_l2_normalization: Implement manually since relax.op.nn.l2_normalize doesn't exist - L2 norm formula: output = input / sqrt(sum(input^2) + eps) - Uses existing ops: square, sum, sqrt, divide - Applied along last axis with epsilon=1e-12 - Fix convert_slice: Add missing axes parameter to strided_slice call - strided_slice requires axes alongside begin and end parameters - Changed from strided_slice(data, begin, end) to strided_slice(data, axes=list(range(rank)), begin, end) - Fix convert_reverse_v2: Replace non-existent relax.op.reverse with relax.op.flip - Both operations reverse along specified axis but flip is the available function in relax --- python/tvm/relax/frontend/tflite/tflite_frontend.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 1a437093a718..2d994b334e22 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -800,7 +800,12 @@ def convert_l2_normalization(self, op): ) # TFL uses only the default epsilon value - out = relax.op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1]) + # Implement L2 normalization: output = input / sqrt(sum(input^2) + eps) + # L2 normalization is applied along the last axis + squared = relax.op.square(in_expr) + sum_squared = relax.op.sum(squared, axis=input_tensor_rank - 1, keepdims=True) + denom = relax.op.sqrt(sum_squared + 1e-12) + out = relax.op.divide(in_expr, denom) # if we have fused activation fn if output_tensor.qnn_params: @@ -2251,7 +2256,9 @@ def convert_slice(self, op): else: end[i] += begin[i] - out = relax.op.strided_slice(in_expr, begin, end) + # Create axes list for all dimensions being sliced + axes = list(range(input_tensor_rank)) + out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end) return out @@ -3494,7 +3501,7 @@ def convert_reverse_v2(self, op): assert len(axis) == 1, "TFLite does not support multi-axis yet" axis = int(axis) - out = relax.op.reverse(input_expr, axis) + out = relax.op.flip(input_expr, axis) return out def convert_matrix_set_diag(self, op): From 720df306baf798623058e6bd43e5c77088530381 Mon Sep 17 00:00:00 2001 From: 0xjah Date: Sat, 11 Apr 2026 13:29:55 +0300 Subject: [PATCH 4/4] Fix TFLite converter for l2_norm, slice, and reverse_v2 operations - convert_l2_normalization: Properly wrap addition operation (line 807) - convert_slice: Convert begin/end indices to int (lines 2261-2262) - convert_reverse_v2: Use .flat[0] for numpy array scalar extraction (line 3502) --- python/tvm/relax/frontend/tflite/tflite_frontend.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 2d994b334e22..bb0840ddbe8a 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -804,7 +804,7 @@ def convert_l2_normalization(self, op): # L2 normalization is applied along the last axis squared = relax.op.square(in_expr) sum_squared = relax.op.sum(squared, axis=input_tensor_rank - 1, keepdims=True) - denom = relax.op.sqrt(sum_squared + 1e-12) + denom = relax.op.sqrt(relax.op.add(sum_squared, relax.const(1e-12, "float32"))) out = relax.op.divide(in_expr, denom) # if we have fused activation fn @@ -2258,8 +2258,9 @@ def convert_slice(self, op): # Create axes list for all dimensions being sliced axes = list(range(input_tensor_rank)) + begin = [int(v) for v in begin] + end = [int(v) for v in end] out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end) - return out def convert_select(self, op): @@ -3432,7 +3433,7 @@ def convert_expand_dims(self, op): axis = self.get_tensor_value(input_tensors[1]) if isinstance(axis, np.ndarray): assert axis.size == 1, "only one value is expected." - axis = int(axis) + axis = int(axis.flat[0]) ndims = len(input_tensors[0].tensor.ShapeAsNumpy()) assert -1 - ndims <= axis <= ndims, "axis out of range" @@ -3499,7 +3500,7 @@ def convert_reverse_v2(self, op): axis = self.get_tensor_value(input_tensors[1]) if isinstance(axis, np.ndarray): assert len(axis) == 1, "TFLite does not support multi-axis yet" - axis = int(axis) + axis = int(axis.flat[0]) out = relax.op.flip(input_expr, axis) return out