Skip to content

Commit f1d2b33

Browse files
committed
Implement average pooling
1 parent f8b5db3 commit f1d2b33

File tree

3 files changed

+103
-10
lines changed

3 files changed

+103
-10
lines changed

tripy/tests/integration/test_pooling.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,20 @@ class TestPooling:
3131
],
3232
)
3333
@pytest.mark.parametrize("dtype", [tp.float32, tp.float16, tp.int8])
34-
def test_maxpool_2d(self, kernel_dims, stride, padding, dtype):
34+
@pytest.mark.parametrize("pool_type", ["max", "avg"])
35+
def test_pool_2d(self, kernel_dims, stride, padding, dtype, pool_type):
3536
inp_tp = tp.reshape(tp.arange(64, dtype=dtype), (1, 1, 8, 8))
36-
out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
37-
out_torch = torch.from_dlpack(out).to("cpu")
38-
3937
torch_padding = (padding[0][0], padding[1][0])
40-
pool_torch = torch.nn.MaxPool2d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
38+
39+
if pool_type == "max":
40+
out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
41+
pool_torch = torch.nn.MaxPool2d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
42+
elif pool_type == "avg":
43+
pytest.skip("https://github.com/NVIDIA/TensorRT-Incubator/issues/237: Average pooling is not functional.")
44+
out = tp.avgpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
45+
pool_torch = torch.nn.AvgPool2d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
46+
47+
out_torch = torch.from_dlpack(out).to("cpu")
4148
expected = pool_torch(torch.from_dlpack(inp_tp).to("cpu"))
4249
assert torch.allclose(expected, out_torch)
4350
assert expected.shape == out_torch.shape
@@ -49,13 +56,21 @@ def test_maxpool_2d(self, kernel_dims, stride, padding, dtype):
4956
],
5057
)
5158
@pytest.mark.parametrize("dtype", [tp.float32, tp.float16])
52-
def test_maxpool_3d(self, kernel_dims, stride, padding, dtype):
59+
@pytest.mark.parametrize("pool_type", ["max", "avg"])
60+
def test_pool_3d(self, kernel_dims, stride, padding, dtype, pool_type):
5361
inp_tp = tp.reshape(tp.arange(512, dtype=dtype), (1, 1, 8, 8, 8))
54-
out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
62+
torch_padding = (padding[0][0], padding[1][0], padding[2][0])
63+
64+
if pool_type == "max":
65+
out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
66+
pool_torch = torch.nn.MaxPool3d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
67+
elif pool_type == "avg":
68+
pytest.skip("https://github.com/NVIDIA/TensorRT-Incubator/issues/237: Average pooling is not functional.")
69+
out = tp.avgpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
70+
pool_torch = torch.nn.AvgPool3d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
71+
5572
out_torch = torch.from_dlpack(out).to("cpu")
5673

57-
torch_padding = (padding[0][0], padding[1][0], padding[2][0])
58-
pool_torch = torch.nn.MaxPool3d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
5974
expected = pool_torch(torch.from_dlpack(inp_tp).to("cpu"))
6075
assert torch.allclose(expected, out_torch)
6176
assert expected.shape == out_torch.shape

tripy/tripy/flat_ir/ops/reduce_window.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ReduceWindowOp(BaseFlatIROp):
3636
def _get_reduce_func(self):
3737
if self.reduce_mode == "max":
3838
return stablehlo.MaxOp
39+
elif self.reduce_mode == "avg":
40+
return stablehlo.AddOp
3941
else:
4042
raise NotImplementedError()
4143

@@ -71,6 +73,21 @@ def to_mlir(self, operands):
7173
reduce_func = self._get_reduce_func()
7274
with ir.InsertionPoint(reduce_block):
7375
out = reduce_func(*reduce_block.arguments)
76+
if self.reduce_mode == "avg":
77+
# Calculate the number of elements in the window
78+
window_elements = 1
79+
for dim in self.window_dims:
80+
window_elements *= dim
81+
82+
# Create a dense elements attribute for the window size
83+
window_size_attr = ir.DenseElementsAttr.get_splat(
84+
ir.RankedTensorType.get([], mlir_utils.get_mlir_dtype(input_dtype)),
85+
ir.FloatAttr.get(mlir_utils.get_mlir_dtype(input_dtype), float(window_elements)),
86+
)
87+
window_size_const = stablehlo.ConstantOp(window_size_attr)
88+
89+
# Divide the sum by the window size
90+
out = stablehlo.DivOp(out, window_size_const)
7491
stablehlo.ReturnOp([out])
7592

