diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index ebfbcacf9c87..8d112b91d642 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -223,6 +223,9 @@ def __init__(self, model, subgraph, exp_tab, ctx): "SCATTER_ND": self.convert_scatter_nd, "SELECT": self.convert_select, "SELECT_V2": self.convert_select, + "SEGMENT_SUM": functools.partial( + self._convert_segment_op, op_name="SEGMENT_SUM", reduction="add" + ), "SHAPE": self.convert_shape, "SIN": functools.partial(self._convert_unary_elemwise, relax_op=_op.sin), "SLICE": self.convert_slice, @@ -246,6 +249,12 @@ def __init__(self, model, subgraph, exp_tab, ctx): "TRANSPOSE_CONV": self.convert_transpose_conv, "TRANSPOSE": self.convert_transpose, "UNPACK": self.convert_unpack, + "UNSORTED_SEGMENT_MIN": functools.partial( + self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", reduction="min" + ), + "UNSORTED_SEGMENT_PROD": functools.partial( + self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", reduction="mul" + ), # "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, "WHERE": self.convert_select, "ZEROS_LIKE": self.convert_zeros_like, @@ -2586,6 +2595,100 @@ def convert_scatter_nd(self, op): data = relax.op.zeros(shape, updates_dtype) return relax.op.scatter_nd(data, indices, updates, "update") + def _get_segment_scatter_base(self, output_shape, output_dtype, reduction): + """Create the identity base tensor for scatter-based segment reductions.""" + if reduction == "add": + return relax.op.zeros(output_shape, output_dtype) + if reduction == "mul": + return relax.op.full(output_shape, relax.const(1, output_dtype), output_dtype) + if reduction == "min": + np_dtype = np.dtype(output_dtype) + if np.issubdtype(np_dtype, np.floating): + identity = np.finfo(np_dtype).max + elif np.issubdtype(np_dtype, np.integer): + identity = np.iinfo(np_dtype).max + else: + raise tvm.error.OpNotImplemented( + f"UNSORTED_SEGMENT_MIN does not support output dtype {output_dtype}." + ) + return relax.op.full(output_shape, relax.const(identity, output_dtype), output_dtype) + + raise ValueError(f"Unsupported segment reduction mode: {reduction}") + + def _get_segment_num_segments(self, op_name, input_tensors): + if op_name == "SEGMENT_SUM": + segment_ids_tensor = input_tensors[1] + if self.has_expr(segment_ids_tensor.tensor_idx): + raise tvm.error.OpNotImplemented( + "TFLite SEGMENT_SUM with runtime segment_ids is not supported, " + "because TFLite does not encode a reliable output segment count." + ) + segment_ids = self.get_tensor_value(segment_ids_tensor) + if np.any(segment_ids < 0): + raise tvm.error.OpNotImplemented( + "TFLite SEGMENT_SUM with negative segment ids is not supported." + ) + return int(np.max(segment_ids)) + 1 if segment_ids.size else 0 + + num_segments_tensor = input_tensors[2] + if self.has_expr(num_segments_tensor.tensor_idx): + raise tvm.error.OpNotImplemented( + f"TFLite {op_name} with runtime num_segments is not supported." + ) + num_segments_value = self.get_tensor_value(num_segments_tensor) + assert num_segments_value.size == 1, f"{op_name} num_segments should be a scalar tensor" + num_segments = int(num_segments_value.item()) + assert num_segments >= 0, f"{op_name} num_segments should be non-negative" + return num_segments + + def _convert_segment_op(self, op, op_name, reduction): + """Convert TFLite segment ops through relax.op.scatter_nd.""" + from tflite.TensorType import TensorType + + input_tensors = self.get_input_tensors(op) + expected_inputs = 2 if op_name == "SEGMENT_SUM" else 3 + assert len(input_tensors) == expected_inputs, ( + f"{op_name} should have {expected_inputs} input tensors" + ) + + data_tensor = input_tensors[0] + segment_ids_tensor = input_tensors[1] + for t in input_tensors: + assert not t.qnn_params, "Quantized input is not expected." + + segment_ids_type = segment_ids_tensor.tensor.Type() + assert segment_ids_type in (TensorType.INT32, TensorType.INT64) + if op_name != "SEGMENT_SUM": + num_segments_type = input_tensors[2].tensor.Type() + assert num_segments_type in (TensorType.INT32, TensorType.INT64) + if not self.has_expr(segment_ids_tensor.tensor_idx): + segment_ids_value = self.get_tensor_value(segment_ids_tensor) + if np.any(segment_ids_value < 0): + raise tvm.error.OpNotImplemented( + f"TFLite {op_name} with negative segment ids is not supported." + ) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, f"{op_name} should have 1 output tensor" + output_tensor = output_tensors[0] + output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) + + data_shape = to_int_list(self.get_tensor_shape(data_tensor)) + segment_ids_shape = to_int_list(self.get_tensor_shape(segment_ids_tensor)) + segment_ids_rank = len(segment_ids_shape) + assert data_shape[:segment_ids_rank] == segment_ids_shape, ( + f"{op_name} requires segment_ids shape to match a prefix of data shape" + ) + num_segments = self._get_segment_num_segments(op_name, input_tensors) + output_shape = [num_segments] + data_shape[segment_ids_rank:] + + data = self.get_tensor_expr(data_tensor) + segment_ids = self.get_tensor_expr(segment_ids_tensor) + indices = relax.op.expand_dims(segment_ids, axis=[segment_ids_rank]) + + base = self._get_segment_scatter_base(output_shape, output_dtype, reduction) + return relax.op.scatter_nd(base, indices, data, reduction) + def convert_select(self, op): """Convert TFLite SELECT""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index c5531ccf73bd..a2d2612232c0 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -1783,6 +1783,101 @@ def func(self, indices, updates, shape): verify(Model) +def test_segment_sum(): + """SEGMENT_SUM lowers to scatter_nd with add reduction.""" + + class Model(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)]) + def func(self, data): + return tf.raw_ops.SegmentSum( + data=data, segment_ids=tf.constant([0, 0, 1, 2], dtype=tf.int32) + ) + + @I.ir_module + class Expected: + @R.function + def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3, 2), dtype="float32") = R.zeros(R.shape([3, 2]), dtype="float32") + lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( + R.const([0, 0, 1, 2], "int32"), axis=[1] + ) + gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd( + lv, lv1, data, reduction="add" + ) + R.output(gv) + return gv + + verify(Model, Expected) + + +def test_unsorted_segment_min(): + """UNSORTED_SEGMENT_MIN lowers to scatter_nd with min reduction.""" + + class Model(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)]) + def func(self, data): + return tf.raw_ops.UnsortedSegmentMin( + data=data, + segment_ids=tf.constant([2, 0, 2, 1], dtype=tf.int32), + num_segments=tf.constant(3, dtype=tf.int32), + ) + + @I.ir_module + class Expected: + @R.function + def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3, 2), dtype="float32") = R.full( + R.shape([3, 2]), R.const(np.finfo(np.float32).max, "float32"), dtype="float32" + ) + lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( + R.const([2, 0, 2, 1], "int32"), axis=[1] + ) + gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd( + lv, lv1, data, reduction="min" + ) + R.output(gv) + return gv + + verify(Model, Expected) + + +def test_unsorted_segment_prod(): + """UNSORTED_SEGMENT_PROD lowers to scatter_nd with mul reduction.""" + + class Model(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)]) + def func(self, data): + return tf.raw_ops.UnsortedSegmentProd( + data=data, + segment_ids=tf.constant([1, 0, 1, 2], dtype=tf.int32), + num_segments=tf.constant(3, dtype=tf.int32), + ) + + @I.ir_module + class Expected: + @R.function + def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3, 2), dtype="float32") = R.full( + R.shape([3, 2]), R.const(1, "float32"), dtype="float32" + ) + lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( + R.const([1, 0, 1, 2], "int32"), axis=[1] + ) + gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd( + lv, lv1, data, reduction="mul" + ) + R.output(gv) + return gv + + verify(Model, Expected) + + def test_batch_matmul(): class BatchMatMul(tf.Module): @tf.function(