Skip to content

Commit e5fe27e

Browse files
Updates maximum, minimum, and where to accept non-tripy-tensors
1 parent 67d21c4 commit e5fe27e

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

tripy/nvtripy/frontend/ops/binary/maximum.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
from nvtripy import export
1616
from nvtripy.frontend.ops.binary.create import create_binary_op
1717
from nvtripy.trace.ops.binary import Max
18+
from nvtripy.types import TensorLike
1819
from nvtripy.utils import wrappers
1920

2021

2122
@export.public_api(document_under="operations/functions")
2223
@wrappers.interface(
2324
dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"},
2425
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},
26+
convert_to_tensors=True,
2527
)
26-
def maximum(lhs: "nvtripy.Tensor", rhs: "nvtripy.Tensor") -> "nvtripy.Tensor":
28+
def maximum(lhs: TensorLike, rhs: TensorLike) -> "nvtripy.Tensor":
2729
"""
2830
Performs an elementwise maximum.
2931

tripy/nvtripy/frontend/ops/binary/minimum.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
from nvtripy import export
1616
from nvtripy.frontend.ops.binary.create import create_binary_op
1717
from nvtripy.trace.ops.binary import Min
18+
from nvtripy.types import TensorLike
1819
from nvtripy.utils import wrappers
1920

2021

2122
@export.public_api(document_under="operations/functions")
2223
@wrappers.interface(
2324
dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"},
2425
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},
26+
convert_to_tensors=True,
2527
)
26-
def minimum(lhs: "nvtripy.Tensor", rhs: "nvtripy.Tensor") -> "nvtripy.Tensor":
28+
def minimum(lhs: TensorLike, rhs: TensorLike) -> "nvtripy.Tensor":
2729
"""
2830
Performs an elementwise minimum.
2931

tripy/nvtripy/frontend/ops/where.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from nvtripy import export
2020
from nvtripy.frontend.ops import utils as op_utils
2121
from nvtripy.trace.ops.where import Where
22+
from nvtripy.types import TensorLike
2223
from nvtripy.utils import wrappers
2324

2425

@@ -29,8 +30,9 @@
2930
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"],
3031
"T2": ["bool"],
3132
},
33+
convert_to_tensors=True,
3234
)
33-
def where(condition: "nvtripy.Tensor", input: "nvtripy.Tensor", other: "nvtripy.Tensor") -> "nvtripy.Tensor":
35+
def where(condition: "nvtripy.Tensor", input: TensorLike, other: TensorLike) -> "nvtripy.Tensor":
3436
r"""
3537
Returns a new tensor of elements selected from either ``input`` or ``other``, depending on ``condition``.
3638

0 commit comments

Comments
 (0)