Skip to content

Commit d3e9d2c

Browse files
Add tp.MaxPool (#204)
Signed-off-by: yizhuoz004 <[email protected]> Co-authored-by: pranavm-nvidia <[email protected]>
1 parent 7d1ebf8 commit d3e9d2c

File tree

10 files changed

+394
-48
lines changed

10 files changed

+394
-48
lines changed

tripy/tests/frontend/trace/ops/test_convolution.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_invalid_rank_fails(self, conv_func):
6666
@pytest.mark.parametrize(
6767
"padding, err, expect_input_stack_info",
6868
[
69-
(((2, 2),), r"expects padding-entries to have same dimension-size as size of window dimensions", True),
69+
(((2, 2),), r"Padding must have the same length as kernel_dims.", False),
7070
(((2, 2, 2), (2, 2, 2)), r"Padding must be provided as a sequence of pairs of integers.", False),
7171
(((1, 2), (-3, 1)), r"Negative padding is not supported.", False),
7272
],
@@ -84,7 +84,7 @@ def test_invalid_padding(self, conv_func, padding, err, expect_input_stack_info)
8484
"stride, err, expect_input_stack_info",
8585
[
8686
((-1, 0), r"Non-positive stride is not supported.", False),
87-
((2, 2, 2), r"expects window-strides to have same dimension-size as size of window dimensions", True),
87+
((2, 2, 2), r"Stride must have the same length as kernel_dims.", False),
8888
],
8989
)
9090
def test_invalid_stride(self, conv_func, stride, err, expect_input_stack_info):
@@ -136,11 +136,7 @@ def test_infer_rank(self, conv_func):
136136
"dilation, err, expect_input_stack_info",
137137
[
138138
((-1, 0), r"Non-positive dilation is not supported.", False),
139-
(
140-
(2, 2, 2),
141-
r"expects window-dilation factors to have same dimension-size as size of window dimensions.",
142-
True,
143-
),
139+
((2, 2, 2), r"Dilation must have the same length as kernel_dims.", False),
144140
],
145141
)
146142
def test_invalid_rhs_dilation(self, conv_func, dilation, err, expect_input_stack_info):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from tests import helper
17+
import tripy as tp
18+
19+
20+
class TestPooling:
21+
def test_invalid_kernel_dims(self):
22+
a = tp.ones((1, 1, 4, 4))
23+
with helper.raises(tp.TripyException, "Unsupported kernel_dims, must be 2D or 3D."):
24+
tp.maxpool(a, (2,))
25+
26+
def test_invalid_stride(self):
27+
a = tp.ones((1, 1, 4, 4))
28+
with helper.raises(tp.TripyException, "Stride must have the same length as kernel_dims."):
29+
tp.maxpool(a, (2, 2), stride=(1,))
30+
31+
def test_invalid_padding_length(self):
32+
a = tp.ones((1, 1, 4, 4))
33+
with helper.raises(tp.TripyException, "Padding must have the same length as kernel_dims."):
34+
tp.maxpool(a, (2, 2), padding=((1, 1),))
35+
36+
def test_invalid_padding_contents(self):
37+
a = tp.ones((1, 1, 4, 4))
38+
with helper.raises(tp.TripyException, "Padding must be provided as a sequence of pairs of integers."):
39+
tp.maxpool(a, (2, 2), padding=((1, 1, 1), (1, 1, 1)))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import torch
20+
21+
import tripy as tp
22+
23+
24+
class TestPooling:
25+
26+
@pytest.mark.parametrize(
27+
"kernel_dims, stride, padding",
28+
[
29+
((3, 3), (1, 1), ((0, 0), (0, 0))),
30+
((4, 4), (2, 2), ((1, 1), (2, 2))),
31+
],
32+
)
33+
@pytest.mark.parametrize("dtype", [tp.float32, tp.float16, tp.int8])
34+
def test_maxpool_2d(self, kernel_dims, stride, padding, dtype):
35+
inp_tp = tp.reshape(tp.arange(64, dtype=dtype), (1, 1, 8, 8))
36+
out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
37+
out_torch = torch.from_dlpack(out).to("cpu")
38+
39+
torch_padding = (padding[0][0], padding[1][0])
40+
pool_torch = torch.nn.MaxPool2d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
41+
expected = pool_torch(torch.from_dlpack(inp_tp).to("cpu"))
42+
assert torch.allclose(expected, out_torch)
43+
assert expected.shape == out_torch.shape
44+
45+
@pytest.mark.parametrize(
46+
"kernel_dims, stride, padding",
47+
[
48+
((2, 2, 2), (2, 2, 2), ((0, 0), (1, 1), (1, 1))),
49+
],
50+
)
51+
@pytest.mark.parametrize("dtype", [tp.float32, tp.float16])
52+
def test_maxpool_3d(self, kernel_dims, stride, padding, dtype):
53+
inp_tp = tp.reshape(tp.arange(512, dtype=dtype), (1, 1, 8, 8, 8))
54+
out = tp.maxpool(inp_tp, kernel_dims=kernel_dims, stride=stride, padding=padding)
55+
out_torch = torch.from_dlpack(out).to("cpu")
56+
57+
torch_padding = (padding[0][0], padding[1][0], padding[2][0])
58+
pool_torch = torch.nn.MaxPool3d(kernel_size=kernel_dims, stride=stride, padding=torch_padding)
59+
expected = pool_torch(torch.from_dlpack(inp_tp).to("cpu"))
60+
assert torch.allclose(expected, out_torch)
61+
assert expected.shape == out_torch.shape

