Skip to content

Commit 59b9536

Browse files
committed
Replace class decorator with a function decorator
1 parent aa37807 commit 59b9536

File tree

20 files changed

+123
-146
lines changed

20 files changed

+123
-146
lines changed

tripy/tests/backend/mlir/test_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,6 @@ def test_reason_context(self):
6666

6767
with pytest.raises(
6868
TripyException,
69-
match=".*This is the first level of context\n This is the second level of context\n.*",
69+
match=".*This is the first level of context\n This is the second level of context.\n.*",
7070
) as exc:
7171
map_error_to_user_code_and_raise(flat_ir, exc, err_str)

tripy/tests/flat_ir/test_constant_deduplication.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,9 @@ def create_subgraph(config):
5353
# Create a function with no inputs and a single output
5454
func_result_tensor = FlatIRTensor.build(shape=[2], rank=1, dtype=int32, reason_details="", device=device("gpu"))
5555
setattr(result_tensor, "caller_tensor", func_result_tensor)
56-
func = FlatIRFunction("MockFunc", [], [result_tensor])
56+
func = FlatIRFunction("MockFunc", [], [result_tensor], [op1, op2, op3, mock_op])
5757
func_result_tensor.producer = func
5858

59-
# Insert all operations in a function
60-
func.ops = [op1, op2, op3, mock_op]
61-
6259
# Return function result tensor i.e. output of a function call
6360
return [], [func_result_tensor]
6461

