Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"SCATTER_ND": self.convert_scatter_nd,
"SELECT": self.convert_select,
"SELECT_V2": self.convert_select,
"SEGMENT_SUM": functools.partial(
self._convert_segment_op, op_name="SEGMENT_SUM", reduction="add"
),
"SHAPE": self.convert_shape,
"SIN": functools.partial(self._convert_unary_elemwise, relax_op=_op.sin),
"SLICE": self.convert_slice,
Expand All @@ -246,6 +249,12 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"TRANSPOSE_CONV": self.convert_transpose_conv,
"TRANSPOSE": self.convert_transpose,
"UNPACK": self.convert_unpack,
"UNSORTED_SEGMENT_MIN": functools.partial(
self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", reduction="min"
),
"UNSORTED_SEGMENT_PROD": functools.partial(
self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", reduction="mul"
),
# "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm,
"WHERE": self.convert_select,
"ZEROS_LIKE": self.convert_zeros_like,
Expand Down Expand Up @@ -2586,6 +2595,100 @@ def convert_scatter_nd(self, op):
data = relax.op.zeros(shape, updates_dtype)
return relax.op.scatter_nd(data, indices, updates, "update")

def _get_segment_scatter_base(self, output_shape, output_dtype, reduction):
"""Create the identity base tensor for scatter-based segment reductions."""
if reduction == "add":
return relax.op.zeros(output_shape, output_dtype)
if reduction == "mul":
return relax.op.full(output_shape, relax.const(1, output_dtype), output_dtype)
if reduction == "min":
np_dtype = np.dtype(output_dtype)
if np.issubdtype(np_dtype, np.floating):
identity = np.finfo(np_dtype).max
elif np.issubdtype(np_dtype, np.integer):
identity = np.iinfo(np_dtype).max
else:
raise tvm.error.OpNotImplemented(
f"UNSORTED_SEGMENT_MIN does not support output dtype {output_dtype}."
)
return relax.op.full(output_shape, relax.const(identity, output_dtype), output_dtype)

raise ValueError(f"Unsupported segment reduction mode: {reduction}")

def _get_segment_num_segments(self, op_name, input_tensors):
if op_name == "SEGMENT_SUM":
segment_ids_tensor = input_tensors[1]
if self.has_expr(segment_ids_tensor.tensor_idx):
raise tvm.error.OpNotImplemented(
"TFLite SEGMENT_SUM with runtime segment_ids is not supported, "
"because TFLite does not encode a reliable output segment count."
)
Comment thread
Aharrypotter marked this conversation as resolved.
segment_ids = self.get_tensor_value(segment_ids_tensor)
if np.any(segment_ids < 0):
raise tvm.error.OpNotImplemented(
"TFLite SEGMENT_SUM with negative segment ids is not supported."
)
return int(np.max(segment_ids)) + 1 if segment_ids.size else 0

num_segments_tensor = input_tensors[2]
if self.has_expr(num_segments_tensor.tensor_idx):
raise tvm.error.OpNotImplemented(
f"TFLite {op_name} with runtime num_segments is not supported."
)
num_segments_value = self.get_tensor_value(num_segments_tensor)
assert num_segments_value.size == 1, f"{op_name} num_segments should be a scalar tensor"
num_segments = int(num_segments_value.item())
assert num_segments >= 0, f"{op_name} num_segments should be non-negative"
return num_segments

def _convert_segment_op(self, op, op_name, reduction):
"""Convert TFLite segment ops through relax.op.scatter_nd."""
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
expected_inputs = 2 if op_name == "SEGMENT_SUM" else 3
assert len(input_tensors) == expected_inputs, (
f"{op_name} should have {expected_inputs} input tensors"
)

data_tensor = input_tensors[0]
segment_ids_tensor = input_tensors[1]
for t in input_tensors:
assert not t.qnn_params, "Quantized input is not expected."