tripy/tripy/backend/mlir/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import contextlib
19+
import numbers
1920
import os
2021
import re
2122
import sys
@@ -84,13 +85,20 @@ def get_mlir_dtype(dtype: "tripy.dtype"):
8485
}[dtype.name]
8586

8687

87-
def get_mlir_scalar_attr(dtype: "tripy.dtype", value):
88-
from tripy.common.datatype import floating
89-
88+
def get_mlir_scalar_attr(mlir_dtype, value):
9089
# MLIR represents float dtypes as FloatAttr
9190
# and non-float dtypes as IntegerAttr
92-
attr_func = ir.FloatAttr.get if issubclass(dtype, floating) else ir.IntegerAttr.get
93-
return attr_func(get_mlir_dtype(dtype), value)
91+
attr_func = ir.IntegerAttr.get if isinstance(mlir_dtype, ir.IntegerType) else ir.FloatAttr.get
92+
return attr_func(mlir_dtype, value)
93+
94+
95+
def list_to_dense_attr(data: List, mlir_dtype):
96+
if isinstance(data, numbers.Number):
97+
return [get_mlir_scalar_attr(mlir_dtype, data)]
98+
attrs = []
99+
for element in data:
100+
attrs.extend(list_to_dense_attr(element, mlir_dtype))
101+
return attrs
94102

95103

96104
def get_mlir_quant_dtype(

tripy/tripy/flat_ir/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tripy.flat_ir.ops.plugin import PluginOp
4343
from tripy.flat_ir.ops.pow import PowOp
4444
from tripy.flat_ir.ops.reduce import ArgMinMaxOp, ReduceOp
45+
from tripy.flat_ir.ops.reduce_window import ReduceWindowOp
4546
from tripy.flat_ir.ops.reshape import DynamicReshapeOp
4647
from tripy.flat_ir.ops.round_nearest_even import RoundNearestEvenOp
4748
from tripy.flat_ir.ops.rsqrt import RsqrtOp

tripy/tripy/flat_ir/ops/constant.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,9 @@ def to_mlir(self, operands):
7777
)
7878
else:
7979
out_dtype = self.outputs[0].dtype
80-
81-
def to_attrs(data):
82-
if isinstance(data, numbers.Number):
83-
return [mlir_utils.get_mlir_scalar_attr(out_dtype, data)]
84-
attrs = []
85-
for element in data:
86-
attrs.extend(to_attrs(element))
87-
return attrs
88-
89-
attr = ir.DenseElementsAttr.get(attrs=to_attrs(self.data), type=self.outputs[0].to_mlir())
80+
attr = ir.DenseElementsAttr.get(
81+
attrs=mlir_utils.list_to_dense_attr(self.data, mlir_utils.get_mlir_dtype(out_dtype)),
82+
type=self.outputs[0].to_mlir(),
83+
)
9084

