2121class FlatIRFunction :
2222 """Represents a function in the Flat IR."""
2323
24- def __init__ (self , name : str , inputs : List ["FlatIRTensor" ], outputs : List ["FlatIRTensor" ]):
24+ def __init__ (self , name : str , inputs : List ["FlatIRTensor" ], outputs : List ["FlatIRTensor" ], ops ):
2525 """Initialize a FlatIRFunction."""
2626 self .name = name
2727 self .inputs = inputs
2828 self .outputs = outputs
29- self .ops : List [ BaseFlatIROp ] = []
30- self .traced_ir_ops : List [ BaseFlatIROp ] = [] # Used for flat ir function deduplication
29+ self .ops = ops
30+ self .traced_ir_ops = ops # Used only for function deduplication.
3131 self .trace_input_names = None
3232 self .trace_output_names = None
3333 self .caller_replacements = []
@@ -38,8 +38,7 @@ def clone_with_new_io(
3838 """
3939 Create a clone of the function with new inputs and outputs.
4040 """
41- new_func = FlatIRFunction (self .name , new_inputs , new_outputs )
42- new_func .ops = self .ops
41+ new_func = FlatIRFunction (self .name , new_inputs , new_outputs , self .ops )
4342 new_func .trace_input_names = self .trace_input_names
4443 new_func .trace_output_names = self .trace_output_names
4544 return new_func
@@ -54,13 +53,6 @@ def get_caller_inputs(self) -> List["FlatIRTensor"]:
5453 def get_caller_outputs (self ) -> List ["FlatIRTensor" ]:
5554 return [getattr (out , "caller_tensor" ) for out in self .outputs ]
5655
57- def add_op (self , op : BaseFlatIROp ) -> None :
58- """
59- Add an operation to the function ops. `trace_ir_ops` are used only for function deduplication.
60- """
61- self .ops .append (op )
62- self .traced_ir_ops .append (op )
63-
6456 def __str__ (self ) -> str :
6557 """Generate a string representation of the function."""
6658 function_signature = [
0 commit comments