Skip to content

Commit ecd0a28

Browse files
committed
Add flat ir function
1 parent 6e51621 commit ecd0a28

31 files changed

+716
-223
lines changed

tripy/tests/backend/mlir/test_compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ def test_reason_context(self):
5757
b = tp.ones((1,))
5858
trace = Trace([a + b])
5959
flat_ir = trace.to_flat_ir()
60-
producer = flat_ir.outputs[0].producer.inputs[0]
60+
func_binary = flat_ir.outputs[0].producer
61+
producer = func_binary.ops[-1].inputs[0]
6162
flat_ir_inputs = ",".join(map(lambda i: i.name, producer.producer.inputs))
62-
trace_inputs = ",".join(producer.producer.trace_input_names)
63-
trace_output = producer.producer.trace_output_names[0]
63+
trace_inputs = ",".join(func_binary.trace_input_names)
64+
trace_output = func_binary.trace_output_names[0]
6465
err_str = f'loc("{flat_ir_inputs};;<out>;;{producer.name};;<trace_in>;;{trace_inputs};;<trace_out>;;{trace_output}"): Test error'
6566

6667
with pytest.raises(

tripy/tests/flat_ir/ops/test_broadcast.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ def test_str(self):
2929
trace = Trace([out])
3030
flat_ir = trace.to_flat_ir()
3131

32-
broadcast = flat_ir.ops[-1]
32+
func_broadcast = flat_ir.ops[-1]
33+
broadcast = func_broadcast.ops[-1]
3334
assert isinstance(broadcast, DynamicBroadcastOp)
3435
assert re.match(
35-
r"out: \[rank=\(2\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter2, t[0-9]+, broadcast_dim=\[\]\)",
36+
r"t_inter3: \[rank=\(2\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter4, t_inter2, broadcast_dim=\[\]\)",
3637
str(broadcast),
3738
)

tripy/tests/flat_ir/ops/test_divide.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,21 @@ def test_str(self):
3131
trace = Trace([out])
3232
flat_ir = trace.to_flat_ir()
3333

34-
div = flat_ir.ops[-1]
35-
broadcast_a = flat_ir.ops[-3]
36-
broadcast_b = flat_ir.ops[-2]
34+
func_div = flat_ir.ops[-1]
35+
div = func_div.ops[-1]
36+
broadcast_a = func_div.ops[-3]
37+
broadcast_b = func_div.ops[-2]
3738
assert isinstance(div, DivideOp)
3839

3940
assert re.match(
40-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(a, t_inter[0-9]+, broadcast_dim=\[0\]\)",
41+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4142
str(broadcast_a),
4243
)
4344
assert re.match(
44-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(b, t_inter[0-9]+, broadcast_dim=\[0\]\)",
45+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4546
str(broadcast_b),
4647
)
4748
assert re.match(
48-
r"out: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DivideOp\(t_inter[0-9]+, t_inter[0-9]+\)",
49+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DivideOp\(t_inter[0-9]+, t_inter[0-9]+\)",
4950
str(div),
5051
)

tripy/tests/flat_ir/ops/test_gather.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tripy as tp
2020

2121
from tripy.flat_ir.ops import DynamicGatherOp
22+
from tripy.flat_ir.ops.base import FlatIRFunction
2223
from tripy.frontend.trace import Trace
2324
import re
2425

@@ -35,12 +36,15 @@ def test_gather_str(self, axis):
3536
trace = Trace([out])
3637
flat_ir = trace.to_flat_ir()
3738

38-
gather = flat_ir.ops[-1]
39-
reshape = flat_ir.ops[-2]
39+
func_gather = flat_ir.ops[-1]
40+
assert isinstance(func_gather, FlatIRFunction)
41+
42+
gather = func_gather.ops[-1]
43+
reshape = func_gather.ops[-7]
4044
print(str(reshape))
4145
assert isinstance(gather, DynamicGatherOp)
4246
assert re.match(
43-
rf"out: \[rank=\(3\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicGatherOp\(data, indices, t_inter[0-9]+, axis={axis}\)",
47+
rf"t_inter[0-9]+: \[rank=\(3\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicGatherOp\(t_inter[0-9]+, t_inter[0-9]+, t_inter[0-9]+, axis={axis}\)",
4448
str(gather),
4549
)
4650

@@ -51,7 +55,7 @@ def test_gather_mlir(self, axis):
5155
flat_ir = trace.to_flat_ir()
5256
mlir_text = str(flat_ir.to_mlir())
5357
if axis == 0:
54-
target = '"stablehlo.dynamic_gather"(%c, %c_0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>}> : (tensor<2x3xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<1x?xi32>'
58+
target = '"stablehlo.dynamic_gather"(%arg0, %arg1, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>}> : (tensor<2x3xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<1x?xi32>'
5559
else:
56-
target = '"stablehlo.dynamic_gather"(%c, %c_0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<2x3xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<?x1xi32>'
60+
target = '"stablehlo.dynamic_gather"(%arg0, %arg1, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<2x3xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<?x1xi32>'
5761
assert target in mlir_text, mlir_text

tripy/tests/flat_ir/ops/test_maximum.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tripy as tp
2020
from tripy.frontend.trace import Trace
2121
from tripy.flat_ir.ops import MaxOp
22+
from tripy.flat_ir.ops.base import FlatIRFunction
2223

2324

2425
class TestMaxOp:
@@ -31,20 +32,22 @@ def test_str(self):
3132
trace = Trace([out])
3233
flat_ir = trace.to_flat_ir()
3334

34-
max_op = flat_ir.ops[-1]
35-
broadcast_a = flat_ir.ops[-3]
36-
broadcast_b = flat_ir.ops[-2]
35+
func_max = flat_ir.ops[-1]
36+
assert isinstance(func_max, FlatIRFunction)
3737

38+
max_op = func_max.ops[-1]
39+
broadcast_a = func_max.ops[-3]
40+
broadcast_b = func_max.ops[-2]
3841
assert isinstance(max_op, MaxOp)
3942
assert re.match(
40-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(a, t_inter[0-9]+, broadcast_dim=\[0\]\)",
43+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4144
str(broadcast_a),
4245
)
4346
assert re.match(
44-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(b, t_inter[0-9]+, broadcast_dim=\[0\]\)",
47+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4548
str(broadcast_b),
4649
)
4750
assert re.match(
48-
r"out: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = MaxOp\(t_inter[0-9]+, t_inter[0-9]+\)",
51+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = MaxOp\(t_inter[0-9]+, t_inter[0-9]+\)",
4952
str(max_op),
5053
)

tripy/tests/flat_ir/ops/test_minimum.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,22 @@ def test_str(self):
3131
trace = Trace([out])
3232
flat_ir = trace.to_flat_ir()
3333

34-
min_op = flat_ir.ops[-1]
35-
broadcast_a = flat_ir.ops[-3]
36-
broadcast_b = flat_ir.ops[-2]
34+
func_min = flat_ir.ops[-1]
35+
min_op = func_min.ops[-1]
36+
broadcast_a = func_min.ops[-3]
37+
broadcast_b = func_min.ops[-2]
3738

3839
assert isinstance(min_op, MinOp)
3940
assert re.match(
40-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(a, t_inter[0-9]+, broadcast_dim=\[0\]\)",
41+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4142
str(broadcast_a),
4243
)
4344
assert re.match(
44-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(b, t_inter[0-9]+, broadcast_dim=\[0\]\)",
45+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4546
str(broadcast_b),
4647
)
4748

4849
assert re.match(
49-
r"out: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = MinOp\(t_inter[0-9]+, t_inter[0-9]+\)",
50+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = MinOp\(t_inter[0-9]+, t_inter[0-9]+\)",
5051
str(min_op),
5152
)

tripy/tests/flat_ir/ops/test_reduce.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tripy as tp
1919
from tripy.frontend.trace import Trace
2020
from tripy.flat_ir.ops import ArgMinMaxOp, ConvertOp, DivideOp, DynamicBroadcastOp, MulOp, ReduceOp
21+
from tripy.flat_ir.ops.base import FlatIRFunction
2122
import re
2223

2324

@@ -30,10 +31,11 @@ def test_sum_str(self):
3031
trace = Trace([out])
3132
flat_ir = trace.to_flat_ir()
3233

33-
reduce = flat_ir.ops[-1]
34+
func_reduce = flat_ir.ops[-1]
35+
reduce = func_reduce.ops[-1]
3436
assert isinstance(reduce, ReduceOp)
3537
assert re.match(
36-
r"out: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(inp, t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)",
38+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(t_inter[0-9]+, t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)",
3739
str(reduce),
3840
)
3941

@@ -45,11 +47,11 @@ def test_max_str(self):
4547
trace = Trace([out])
4648
flat_ir = trace.to_flat_ir()
4749

48-
reduce = flat_ir.ops[-1]
50+
func_reduce = flat_ir.ops[-1]
51+
reduce = func_reduce.ops[-1]
4952
assert isinstance(reduce, ReduceOp)
50-
5153
assert re.match(
52-
r"out: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(inp, t_inter[0-9]+, reduce_mode='max', reduce_dims=\[0\]\)",
54+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(t_inter[0-9]+, t_inter[0-9]+, reduce_mode='max', reduce_dims=\[0\]\)",
5355
str(reduce),
5456
)
5557

@@ -61,30 +63,43 @@ def test_mean_str(self):
6163
trace = Trace([out])
6264
flat_ir = trace.to_flat_ir()
6365

64-
div = flat_ir.ops[-1]
66+
func_div = flat_ir.ops[-1]
67+
div = func_div.ops[-1]
68+
broadcast_a = func_div.ops[-3]
69+
broadcast_b = func_div.ops[-2]
6570
assert isinstance(div, DivideOp)
6671
assert re.match(
67-
r"out: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DivideOp\(t_inter[0-9]+, t_inter[0-9]+\)",
72+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DivideOp\(t_inter[0-9]+, t_inter[0-9]+\)",
6873
str(div),
6974
)
7075

71-
broadcast = flat_ir.ops[-2]
72-
assert isinstance(broadcast, DynamicBroadcastOp)
76+
assert isinstance(broadcast_a, DynamicBroadcastOp)
77+
assert re.match(
78+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[[0-9]*\]\)",
79+
str(broadcast_a),
80+
)
81+
82+
assert isinstance(broadcast_b, DynamicBroadcastOp)
7383
assert re.match(
7484
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[[0-9]*\]\)",
75-
str(broadcast),
85+
str(broadcast_b),
7686
)
7787