9185
return [stablehlo.ConstantOp(attr)]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from dataclasses import dataclass
19+
from typing import Sequence, Tuple
20+
21+
from mlir_tensorrt.compiler import ir
22+
from mlir_tensorrt.compiler.dialects import stablehlo
23+
24+
from tripy.backend.mlir import utils as mlir_utils
25+
from tripy.flat_ir.ops.base import BaseFlatIROp
26+
27+
28+
@dataclass(repr=False)
29+
class ReduceWindowOp(BaseFlatIROp):
30+
31+
reduce_mode: str
32+
window_dims: Sequence[int]
33+
window_strides: Sequence[int]
34+
padding: Sequence[Tuple[int]]
35+
36+
def _get_reduce_func(self):
37+
if self.reduce_mode == "max":
38+
return stablehlo.MaxOp
39+
else:
40+
raise NotImplementedError()
41+
42+
def to_mlir(self, operands):
43+
input_dtype = self.inputs[0].dtype
44+
out_type = self.outputs[0].to_mlir()
45+
46+
window_dims_attr = ir.DenseI64ArrayAttr.get(self.window_dims)
47+
window_strides_attr = ir.DenseI64ArrayAttr.get(self.window_strides)
48+
padding_attr_type = ir.RankedTensorType.get(
49+
[len(self.padding), 2],
50+
ir.IntegerType.get_signless(64),
51+
)
52+
padding_attr = ir.DenseElementsAttr.get(
53+
attrs=mlir_utils.list_to_dense_attr(self.padding, ir.IntegerType.get_signless(64)),
54+
type=padding_attr_type,
55+
)
56+
57+
reduce = stablehlo.ReduceWindowOp(
58+
result=[out_type],
59+
inputs=[operands[0]],
60+
init_values=[operands[1]],
61+
window_dimensions=window_dims_attr,
62+
window_strides=window_strides_attr,
63+
padding=padding_attr,
64+
)
65+
66+
reduce_arg_type = ir.RankedTensorType.get(
67+
[],
68+
mlir_utils.get_mlir_dtype(input_dtype),
69+
)
70+
reduce_block = ir.Block.create_at_start(reduce.regions[0], [reduce_arg_type, reduce_arg_type])
71+
reduce_func = self._get_reduce_func()
72+
with ir.InsertionPoint(reduce_block):
73+
out = reduce_func(*reduce_block.arguments)
74+
stablehlo.ReturnOp([out])
75+
76+
return [reduce]

tripy/tripy/frontend/module/convolution.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tripy.common import datatype
2424
from tripy.frontend.module.module import Module
2525
from tripy.frontend.module.parameter import Parameter, DefaultParameter
26+
from tripy.frontend.trace.ops import utils as op_utils
2627

2728
from tripy.common.exception import raise_error
2829

@@ -70,37 +71,12 @@ def __init__(
7071
],
7172
)
7273

74+
op_utils.check_conv_pooling_args(kernel_dims, stride, padding, dilation)
7375
rank = len(kernel_dims) + 2
7476
self.padding = utils.default(padding, tuple(((0, 0) for _ in range(rank - 2))))
75-
76-
if not all(len(pad) == 2 for pad in self.padding):
77-
raise_error(
78-
f"Padding must be provided as a sequence of pairs of integers.",
79-
details=[f"Supplied padding attribute: {self.padding} contains sequences that are not of length 2."],
80-
)
81-
82-
if not all(p1 >= 0 and p2 >= 0 for p1, p2 in self.padding):
83-
raise_error(
84-
"Negative padding is not supported.",
85-
details=[f"Got padding: {self.padding} but all values must be non-negative integers."],
86-
)
87-
8877
self.stride = utils.default(stride, (1,) * (rank - 2))
89-
90-
if not all(s > 0 for s in self.stride):
91-
raise_error(
92-
"Non-positive stride is not supported.",
93-
details=[f"Got stride: {self.stride} but all values must be integers greater than 0."],
94-
)
95-
9678
self.dilation = utils.default(dilation, (1,) * (rank - 2))
9779

98-
if not all(isinstance(d, int) and d > 0 for d in self.dilation):
99-
raise_error(
100-
"Non-positive dilation is not supported.",
101-
details=[f"Got dilation: {self.dilation} but all values must be integers greater than 0."],
102-
)
103-
10480
if bias:
10581
self.bias = DefaultParameter((out_channels,), dtype=dtype)
10682

0 commit comments

Comments
 (0)