Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tripy/nvtripy/frontend/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@
import nvtripy.frontend.ops.matmul
import nvtripy.frontend.ops.shape
import nvtripy.frontend.ops.slice

# Import regular methods that should be available as tensor methods
import nvtripy.frontend.ops.cast
import nvtripy.frontend.ops.copy
import nvtripy.frontend.ops.reshape
import nvtripy.frontend.ops.transpose
import nvtripy.frontend.ops.flatten
import nvtripy.frontend.ops.permute
import nvtripy.frontend.ops.squeeze
import nvtripy.frontend.ops.unsqueeze
11 changes: 7 additions & 4 deletions tripy/nvtripy/frontend/ops/_registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,6 +16,7 @@
#

from typing import Any, Callable
from functools import wraps

# We use the tensor method registry to define methods on the `Tensor` class out of line.
# This lets the method live alongside the trace operation and makes it a bit more modular
Expand All @@ -28,14 +29,16 @@ def register_tensor_method(name: str):
Decorator to add the method to the tensor method registry with the name specified.
This does not use the FunctionRegistry decorator because every tensor method would also be
registered in the public function registry and we would prefer to avoid having overhead
from having to dispatch overloads and check types twice.
from having to dispatch overloads and check types twice. This needs to be the top level decorator so we can
get input type validation from other decorators like `public_api`.
"""

# We make a special exception for "shape" since we actually do want that to be a property
allowed_methods = ["shape"]
# We also add additional methods of the tensor class that are not magic methods
allowed_methods = ["copy", "cast", "shape", "reshape", "transpose", "flatten", "permute", "squeeze", "unsqueeze"]
assert name in allowed_methods or name.startswith(
"__"
), f"The tensor method registry should only be used for magic methods, but was used for: {name}"
), f"The tensor method registry should only be used for magic methods and specially allowed methods, but was used for: {name}"

def impl(func: Callable[..., Any]) -> Callable[..., Any]:
TENSOR_METHOD_REGISTRY[name] = func
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from nvtripy.common.datatype import bool as tp_bool
from nvtripy.common.datatype import float32, int8
from nvtripy.frontend.ops import utils as op_utils
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.frontend.ops.dequantize import dequantize
from nvtripy.frontend.ops.quantize import quantize
from nvtripy.trace.ops.cast import Cast
from nvtripy.utils import wrappers


@register_tensor_method("cast")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"},
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from nvtripy.common import device as tp_device
from nvtripy.common.datatype import DATA_TYPES
from nvtripy.common.exception import raise_error
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.utils import wrappers


@register_tensor_method("copy")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from nvtripy import export
from nvtripy.common.exception import raise_error
from nvtripy.frontend.ops import utils as op_utils
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.utils import wrappers


@register_tensor_method("flatten")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from nvtripy import export
from nvtripy.common.exception import raise_error
from nvtripy.frontend.ops import utils as op_utils
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.trace.ops.permute import Permute
from nvtripy.utils import wrappers


@register_tensor_method("permute")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nvtripy import export
from nvtripy.common.exception import raise_error
from nvtripy.frontend.ops import utils as op_utils
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.trace.ops.reshape import Reshape
from nvtripy.types import ShapeLike
from nvtripy.utils import wrappers
Expand All @@ -42,6 +43,7 @@ def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike:
return {"shape": shape}


@register_tensor_method("reshape")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from nvtripy import export, utils
from nvtripy.frontend.ops import utils as op_utils
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.utils import wrappers


@register_tensor_method("squeeze")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.
from nvtripy import export
from nvtripy.common.exception import raise_error
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.utils import wrappers


@register_tensor_method("transpose")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
Expand Down
2 changes: 2 additions & 0 deletions tripy/nvtripy/frontend/ops/unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

from nvtripy import export
from nvtripy.frontend.ops import utils as op_utils
from nvtripy.frontend.ops._registry import register_tensor_method
from nvtripy.utils import wrappers


@register_tensor_method("unsqueeze")
@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
Expand Down
10 changes: 5 additions & 5 deletions tripy/nvtripy/utils/ast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -102,9 +102,8 @@ def get_ast_node_func_name(node) -> Optional[str]:

# Gets the column offset of the argument at `index` to function called `func_name` in the provided `code` snippet.
def get_arg_candidate_column_offsets(
code: str, index: int, num_positional: int, func_name: str, is_kwarg: bool, arg_names: List[str]
code: str, index: int, num_positional: int, func_name: str, is_kwarg: bool
) -> Tuple[int, int]:

candidates = []

result = get_parsed_ast(code)
Expand All @@ -123,8 +122,9 @@ def get_arg_candidate_column_offsets(
if is_kwarg:
arg_node = node.keywords[index - num_positional]
else:
# For methods, the `self` argument is omitted from ast.Call.args
if "self" in arg_names:
# Detect method calls by examining AST structure
is_method_call = isinstance(node.func, ast.Attribute)
if is_method_call:
index -= 1
# If the final argument is a starred object, then we treat any args
# past the end as pointing to the starred object (this would be a variadic call,
Expand Down
6 changes: 3 additions & 3 deletions tripy/nvtripy/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import functools
import inspect
import types
from dataclasses import dataclass
from textwrap import indent
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
Expand All @@ -40,7 +41,7 @@ class DataTypeConstraints:


# Try to include correct column offsets for non-tensor arguments.
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_names):
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name):
from nvtripy.frontend.tensor import Tensor

assert isinstance(arg, Tensor), f"This function should only be called for objects that are already Tensor instances"
Expand Down Expand Up @@ -94,7 +95,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_na
dispatch_target = dispatch_target.replace("__r", "__")

candidates = utils.ast.get_arg_candidate_column_offsets(
source_info.code, arg_index, num_positional, dispatch_target or func_name, is_kwarg, arg_names
source_info.code, arg_index, num_positional, dispatch_target or func_name, is_kwarg
)

# Only set column range if there is exactly one candidate, otherwise we can't reliably determine
Expand Down Expand Up @@ -202,7 +203,6 @@ def add_arg(arg):
name in kwargs,
len(args),
func.__name__,
[name for name, _ in merged_args],
)

dtype = None
Expand Down
40 changes: 23 additions & 17 deletions tripy/tests/integration/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,39 @@
from tests.conftest import skip_if_older_than_sm89
from tests.helper import NUMPY_TO_TRIPY

dtype_pairs = [
(np.int32, np.float32),
(np.float32, np.int32),
(np.int32, np.int8),
(np.float32, np.int8),
(np.int8, np.int32),
(np.int8, np.float32),
# important to test conversion into bool because default StableHLO semantics
# are simply to truncate to i1, which is not desirable
(np.float32, bool),
(np.int32, bool),
# requires a dequantization first
# TODO(#219): Dequantize fails with dynamic shapes
# (np.int8, bool),
]


class TestCast:
@pytest.mark.parametrize(
"input_dtype, target_dtype",
[
(np.int32, np.float32),
(np.float32, np.int32),
(np.int32, np.int8),
(np.float32, np.int8),
(np.int8, np.int32),
(np.int8, np.float32),
# important to test conversion into bool because default StableHLO semantics
# are simply to truncate to i1, which is not desirable
(np.float32, bool),
(np.int32, bool),
# requires a dequantization first
# TODO(#219): Dequantize fails with dynamic shapes
# (np.int8, bool),
],
dtype_pairs,
)
def test_cast(self, input_dtype, target_dtype, eager_or_compiled):
@pytest.mark.parametrize("use_tensor_method", [False, True])
def test_cast(self, input_dtype, target_dtype, use_tensor_method, eager_or_compiled):
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]

# TODO(#222): Integer casts with negative numbers fail in many cases
input_tensor = tp.copy(tp.Tensor(np.ones((2, 3), dtype=input_dtype)), tp.device("gpu"))

output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype)
if use_tensor_method:
output = eager_or_compiled(lambda t: t.cast(tp_target_dtype), input_tensor)
else:
output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype)

np_input = cp.from_dlpack(input_tensor).get()
assert np.array_equal(cp.from_dlpack(output).get(), np_input.astype(target_dtype))
Expand Down
27 changes: 22 additions & 5 deletions tripy/tests/integration/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,40 @@
import numpy as np
import nvtripy as tp

import pytest


class TestCopy:
def test_to_cpu(self):
@pytest.mark.parametrize(
"copy_func",
[
lambda tensor, device: tp.copy(tensor, device), # Free function
lambda tensor, device: tensor.copy(device), # Tensor method
],
)
def test_copy_tensor_method(self, copy_func):
"""Test that both copy methods work with compilation."""
gpu_tensor = tp.Tensor(cp.ones((2, 2), dtype=cp.float32))
assert gpu_tensor.device.kind == "gpu"

cpu_tensor = tp.copy(gpu_tensor, tp.device("cpu"))
assert cpu_tensor.device.kind == "cpu"
cpu_tensor = copy_func(gpu_tensor, tp.device("cpu"))

assert cpu_tensor.device.kind == "cpu"
# If the tensor is really in CPU memory, we should be able to construct a NumPy array from it
assert np.from_dlpack(cpu_tensor).shape == (2, 2)

def test_to_gpu(self):
@pytest.mark.parametrize(
"copy_func",
[
lambda tensor, device: tp.copy(tensor, device), # Free function
lambda tensor, device: tensor.copy(device), # Tensor method
],
)
def test_to_gpu(self, copy_func):
cpu_tensor = tp.Tensor(np.ones((2, 2), dtype=np.float32))
assert cpu_tensor.device.kind == "cpu"

gpu_tensor = tp.copy(cpu_tensor, tp.device("gpu"))
gpu_tensor = copy_func(cpu_tensor, tp.device("gpu"))
assert gpu_tensor.device.kind == "gpu"

# If the tensor is really in GPU memory, we should be able to construct a Cupy array from it
Expand Down
26 changes: 17 additions & 9 deletions tripy/tests/integration/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,30 @@
import pytest
import nvtripy as tp

test_cases = [
((2, 3, 4), 0, -1, (24,)), # Flatten all dimensions
((2, 3, 4), 1, -1, (2, 12)), # Flatten dimensions 1 through end
((2, 3, 4), 1, 2, (2, 12)), # Flatten dimensions 1 through 2
((2, 3, 4), 0, 1, (6, 4)), # Flatten dimensions 0 through 1
((2, 3, 4, 5), 1, 3, (2, 60)), # Flatten dimensions 1 through 3
]


class TestFlatten:
@pytest.mark.parametrize(
"shape, start_dim, end_dim, expected_shape",
[
((2, 3, 4), 0, -1, (24,)), # Flatten all dimensions
((2, 3, 4), 1, -1, (2, 12)), # Flatten dimensions 1 through end
((2, 3, 4), 1, 2, (2, 12)), # Flatten dimensions 1 through 2
((2, 3, 4), 0, 1, (6, 4)), # Flatten dimensions 0 through 1
((2, 3, 4, 5), 1, 3, (2, 60)), # Flatten dimensions 1 through 3
],
test_cases,
)
def test_flatten(self, shape, start_dim, end_dim, expected_shape, eager_or_compiled):
@pytest.mark.parametrize("use_tensor_method", [False, True])
def test_flatten(self, shape, start_dim, end_dim, expected_shape, use_tensor_method, eager_or_compiled):
cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32)
a = tp.Tensor(cp_a)
b = eager_or_compiled(tp.flatten, a, start_dim=start_dim, end_dim=end_dim)

if use_tensor_method:
b = eager_or_compiled(lambda t: t.flatten(start_dim=start_dim, end_dim=end_dim), a)
else:
b = eager_or_compiled(tp.flatten, a, start_dim=start_dim, end_dim=end_dim)

assert b.shape == expected_shape
assert np.array_equal(cp.from_dlpack(b).get(), cp_a.reshape(expected_shape).get())

Expand Down
26 changes: 16 additions & 10 deletions tripy/tests/integration/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,29 @@
import pytest
import nvtripy as tp

test_cases = [
((2, 4), (1, 8)),
((2, 4, 8, 9), (8, 8, 9)),
((2, 4), (8,)), # change rank of output
((2, 4), (1, -1)), # check negative dim
]


class TestReshape:
@pytest.mark.parametrize(
"shape, new_shape",
[
((2, 4), (1, 8)),
((2, 4, 8, 9), (8, 8, 9)),
((2, 4), (8,)), # change rank of output
((2, 4), (1, -1)), # check negative dim
],
test_cases,
)
def test_static_reshape(self, shape, new_shape, eager_or_compiled):
@pytest.mark.parametrize("use_tensor_method", [False, True])
def test_static_reshape(self, shape, new_shape, use_tensor_method, eager_or_compiled):
cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32)
a = tp.Tensor(cp_a)
b = eager_or_compiled(tp.reshape, a, new_shape)
if -1 in new_shape:
new_shape = tuple(np.prod(shape) // -np.prod(new_shape) if d == -1 else d for d in new_shape)

if use_tensor_method:
b = eager_or_compiled(lambda t: t.reshape(new_shape), a)
else:
b = eager_or_compiled(tp.reshape, a, new_shape)

assert np.array_equal(cp.from_dlpack(b).get(), cp_a.reshape(new_shape).get())

def test_reshape_shape_tensor(self, eager_or_compiled):
Expand Down
Loading