Skip to content

Commit 4ea2d8b

Browse files
Adds support for reflect mode in the pad operation
1 parent 6dcd6a8 commit 4ea2d8b

File tree

3 files changed

+58
-37
lines changed

3 files changed

+58
-37
lines changed

tripy/nvtripy/frontend/ops/pad.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from nvtripy.common.exception import raise_error
2222
from nvtripy.frontend.ops import utils as op_utils
2323
from nvtripy.trace.ops.shape import Shape
24-
from nvtripy.trace.ops.slice import SliceFill
24+
from nvtripy.trace.ops.slice import SliceFill, SliceReflect
2525
from nvtripy.types import IntLike
2626
from nvtripy.utils import wrappers
2727

@@ -46,22 +46,36 @@ def pad(
4646
of ``input``. Each element of ``pad`` is a tuple of integers or :class:`DimensionSize` s ``(low, high)``,
4747
which represents the padding sizes before the lowest index and after the highest index at
4848
the corresponding dimension.
49-
mode: The padding mode. Only "constant" is supported.
50-
value: The padding value for "constant" mode.
49+
mode: The padding mode. Must be one of:
50+
51+
- ``"constant"``: Pads with a constant value (default).
52+
- ``"reflect"``: Pads by reflecting the input tensor.
53+
value: The padding value for "constant" mode. Has no effect for other modes.
5154
5255
Returns:
5356
The padded tensor.
5457
5558
.. code-block:: python
5659
:linenos:
57-
:caption: Constant padding.
60+
:caption: Constant Padding
5861
5962
input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3))
6063
output = tp.pad(input, [(1, 0), (0, 1)])
6164
6265
input_np = np.arange(6, dtype=np.float32).reshape((2, 3)) # doc: omit
6366
expected = np.pad(input_np, ((1, 0), (0, 1))) # doc: omit
6467
assert np.array_equal(cp.from_dlpack(output).get(), expected)
68+
69+
.. code-block:: python
70+
:linenos:
71+
:caption: Reflect Padding
72+
73+
input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3))
74+
output = tp.pad(input, [(1, 1), (1, 1)], mode="reflect")
75+
76+
input_np = np.arange(6, dtype=np.float32).reshape((2, 3)) # doc: omit
77+
expected = np.pad(input_np, ((1, 1), (1, 1)), mode="reflect") # doc: omit
78+
assert np.array_equal(cp.from_dlpack(output).get(), expected)
6579
"""
6680
from nvtripy.frontend.ops.cast import cast
6781
from nvtripy.frontend.tensor import Tensor
@@ -72,7 +86,7 @@ def pad(
7286
[f"Got pad={pad}, ", f" input's rank={input.rank}"],
7387
)
7488

75-
supported_modes = {"constant"}
89+
supported_modes = {"constant", "reflect"}
7690
if mode not in supported_modes:
7791
raise_error(
7892
"Unsupported padding mode.",
@@ -89,4 +103,9 @@ def pad(
89103
sizes = input_shape + padding_lows + padding_highs
90104
steps = op_utils.tensor_from_shape_like([1] * input.rank)
91105

92-
return op_utils.create_op(SliceFill, [input, starts, sizes, steps, cast(Tensor(value), dtype=input.dtype)])
106+
inputs = [input, starts, sizes, steps]
107+
if mode == "constant":
108+
inputs.append(cast(Tensor(value), dtype=input.dtype))
109+
110+
OpType = {"constant": SliceFill, "reflect": SliceReflect}[mode]
111+
return op_utils.create_op(OpType, inputs)

tripy/nvtripy/trace/ops/slice.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,33 @@
2222
from nvtripy.trace.ops.base import TraceOp
2323

2424

25-
@dataclass(repr=False)
26-
class Slice(TraceOp):
25+
def make_slice_op(name, mode):
26+
@dataclass(repr=False)
27+
class SliceOp(TraceOp):
2728

28-
infer_rank = op_utils.InferRankPolicies.same_as_input()
29+
infer_rank = op_utils.InferRankPolicies.same_as_input()
2930

30-
def infer_dtypes(self):
31-
self.outputs[0].dtype = self.inputs[0].dtype
31+
def infer_dtypes(self):
32+
self.outputs[0].dtype = self.inputs[0].dtype
3233

33-
def to_mlir(self, inputs, outputs):
34-
assert len(inputs) == 4, "Slice operation must have exactly 4 inputs."
34+
def to_mlir(self, inputs, outputs):
35+
mode_attr = tensorrt.SliceModeAttr.get(mode)
3536

36-
return [tensorrt.slice(inputs[0], start=inputs[1], size=inputs[2], stride=inputs[3])]
37+
return [
38+
tensorrt.slice(
39+
inputs[0],
40+
start=inputs[1],
41+
size=inputs[2],
42+
stride=inputs[3],
43+
fill=inputs[4] if len(inputs) > 4 else None,
44+
mode=mode_attr,
45+
)
46+
]
3747

48+
SliceOp.__name__ = name
49+
return SliceOp
3850

39-
@dataclass(repr=False)
40-
class SliceFill(TraceOp):
4151

42-
infer_rank = op_utils.InferRankPolicies.same_as_input()
43-
44-
def infer_dtypes(self):
45-
self.outputs[0].dtype = self.inputs[0].dtype
46-
47-
def to_mlir(self, inputs, outputs):
48-
assert len(inputs) == 5, "SliceFill operation must have exactly 5 inputs."
49-
50-
mode_attr = tensorrt.SliceModeAttr.get("kFILL")
51-
52-
return [
53-
tensorrt.slice(
54-
inputs[0],
55-
start=inputs[1],
56-
size=inputs[2],
57-
stride=inputs[3],
58-
fill=inputs[4],
59-
mode=mode_attr,
60-
)
61-
]
52+
Slice = make_slice_op("Slice", "kDEFAULT")
53+
SliceFill = make_slice_op("SliceFill", "kFILL")
54+
SliceReflect = make_slice_op("SliceReflect", "kREFLECT")

tripy/tests/integration/test_pad.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -45,3 +45,12 @@ def test_pad_tensor(self, eager_or_compiled):
4545
expected = np.pad(inp, ((0, 2), (3, 0)))
4646

4747
assert np.array_equal(cp.from_dlpack(out).get(), expected)
48+
49+
@pytest.mark.parametrize("pad", [((1, 0), (0, 1)), ((2, 1), (1, 2))])
50+
def test_pad_reflect(self, pad, eager_or_compiled):
51+
inp = np.arange(6, dtype=np.float32).reshape((2, 3))
52+
53+
out = eager_or_compiled(tp.pad, tp.Tensor(inp), pad, mode="reflect")
54+
expected = np.pad(inp, pad, mode="reflect")
55+
56+
assert np.array_equal(cp.from_dlpack(out).get(), expected)

0 commit comments

Comments
 (0)