2121from nvtripy .common .exception import raise_error
2222from nvtripy .frontend .ops import utils as op_utils
2323from nvtripy .trace .ops .shape import Shape
24- from nvtripy .trace .ops .slice import SliceFill
24+ from nvtripy .trace .ops .slice import SliceFill , SliceReflect
2525from nvtripy .types import IntLike
2626from nvtripy .utils import wrappers
2727
@@ -46,22 +46,36 @@ def pad(
4646 of ``input``. Each element of ``pad`` is a tuple of integers or :class:`DimensionSize` s ``(low, high)``,
4747 which represents the padding sizes before the lowest index and after the highest index at
4848 the corresponding dimension.
49- mode: The padding mode. Only "constant" is supported.
50- value: The padding value for "constant" mode.
49+ mode: The padding mode. Must be one of:
50+
51+ - ``"constant"``: Pads with a constant value (default).
52+ - ``"reflect"``: Pads by reflecting the input tensor.
53+ value: The padding value for "constant" mode. Has no effect for other modes.
5154
5255 Returns:
5356 The padded tensor.
5457
5558 .. code-block:: python
5659 :linenos:
57- :caption: Constant padding.
60+ :caption: Constant Padding
5861
5962 input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3))
6063 output = tp.pad(input, [(1, 0), (0, 1)])
6164
6265 input_np = np.arange(6, dtype=np.float32).reshape((2, 3)) # doc: omit
6366 expected = np.pad(input_np, ((1, 0), (0, 1))) # doc: omit
6467 assert np.array_equal(cp.from_dlpack(output).get(), expected)
68+
69+ .. code-block:: python
70+ :linenos:
71+ :caption: Reflect Padding
72+
73+ input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3))
74+ output = tp.pad(input, [(1, 1), (1, 1)], mode="reflect")
75+
76+ input_np = np.arange(6, dtype=np.float32).reshape((2, 3)) # doc: omit
77+ expected = np.pad(input_np, ((1, 1), (1, 1)), mode="reflect") # doc: omit
78+ assert np.array_equal(cp.from_dlpack(output).get(), expected)
6579 """
6680 from nvtripy .frontend .ops .cast import cast
6781 from nvtripy .frontend .tensor import Tensor
@@ -72,7 +86,7 @@ def pad(
7286 [f"Got pad={ pad } , " , f" input's rank={ input .rank } " ],
7387 )
7488
75- supported_modes = {"constant" }
89+ supported_modes = {"constant" , "reflect" }
7690 if mode not in supported_modes :
7791 raise_error (
7892 "Unsupported padding mode." ,
@@ -89,4 +103,9 @@ def pad(
89103 sizes = input_shape + padding_lows + padding_highs
90104 steps = op_utils .tensor_from_shape_like ([1 ] * input .rank )
91105
92- return op_utils .create_op (SliceFill , [input , starts , sizes , steps , cast (Tensor (value ), dtype = input .dtype )])
106+ inputs = [input , starts , sizes , steps ]
107+ if mode == "constant" :
108+ inputs .append (cast (Tensor (value ), dtype = input .dtype ))
109+
110+ OpType = {"constant" : SliceFill , "reflect" : SliceReflect }[mode ]
111+ return op_utils .create_op (OpType , inputs )
0 commit comments