78-
mul = flat_ir.ops[-15]
88+
mul = flat_ir.ops[-3].ops[-1]
7989
assert isinstance(mul, MulOp)
8090
assert re.match(
81-
r"t[0-9]+: \[rank=\(0\), dtype=\(int32\), loc=\(gpu:0\)\] = MulOp\(t_inter[0-9]+, t_inter[0-9]+\)",
91+
r"t_inter[0-9]+: \[rank=\(0\), dtype=\(int32\), loc=\(gpu:0\)\] = MulOp\(t_inter[0-9]+, t_inter[0-9]+\)",
8292
str(mul),
8393
)
84-
reduce = flat_ir.ops[2]
94+
95+
func_reduce = flat_ir.ops[1]
96+
assert isinstance(func_reduce, FlatIRFunction)
97+
98+
reduce = func_reduce.ops[-1]
8599
assert isinstance(reduce, ReduceOp)
100+
86101
assert re.match(
87-
r"t[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = ReduceOp\(inp, t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)",
102+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = ReduceOp\(t_inter[0-9]+, t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)",
88103
str(reduce),
89104
)
90105

@@ -96,11 +111,14 @@ def test_argmax_str(self):
96111
trace = Trace([out])
97112
flat_ir = trace.to_flat_ir()
98113

99-
reduce = flat_ir.ops[-1]
100-
assert isinstance(reduce, ArgMinMaxOp)
114+
func_argminmax = flat_ir.ops[-1]
115+
assert isinstance(func_argminmax, FlatIRFunction)
116+
117+
argminmax = func_argminmax.ops[-1]
118+
assert isinstance(argminmax, ArgMinMaxOp)
101119
assert re.match(
102-
r"out: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(inp, t[0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmax', reduce_dims=\[0\]\)",
103-
str(reduce),
120+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(t_inter[0-9]+, t_inter[0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmax', reduce_dims=\[0\]\)",
121+
str(argminmax),
104122
)
105123

106124
def test_argmin_str(self):
@@ -111,9 +129,12 @@ def test_argmin_str(self):
111129
trace = Trace([out])
112130
flat_ir = trace.to_flat_ir()
113131

114-
reduce = flat_ir.ops[-1]
115-
assert isinstance(reduce, ArgMinMaxOp)
132+
func_argminmax = flat_ir.ops[-1]
133+
assert isinstance(func_argminmax, FlatIRFunction)
134+
135+
argminmax = func_argminmax.ops[-1]
136+
assert isinstance(argminmax, ArgMinMaxOp)
116137
assert re.match(
117-
r"out: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(inp, t[0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmin', reduce_dims=\[0\]\)",
118-
str(reduce),
138+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(t_inter[0-9]+, t_inter[0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmin', reduce_dims=\[0\]\)",
139+
str(argminmax),
119140
)

tripy/tests/flat_ir/ops/test_subtract.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tripy as tp
2020
from tripy.frontend.trace import Trace
2121
from tripy.flat_ir.ops import SubtractOp
22+
from tripy.flat_ir.ops.base import FlatIRFunction
2223

2324

2425
class TestSubtractOp:
@@ -31,21 +32,24 @@ def test_str(self):
3132
trace = Trace([out])
3233
flat_ir = trace.to_flat_ir()
3334

34-
sub = flat_ir.ops[-1]
35-
broadcast_a = flat_ir.ops[-3]
36-
broadcast_b = flat_ir.ops[-2]
35+
func_sub = flat_ir.ops[-1]
36+
assert isinstance(func_sub, FlatIRFunction)
37+
38+
sub = func_sub.ops[-1]
39+
broadcast_a = func_sub.ops[-3]
40+
broadcast_b = func_sub.ops[-2]
3741

3842
assert isinstance(sub, SubtractOp)
3943

4044
assert re.match(
41-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(a, t_inter[0-9]+, broadcast_dim=\[0\]\)",
45+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4246
str(broadcast_a),
4347
)
4448
assert re.match(
45-
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(b, t_inter[0-9]+, broadcast_dim=\[0\]\)",
49+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[0\]\)",
4650
str(broadcast_b),
4751
)
4852
assert re.match(
49-
r"out: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = SubtractOp\(t_inter[0-9]+, t_inter[0-9]+\)",
53+
r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = SubtractOp\(t_inter[0-9]+, t_inter[0-9]+\)",
5054
str(sub),
5155
)

tripy/tests/flat_ir/test_constant_deduplication.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytest
1717
from tripy.flat_ir.flat_ir import FlatIR
18+
from tripy.flat_ir.ops.base import FlatIRFunction
1819
from tripy.flat_ir.ops import ConstantOp
1920
from tripy.flat_ir.tensor import FlatIRTensor
2021
from tripy.common.device import device
@@ -29,7 +30,7 @@ def __init__(self, inputs, outputs):
2930
output.producer = self
3031

3132

32-
def create_subgraph():
33+
def create_subgraph(config):
3334
# Create constant tensors and ops
3435
const1 = FlatIRTensor.build(shape=[2], rank=1, dtype=int32, reason_details="", device=device("gpu"))
3536
op1 = ConstantOp.build([], [const1], data=[1, 2])
@@ -48,18 +49,34 @@ def create_subgraph():
4849
result_tensor = FlatIRTensor.build(shape=[2], rank=1, dtype=int32, reason_details="", device=device("gpu"))
4950
mock_op = MockOp([const1, const2, const3], [result_tensor])
5051

52+
if config == "func":
53+
# Create a function with no inputs and a single output
54+
func_result_tensor = FlatIRTensor.build(shape=[2], rank=1, dtype=int32, reason_details="", device=device("gpu"))
55+
setattr(result_tensor, "caller_tensor", func_result_tensor)
56+
func = FlatIRFunction("MockFunc", [], [result_tensor])
57+
func_result_tensor.producer = func
58+
59+
# Insert all operations in a function
60+
func.ops = [op1, op2, op3, mock_op]
61+
62+
# Return function result tensor i.e. output of a function call
63+
return [], [func_result_tensor]
64+
5165
return [], [result_tensor]
5266

5367

54-
def test_integrate_subgraph_constant_deduplication():
68+
@pytest.mark.parametrize("config", ["main", "func"])
69+
def test_integrate_subgraph_constant_deduplication(config):
5570
flat_ir = FlatIR()
56-
inputs, outputs = create_subgraph()
71+
inputs, outputs = create_subgraph(config)
5772

5873
# Integrate the subgraph
5974
flat_ir.integrate_subgraph(inputs, outputs)
6075

6176
# Verify that only two ConstantOps remain
62-
constant_ops = [op for op in flat_ir.ops if isinstance(op, ConstantOp)]
77+
ops = flat_ir.ops[0].ops if isinstance(flat_ir.ops[0], FlatIRFunction) else flat_ir.ops
78+
79+
constant_ops = [op for op in ops if isinstance(op, ConstantOp)]
6380
assert len(constant_ops) == 2, "There should be only two ConstantOps after integration"
6481

6582
# Verify that the remaining ConstantOps have different data
@@ -72,12 +89,6 @@ def test_integrate_subgraph_constant_deduplication():
7289
assert set(constant_data) == expected_data, f"Expected constant data {expected_data}, but got {set(constant_data)}"
7390

7491
# Verify that the mock op now uses the same tensor for its first two inputs
75-
mock_op = [op for op in flat_ir.ops if isinstance(op, MockOp)][0]
92+
mock_op = [op for op in ops if isinstance(op, MockOp)][0]
7693
assert mock_op.inputs[0] is mock_op.inputs[1], "The mock op should use the same tensor for its first two inputs"
7794
assert mock_op.inputs[0] is not mock_op.inputs[2], "The mock op should still have a different third input"
78-
79-
# Verify that tensor replacements were applied
80-
assert len(flat_ir.tensor_replacements) > 0, "There should be tensor replacements after integration"
81-
82-
# Verify that the constant map has the correct number of entries
83-
assert len(flat_ir.constant_map) == 2, "Constant map should have 2 entries"

0 commit comments

Comments
 (0)