tripy/tests/flat_ir/test_function_deduplication.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,11 @@ def create_function(
7777
flat_ir.register_tensor(callee_out)
7878
setattr(callee_out, "caller_tensor", original_out)
7979

80-
func = FlatIRFunction(name, [callee_input], callee_outputs)
8180
mock_op = MockOp([callee_input], [callee_outputs[0]])
8281
const_op = ConstantOp.build([], [callee_outputs[1]], data=[3, 4, 5])
8382
callee_outputs[1].producer = const_op
8483

85-
func.ops.extend([mock_op, const_op])
84+
func = FlatIRFunction(name, [callee_input], callee_outputs, [mock_op, const_op])
8685
for out in output_tensors:
8786
out.producer = func
8887

tripy/tripy/backend/mlir/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def find_inputs(graph_nodes):
405405
inputs = [subgraph.register_tensor(inp.to_flat_ir()) for inp in op.inputs]
406406
outputs = [subgraph.register_tensor(out.to_flat_ir()) for out in op.outputs]
407407
# Pass shallow copies of inputs/outputs so that the op is free to modify them
408-
op.convert_to_flat_ir(copy.copy(inputs), copy.copy(outputs), subgraph)
408+
op.to_flat_ir(copy.copy(inputs), copy.copy(outputs))
409409
subgraph.integrate_subgraph(inputs, outputs)
410410

411411
mlir = subgraph.to_mlir()

tripy/tripy/flat_ir/function.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
class FlatIRFunction:
2222
"""Represents a function in the Flat IR."""
2323

24-
def __init__(self, name: str, inputs: List["FlatIRTensor"], outputs: List["FlatIRTensor"]):
24+
def __init__(self, name: str, inputs: List["FlatIRTensor"], outputs: List["FlatIRTensor"], ops):
2525
"""Initialize a FlatIRFunction."""
2626
self.name = name
2727
self.inputs = inputs
2828
self.outputs = outputs
29-
self.ops: List[BaseFlatIROp] = []
30-
self.traced_ir_ops: List[BaseFlatIROp] = [] # Used for flat ir function deduplication
29+
self.ops = ops
30+
self.traced_ir_ops = ops # Used only for function deduplication.
3131
self.trace_input_names = None
3232
self.trace_output_names = None
3333
self.caller_replacements = []
@@ -38,8 +38,7 @@ def clone_with_new_io(
3838
"""
3939
Create a clone of the function with new inputs and outputs.
4040
"""
41-
new_func = FlatIRFunction(self.name, new_inputs, new_outputs)
42-
new_func.ops = self.ops
41+
new_func = FlatIRFunction(self.name, new_inputs, new_outputs, self.ops)
4342
new_func.trace_input_names = self.trace_input_names
4443
new_func.trace_output_names = self.trace_output_names
4544
return new_func
@@ -54,13 +53,6 @@ def get_caller_inputs(self) -> List["FlatIRTensor"]:
5453
def get_caller_outputs(self) -> List["FlatIRTensor"]:
5554
return [getattr(out, "caller_tensor") for out in self.outputs]
5655

57-
def add_op(self, op: BaseFlatIROp) -> None:
58-
"""
59-
Add an operation to the function ops. `trace_ir_ops` are used only for function deduplication.
60-
"""
61-
self.ops.append(op)
62-
self.traced_ir_ops.append(op)
63-
6456
def __str__(self) -> str:
6557
"""Generate a string representation of the function."""
6658
function_signature = [

tripy/tripy/frontend/trace/ops/base.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -201,26 +201,6 @@ def to_flat_ir(self, inputs: List["FlatIRTensor"], outputs: List["FlatIRTensor"]
201201
"""
202202
...
203203

204-
def convert_to_flat_ir(
205-
self, inputs: List["FlatIRTensor"], outputs: List["FlatIRTensor"], flat_ir: Optional["FlatIR"] = None
206-
) -> None:
207-
"""
208-
Convert the trace operation to Flat IR representation.
209-
210-
This method decides whether to call the wrapped or unwrapped version
211-
of to_flat_ir based on the is_flat_ir_conversion_wrapped flag.
212-
213-
Args:
214-
inputs: List of input FlatIRTensor objects
215-
outputs: List of output FlatIRTensor objects
216-
flat_ir: Optional FlatIR object for the wrapped version
217-
"""
218-
try:
219-
self.to_flat_ir(inputs, outputs)
220-
except TypeError as e:
221-
assert "to_flat_ir() missing 1 required positional argument: 'flat_ir'" in str(e)
222-
self.to_flat_ir(inputs, outputs, flat_ir)
223-
224204
def str_skip_fields(self) -> Set[str]:
225205
"""
226206
Returns names of dataclass fields to skip when generating a string representation of the op.

tripy/tripy/frontend/trace/ops/binary_elementwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929

3030
@dataclass(repr=False)
31-
@frontend_utils.wraps_to_flat_ir_to_func
3231
class BinaryElementwise(BaseTraceOp):
3332
class Kind:
3433
SUM = " + "
@@ -127,6 +126,7 @@ def broadcast_inputs(self, inputs):
127126

128127
return [broadcasted_input_0, broadcasted_input_1]
129128

129+
@frontend_utils.make_function
130130
def to_flat_ir(self, inputs, outputs):
131131
from tripy.flat_ir.ops import AddOp, DivideOp, FloorOp, MaxOp, MinOp, MulOp, PowOp, SubtractOp
132132
from tripy.flat_ir.tensor import FlatIRTensor
@@ -193,7 +193,6 @@ def to_flat_ir(self, inputs, outputs):
193193

194194

195195
@dataclass(repr=False)
196-
@frontend_utils.wraps_to_flat_ir_to_func
197196
class Comparison(BinaryElementwise):
198197
class Kind:
199198
class KindElem(str):
@@ -217,6 +216,7 @@ def __new__(cls, content, compare_direction):
217216
def infer_dtypes(self):
218217
self.outputs[0].dtype = datatype.bool
219218

219+
@frontend_utils.make_function
220220
def to_flat_ir(self, inputs, outputs):
221221
from tripy.flat_ir.ops import CompareOp
222222

tripy/tripy/frontend/trace/ops/cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424

2525
@dataclass(repr=False)
26-
@frontend_utils.wraps_to_flat_ir_to_func
2726
class Cast(BaseTraceOp):
2827
dtype: "tripy.common.dtype"
2928

@@ -46,6 +45,7 @@ def infer_shape_output_idxs(self, inputs):
4645
def infer_dtypes(self):
4746
self.outputs[0].dtype = self.dtype
4847

48+
@frontend_utils.make_function
4949
def to_flat_ir(self, inputs, outputs):
5050
from tripy.common.datatype import int32, int64, float32, bool as tp_bool
5151
from tripy.flat_ir.ops import CompareOp, ConvertOp, ConstantOp, DynamicBroadcastOp

tripy/tripy/frontend/trace/ops/dequantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828

2929
@dataclass(repr=False)
30-
@frontend_utils.wraps_to_flat_ir_to_func
3130
class Dequantize(BaseTraceOp):
3231

3332
dtype: datatype.dtype
@@ -36,6 +35,7 @@ class Dequantize(BaseTraceOp):
3635
def infer_dtypes(self):
3736
self.outputs[0].dtype = self.dtype
3837

38+
@frontend_utils.make_function
3939
def to_flat_ir(self, inputs, outputs):
4040
from tripy.common.datatype import int32
4141
from tripy.flat_ir.ops import (

tripy/tripy/frontend/trace/ops/fill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030

3131
@dataclass(repr=False)
32-
@frontend_utils.wraps_to_flat_ir_to_func
3332
class Fill(BaseTraceOp):
3433
value: float
3534
output_rank: int
@@ -55,6 +54,7 @@ def infer_rank(self):
5554
self.output_rank = input_shape[0]
5655
self.outputs[0].rank = self.output_rank
5756

57+
@frontend_utils.make_function
5858
def to_flat_ir(self, inputs, outputs):
5959
from tripy.flat_ir.ops import ConstantOp, ConvertOp, DynamicBroadcastOp
6060
from tripy.flat_ir.tensor import FlatIRTensor

0 commit comments

Comments
 (0)