segment_ids_type = segment_ids_tensor.tensor.Type()
assert segment_ids_type in (TensorType.INT32, TensorType.INT64)
if op_name != "SEGMENT_SUM":
num_segments_type = input_tensors[2].tensor.Type()
assert num_segments_type in (TensorType.INT32, TensorType.INT64)
if not self.has_expr(segment_ids_tensor.tensor_idx):
segment_ids_value = self.get_tensor_value(segment_ids_tensor)
if np.any(segment_ids_value < 0):
raise tvm.error.OpNotImplemented(
f"TFLite {op_name} with negative segment ids is not supported."
)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, f"{op_name} should have 1 output tensor"
output_tensor = output_tensors[0]
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())

data_shape = to_int_list(self.get_tensor_shape(data_tensor))
segment_ids_shape = to_int_list(self.get_tensor_shape(segment_ids_tensor))
segment_ids_rank = len(segment_ids_shape)
assert data_shape[:segment_ids_rank] == segment_ids_shape, (
f"{op_name} requires segment_ids shape to match a prefix of data shape"
)
num_segments = self._get_segment_num_segments(op_name, input_tensors)
output_shape = [num_segments] + data_shape[segment_ids_rank:]

data = self.get_tensor_expr(data_tensor)
segment_ids = self.get_tensor_expr(segment_ids_tensor)
indices = relax.op.expand_dims(segment_ids, axis=[segment_ids_rank])

base = self._get_segment_scatter_base(output_shape, output_dtype, reduction)
return relax.op.scatter_nd(base, indices, data, reduction)

def convert_select(self, op):
"""Convert TFLite SELECT"""
input_tensors = self.get_input_tensors(op)
Expand Down
95 changes: 95 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,101 @@ def func(self, indices, updates, shape):
verify(Model)


def test_segment_sum():
"""SEGMENT_SUM lowers to scatter_nd with add reduction."""

class Model(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)])
def func(self, data):
return tf.raw_ops.SegmentSum(
data=data, segment_ids=tf.constant([0, 0, 1, 2], dtype=tf.int32)
)

@I.ir_module
class Expected:
@R.function
def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((3, 2), dtype="float32") = R.zeros(R.shape([3, 2]), dtype="float32")
lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
R.const([0, 0, 1, 2], "int32"), axis=[1]
)
gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
lv, lv1, data, reduction="add"
)
R.output(gv)
return gv

verify(Model, Expected)


def test_unsorted_segment_min():
"""UNSORTED_SEGMENT_MIN lowers to scatter_nd with min reduction."""

class Model(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)])
def func(self, data):
return tf.raw_ops.UnsortedSegmentMin(
data=data,
segment_ids=tf.constant([2, 0, 2, 1], dtype=tf.int32),
num_segments=tf.constant(3, dtype=tf.int32),
)

@I.ir_module
class Expected:
@R.function
def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((3, 2), dtype="float32") = R.full(
R.shape([3, 2]), R.const(np.finfo(np.float32).max, "float32"), dtype="float32"
)
lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
R.const([2, 0, 2, 1], "int32"), axis=[1]
)
gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
lv, lv1, data, reduction="min"
)
R.output(gv)
return gv

verify(Model, Expected)


def test_unsorted_segment_prod():
"""UNSORTED_SEGMENT_PROD lowers to scatter_nd with mul reduction."""

class Model(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)])
def func(self, data):
return tf.raw_ops.UnsortedSegmentProd(
data=data,
segment_ids=tf.constant([1, 0, 1, 2], dtype=tf.int32),
num_segments=tf.constant(3, dtype=tf.int32),
)

@I.ir_module
class Expected:
@R.function
def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((3, 2), dtype="float32") = R.full(
R.shape([3, 2]), R.const(1, "float32"), dtype="float32"
)
lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims(
R.const([1, 0, 1, 2], "int32"), axis=[1]
)
gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(
lv, lv1, data, reduction="mul"
)
R.output(gv)
return gv

verify(Model, Expected)


def test_batch_matmul():
class BatchMatMul(tf.Module):
@tf.function(
Expand Down
Loading