@@ -43,7 +43,8 @@ def __init__(self, shapes: Sequence[ShapeBounds] = None):
4343
4444 def __str__ (self ) -> str :
4545 """Generate a string representation of the FlatIR."""
46- from tripy .flat_ir .ops .base import FlatIRFunction , BaseFlatIROp
46+ from tripy .flat_ir .function import FlatIRFunction
47+ from tripy .flat_ir .ops .base import BaseFlatIROp
4748
4849 ir = []
4950
@@ -96,7 +97,8 @@ def to_mlir(self):
9697 from mlir_tensorrt .compiler import ir
9798 from mlir_tensorrt .compiler .dialects import func as func_dialect
9899 from tripy .backend .mlir .utils import make_ir_context , make_tensor_location
99- from tripy .flat_ir .ops .base import FlatIRFunction , BaseFlatIROp
100+ from tripy .flat_ir .function import FlatIRFunction
101+ from tripy .flat_ir .ops .base import BaseFlatIROp
100102
101103 def _base_op_to_mlir (op , mlir_tensor_map ):
102104 op_inputs = [mlir_tensor_map [input_tensor .name ] for input_tensor in op .inputs ]
@@ -376,7 +378,8 @@ def integrate_subgraph(self, inputs: List["FlatIRTensor"], outputs: List["FlatIR
376378 """
377379 Integrate a subgraph delineated by the given inputs and outputs into this FlatIR.
378380 """
379- from tripy .flat_ir .ops .base import BaseFlatIROp , FlatIRFunction
381+ from tripy .flat_ir .function import FlatIRFunction
382+ from tripy .flat_ir .ops .base import BaseFlatIROp
380383 from tripy .flat_ir .tensor import FlatIRTensor
381384 from tripy .flat_ir .ops import ConstantOp
382385
0 commit comments