@@ -49,7 +49,7 @@ def infer_dtypes(self):
4949 self .outputs [0 ].dtype = self .inputs [0 ].dtype
5050
5151 def to_flat_ir (self , inputs , outputs ):
52- from tripy .flat_ir .ops import ConstantOp , ReduceWindowOp
52+ from tripy .flat_ir .ops import ConstantOp , DivideOp , ReduceWindowOp
5353 from tripy .flat_ir .tensor import FlatIRTensor
5454
5555 init_value = 0
@@ -70,14 +70,74 @@ def to_flat_ir(self, inputs, outputs):
7070 window_strides = [1 ] * extra_dims + list (self .stride )
7171 padding = [(0 , 0 )] * extra_dims + list (self .padding )
7272
73- ReduceWindowOp .build (
74- [inputs [0 ], init_const ],
75- outputs ,
76- reduce_mode = self .kind .op ,
77- window_dims = window_dims ,
78- window_strides = window_strides ,
79- padding = padding ,
80- )
73+ if self .kind .op == "max" :
74+ ReduceWindowOp .build (
75+ [inputs [0 ], init_const ],
76+ outputs ,
77+ reduce_mode = self .kind .op ,
78+ window_dims = window_dims ,
79+ window_strides = window_strides ,
80+ padding = padding ,
81+ )
82+ elif self .kind .op == "avg" :
83+
84+ reduce_out = FlatIRTensor .build (
85+ rank = outputs [0 ].rank ,
86+ dtype = outputs [0 ].dtype ,
87+ device = outputs [0 ].device ,
88+ reason_details = [f"create the output of reduce `{ self .kind .op } ` operation." ],
89+ )
90+
91+ ReduceWindowOp .build (
92+ [inputs [0 ], init_const ],
93+ [reduce_out ],
94+ reduce_mode = self .kind .op ,
95+ window_dims = window_dims ,
96+ window_strides = window_strides ,
97+ padding = padding ,
98+ )
99+
100+ window_elements = 1
101+ for dim in window_dims :
102+ window_elements *= dim
103+
104+ # window_elements = compute_window_elements(self.kernel_dims, self.padding)
105+ init_const = FlatIRTensor .build (
106+ shape = (),
107+ rank = 0 ,
108+ dtype = outputs [0 ].dtype ,
109+ device = outputs [0 ].device ,
110+ reason_details = [
111+ f"create the constant value tensor (containing { window_elements } ) for the divisor of average pool operation."
112+ ],
113+ )
114+ ConstantOp .build ([], [init_const ], data = window_elements )
115+ with FlatIRTensor .context (
116+ [f"expand the rank of constant tensor which is the divisor of average pool operation." ]
117+ ):
118+ init_const = op_utils .expand_rank_of_tensor (init_const , inputs [0 ].rank )
119+
120+ with FlatIRTensor .context ([f"broadcast the inputs of division operation." ]):
121+ shape_of_input0 = op_utils .get_shape_of_tensor (reduce_out )
122+ shape_of_input1 = op_utils .get_shape_of_tensor (init_const )
123+
124+ # Compute element-wise max of input shapes to get the desired output shape.
125+ output_shape_tensor = op_utils .compute_shape_of_broadcast (
126+ shape_of_input0 ,
127+ shape_of_input1 ,
128+ inputs [0 ].rank ,
129+ shape1_name = f"the shape of the first input { shape_of_input0 } " ,
130+ shape2_name = f"the shape of the second input { shape_of_input1 } " ,
131+ )
132+
133+ init_const = op_utils .insert_broadcast (
134+ init_const ,
135+ out_rank = inputs [0 ].rank ,
136+ shape_of_target_tensor = output_shape_tensor ,
137+ tensor_details = f"left operand" ,
138+ )
139+
140+ DivideOp .build ([reduce_out , init_const ], outputs )
81141
82142
83143@export .public_api (document_under = "operations/functions" )
@@ -106,7 +166,7 @@ def maxpool(
106166 Args:
107167 input: The input tensor.
108168 kernel_dims: The spatial shape of the pooling window. Only 2-D or 3-D ``kernel_dims`` are supported.
109- If the input has `` int8` ` datatype, ``kernel_dims`` can only be 2-D.
169+ If the input has :class:` int8` datatype, ``kernel_dims`` can only be 2-D.
110170 stride: A sequence of length :math:`M` indicating the stride of pooling across each spatial dimension,
111171 where :math:`M` is the number of spatial dimensions, i.e. :math:`M = \text{rank(input)} - 2`.
112172 Defaults to all 1.
@@ -139,3 +199,64 @@ def maxpool(
139199 padding = utils .default (padding , [(0 , 0 )] * spatial_dims )
140200
141201 return Pooling .build ([input ], Pooling .Kind .MAX , kernel_dims , stride , padding )
202+
203+
204+ @export .public_api (document_under = "operations/functions" )
205+ @constraints .dtype_info (
206+ dtype_variables = {
207+ "T1" : ["float32" , "float16" ],
208+ },
209+ dtype_constraints = {"input" : "T1" , constraints .RETURN_VALUE : "T1" },
210+ )
211+ def avgpool (
212+ input : "tripy.Tensor" ,
213+ kernel_dims : Sequence [int ],
214+ stride : Sequence [int ] = None ,
215+ padding : Sequence [Tuple [int ]] = None ,
216+ ) -> "tripy.Tensor" :
217+ r"""
218+ Applies an average pooling over the input tensor.
219+
220+ The output's non-spatial dimensions are the same as input. For each input spatial dimension
221+ :math:`D_{i}`, the corresponding output dimension will be:
222+
223+ .. math::
224+ D_{out_i} = \left\lfloor\frac{D_{i} + \text{padding_before[i]} + \text{padding_after[i]} -
225+ \text{kernel_dims[i]}}{\text{stride[i]}} + 1\right\rfloor
226+
227+ Args:
228+ input: The input tensor.
229+ kernel_dims: The spatial shape of the pooling window. Only 2-D or 3-D ``kernel_dims`` are supported.
230+ If the input has :class:`int8` datatype, ``kernel_dims`` can only be 2-D.
231+ stride: A sequence of length :math:`M` indicating the stride of pooling across each spatial dimension,
232+ where :math:`M` is the number of spatial dimensions, i.e. :math:`M = \text{rank(input)} - 2`.
233+ Defaults to all 1.
234+ padding: A sequence of pairs of integers of length :math:`M` indicating the zero padding
235+ to apply to the input along each spatial dimension before and after the dimension respectively,
236+ where :math:`M` is the number of spatial dimensions, i.e. :math:`M = \text{rank(input)} - 2`.
237+ Defaults to all 0.
238+
239+ Returns:
240+ The result tensor after the pooling operation.
241+
242+ .. code-block:: python
243+ :linenos:
244+ :caption: Example
245+
246+ input = tp.reshape(tp.arange(16, dtype=tp.float32), (1, 1, 4, 4))
247+ output = tp.avgpool(input, kernel_dims=(2, 2))
248+
249+ pool_torch = torch.nn.AvgPool2d((2, 2), stride=1) # doc: omit
250+ expected = pool_torch(torch.from_dlpack(input).to("cpu")) # doc: omit
251+
252+ assert torch.allclose(torch.from_dlpack(output).to("cpu"), expected)
253+ """
254+ spatial_dims = len (kernel_dims )
255+ if spatial_dims != 2 and spatial_dims != 3 :
256+ raise_error ("Unsupported kernel_dims, must be 2D or 3D." , [f"Got kernel_dims={ kernel_dims } " ])
257+
258+ op_utils .check_conv_pooling_args (kernel_dims , stride , padding )
259+ stride = utils .default (stride , [1 ] * spatial_dims )
260+ padding = utils .default (padding , [(0 , 0 )] * spatial_dims )
261+
262+ return Pooling .build ([input ], Pooling .Kind .AVG , kernel_dims , stride , padding )
0 commit comments