1818import tripy as tp
1919from tripy .frontend .trace import Trace
2020from tripy .flat_ir .ops import ArgMinMaxOp , ConvertOp , DivideOp , DynamicBroadcastOp , MulOp , ReduceOp
21+ from tripy .flat_ir .ops .base import FlatIRFunction
2122import re
2223
2324
@@ -30,10 +31,11 @@ def test_sum_str(self):
3031 trace = Trace ([out ])
3132 flat_ir = trace .to_flat_ir ()
3233
33- reduce = flat_ir .ops [- 1 ]
34+ func_reduce = flat_ir .ops [- 1 ]
35+ reduce = func_reduce .ops [- 1 ]
3436 assert isinstance (reduce , ReduceOp )
3537 assert re .match (
36- r"out : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(inp , t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)" ,
38+ r"t_inter[0-9]+ : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(t_inter[0-9]+ , t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)" ,
3739 str (reduce ),
3840 )
3941
@@ -45,11 +47,11 @@ def test_max_str(self):
4547 trace = Trace ([out ])
4648 flat_ir = trace .to_flat_ir ()
4749
48- reduce = flat_ir .ops [- 1 ]
50+ func_reduce = flat_ir .ops [- 1 ]
51+ reduce = func_reduce .ops [- 1 ]
4952 assert isinstance (reduce , ReduceOp )
50-
5153 assert re .match (
52- r"out : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(inp , t_inter[0-9]+, reduce_mode='max', reduce_dims=\[0\]\)" ,
54+ r"t_inter[0-9]+ : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ReduceOp\(t_inter[0-9]+ , t_inter[0-9]+, reduce_mode='max', reduce_dims=\[0\]\)" ,
5355 str (reduce ),
5456 )
5557
@@ -61,30 +63,43 @@ def test_mean_str(self):
6163 trace = Trace ([out ])
6264 flat_ir = trace .to_flat_ir ()
6365
64- div = flat_ir .ops [- 1 ]
66+ func_div = flat_ir .ops [- 1 ]
67+ div = func_div .ops [- 1 ]
68+ broadcast_a = func_div .ops [- 3 ]
69+ broadcast_b = func_div .ops [- 2 ]
6570 assert isinstance (div , DivideOp )
6671 assert re .match (
67- r"out : \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DivideOp\(t_inter[0-9]+, t_inter[0-9]+\)" ,
72+ r"t_inter[0-9]+ : \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DivideOp\(t_inter[0-9]+, t_inter[0-9]+\)" ,
6873 str (div ),
6974 )
7075
71- broadcast = flat_ir .ops [- 2 ]
72- assert isinstance (broadcast , DynamicBroadcastOp )
76+ assert isinstance (broadcast_a , DynamicBroadcastOp )
77+ assert re .match (
78+ r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[[0-9]*\]\)" ,
79+ str (broadcast_a ),
80+ )
81+
82+ assert isinstance (broadcast_b , DynamicBroadcastOp )
7383 assert re .match (
7484 r"t_inter[0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicBroadcastOp\(t_inter[0-9]+, t_inter[0-9]+, broadcast_dim=\[[0-9]*\]\)" ,
75- str (broadcast ),
85+ str (broadcast_b ),
7686 )
7787
78- mul = flat_ir .ops [- 15 ]
88+ mul = flat_ir .ops [- 3 ]. ops [ - 1 ]
7989 assert isinstance (mul , MulOp )
8090 assert re .match (
81- r"t [0-9]+: \[rank=\(0\), dtype=\(int32\), loc=\(gpu:0\)\] = MulOp\(t_inter[0-9]+, t_inter[0-9]+\)" ,
91+ r"t_inter [0-9]+: \[rank=\(0\), dtype=\(int32\), loc=\(gpu:0\)\] = MulOp\(t_inter[0-9]+, t_inter[0-9]+\)" ,
8292 str (mul ),
8393 )
84- reduce = flat_ir .ops [2 ]
94+
95+ func_reduce = flat_ir .ops [1 ]
96+ assert isinstance (func_reduce , FlatIRFunction )
97+
98+ reduce = func_reduce .ops [- 1 ]
8599 assert isinstance (reduce , ReduceOp )
100+
86101 assert re .match (
87- r"t [0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = ReduceOp\(inp , t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)" ,
102+ r"t_inter [0-9]+: \[rank=\(1\), dtype=\(float32\), loc=\(gpu:0\)\] = ReduceOp\(t_inter[0-9]+ , t_inter[0-9]+, reduce_mode='sum', reduce_dims=\[0\]\)" ,
88103 str (reduce ),
89104 )
90105
@@ -96,11 +111,14 @@ def test_argmax_str(self):
96111 trace = Trace ([out ])
97112 flat_ir = trace .to_flat_ir ()
98113
99- reduce = flat_ir .ops [- 1 ]
100- assert isinstance (reduce , ArgMinMaxOp )
114+ func_argminmax = flat_ir .ops [- 1 ]
115+ assert isinstance (func_argminmax , FlatIRFunction )
116+
117+ argminmax = func_argminmax .ops [- 1 ]
118+ assert isinstance (argminmax , ArgMinMaxOp )
101119 assert re .match (
102- r"out : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(inp, t [0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmax', reduce_dims=\[0\]\)" ,
103- str (reduce ),
120+ r"t_inter[0-9]+ : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(t_inter[0-9]+, t_inter [0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmax', reduce_dims=\[0\]\)" ,
121+ str (argminmax ),
104122 )
105123
106124 def test_argmin_str (self ):
@@ -111,9 +129,12 @@ def test_argmin_str(self):
111129 trace = Trace ([out ])
112130 flat_ir = trace .to_flat_ir ()
113131
114- reduce = flat_ir .ops [- 1 ]
115- assert isinstance (reduce , ArgMinMaxOp )
132+ func_argminmax = flat_ir .ops [- 1 ]
133+ assert isinstance (func_argminmax , FlatIRFunction )
134+
135+ argminmax = func_argminmax .ops [- 1 ]
136+ assert isinstance (argminmax , ArgMinMaxOp )
116137 assert re .match (
117- r"out : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(inp, t [0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmin', reduce_dims=\[0\]\)" ,
118- str (reduce ),
138+ r"t_inter[0-9]+ : \[rank=\(1\), dtype=\(int32\), loc=\(gpu:0\)\] = ArgMinMaxOp\(t_inter[0-9]+, t_inter [0-9]+, t_inter[0-9]+, t_inter[0-9]+, reduce_mode='argmin', reduce_dims=\[0\]\)" ,
139+ str (argminmax ),
119140 )
0 commit comments