7693
return [reduce]

tripy/tripy/frontend/trace/ops/pooling.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def maxpool(
106106
Args:
107107
input: The input tensor.
108108
kernel_dims: The spatial shape of the pooling window. Only 2-D or 3-D ``kernel_dims`` are supported.
109-
If the input has ``int8`` datatype, ``kernel_dims`` can only be 2-D.
109+
If the input has :class:`int8` datatype, ``kernel_dims`` can only be 2-D.
110110
stride: A sequence of length :math:`M` indicating the stride of pooling across each spatial dimension,
111111
where :math:`M` is the number of spatial dimensions, i.e. :math:`M = \text{rank(input)} - 2`.
112112
Defaults to all 1.
@@ -139,3 +139,64 @@ def maxpool(
139139
padding = utils.default(padding, [(0, 0)] * spatial_dims)
140140

141141
return Pooling.build([input], Pooling.Kind.MAX, kernel_dims, stride, padding)
142+
143+
144+
@export.public_api(document_under="operations/functions")
145+
@constraints.dtype_info(
146+
dtype_variables={
147+
"T1": ["float32", "float16", "int8"],
148+
},
149+
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
150+
)
151+
def avgpool(
152+
input: "tripy.Tensor",
153+
kernel_dims: Sequence[int],
154+
stride: Sequence[int] = None,
155+
padding: Sequence[Tuple[int]] = None,
156+
) -> "tripy.Tensor":
157+
r"""
158+
Applies an average pooling over the input tensor.
159+
160+
The output's non-spatial dimensions are the same as input. For each input spatial dimension
161+
:math:`D_{i}`, the corresponding output dimension will be:
162+
163+
.. math::
164+
D_{out_i} = \left\lfloor\frac{D_{i} + \text{padding_before[i]} + \text{padding_after[i]} -
165+
\text{kernel_dims[i]}}{\text{stride[i]}} + 1\right\rfloor
166+
167+
Args:
168+
input: The input tensor.
169+
kernel_dims: The spatial shape of the pooling window. Only 2-D or 3-D ``kernel_dims`` are supported.
170+
If the input has :class:`int8` datatype, ``kernel_dims`` can only be 2-D.
171+
stride: A sequence of length :math:`M` indicating the stride of pooling across each spatial dimension,
172+
where :math:`M` is the number of spatial dimensions, i.e. :math:`M = \text{rank(input)} - 2`.
173+
Defaults to all 1.
174+
padding: A sequence of pairs of integers of length :math:`M` indicating the zero padding
175+
to apply to the input along each spatial dimension before and after the dimension respectively,
176+
where :math:`M` is the number of spatial dimensions, i.e. :math:`M = \text{rank(input)} - 2`.
177+
Defaults to all 0.
178+
179+
Returns:
180+
The result tensor after the pooling operation.
181+
182+
.. code-block:: python
183+
:linenos:
184+
:caption: Example
185+
186+
input = tp.reshape(tp.arange(16, dtype=tp.float32), (1, 1, 4, 4))
187+
output = tp.avgpool(input, kernel_dims=(2, 2))
188+
189+
pool_torch = torch.nn.AvgPool2d((2, 2), stride=1) # doc: omit
190+
expected = pool_torch(torch.from_dlpack(input).to("cpu")) # doc: omit
191+
192+
assert torch.allclose(torch.from_dlpack(output).to("cpu"), expected)
193+
"""
194+
spatial_dims = len(kernel_dims)
195+
if spatial_dims != 2 and spatial_dims != 3:
196+
raise_error("Unsupported kernel_dims, must be 2D or 3D.", [f"Got kernel_dims={kernel_dims}"])
197+
198+
op_utils.check_conv_pooling_args(kernel_dims, stride, padding)
199+
stride = utils.default(stride, [1] * spatial_dims)
200+
padding = utils.default(padding, [(0, 0)] * spatial_dims)
201+
202+
return Pooling.build([input], Pooling.Kind.AVG, kernel_dims, stride, padding)

0 commit comments

Comments
 (0)