diff --git a/tripy/nvtripy/frontend/ops/__init__.py b/tripy/nvtripy/frontend/ops/__init__.py index ecb193a38..2db2d6c0b 100644 --- a/tripy/nvtripy/frontend/ops/__init__.py +++ b/tripy/nvtripy/frontend/ops/__init__.py @@ -21,3 +21,13 @@ import nvtripy.frontend.ops.matmul import nvtripy.frontend.ops.shape import nvtripy.frontend.ops.slice + +# Import regular methods that should be available as tensor methods +import nvtripy.frontend.ops.cast +import nvtripy.frontend.ops.copy +import nvtripy.frontend.ops.reshape +import nvtripy.frontend.ops.transpose +import nvtripy.frontend.ops.flatten +import nvtripy.frontend.ops.permute +import nvtripy.frontend.ops.squeeze +import nvtripy.frontend.ops.unsqueeze diff --git a/tripy/nvtripy/frontend/ops/_registry.py b/tripy/nvtripy/frontend/ops/_registry.py index d5e52d689..d6eef0df8 100644 --- a/tripy/nvtripy/frontend/ops/_registry.py +++ b/tripy/nvtripy/frontend/ops/_registry.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,14 +28,16 @@ def register_tensor_method(name: str): Decorator to add the method to the tensor method registry with the name specified. This does not use the FunctionRegistry decorator because every tensor method would also be registered in the public function registry and we would prefer to avoid having overhead - from having to dispatch overloads and check types twice. + from having to dispatch overloads and check types twice. This needs to be the top level decorator so we can + get input type validation from other decorators like `public_api`. """ # We make a special exception for "shape" since we actually do want that to be a property - allowed_methods = ["shape"] + # We also add additional methods of the tensor class that are not magic methods + allowed_methods = ["copy", "cast", "shape", "reshape", "transpose", "flatten", "permute", "squeeze", "unsqueeze"] assert name in allowed_methods or name.startswith( "__" - ), f"The tensor method registry should only be used for magic methods, but was used for: {name}" + ), f"The tensor method registry should only be used for magic methods and specially allowed methods, but was used for: {name}" def impl(func: Callable[..., Any]) -> Callable[..., Any]: TENSOR_METHOD_REGISTRY[name] = func diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index e937e8d4f..9c197af57 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -20,12 +20,14 @@ from nvtripy.common.datatype import bool as tp_bool from nvtripy.common.datatype import float32, int8 from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize from nvtripy.trace.ops.cast import Cast from nvtripy.utils import wrappers +@register_tensor_method("cast") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, diff --git a/tripy/nvtripy/frontend/ops/copy.py b/tripy/nvtripy/frontend/ops/copy.py index df706e725..4e5c10d6a 100644 --- a/tripy/nvtripy/frontend/ops/copy.py +++ b/tripy/nvtripy/frontend/ops/copy.py @@ -21,9 +21,11 @@ from nvtripy.common import device as tp_device from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.utils import wrappers +@register_tensor_method("copy") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, diff --git a/tripy/nvtripy/frontend/ops/flatten.py b/tripy/nvtripy/frontend/ops/flatten.py index 40ea5bcd2..a356bf123 100644 --- a/tripy/nvtripy/frontend/ops/flatten.py +++ b/tripy/nvtripy/frontend/ops/flatten.py @@ -17,9 +17,11 @@ from nvtripy import export from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.utils import wrappers +@register_tensor_method("flatten") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, diff --git a/tripy/nvtripy/frontend/ops/permute.py b/tripy/nvtripy/frontend/ops/permute.py index 1a62c35fb..a367ce939 100644 --- a/tripy/nvtripy/frontend/ops/permute.py +++ b/tripy/nvtripy/frontend/ops/permute.py @@ -20,10 +20,12 @@ from nvtripy import export from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.permute import Permute from nvtripy.utils import wrappers +@register_tensor_method("permute") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, diff --git a/tripy/nvtripy/frontend/ops/reshape.py b/tripy/nvtripy/frontend/ops/reshape.py index 943dc6e56..262dc78b8 100644 --- a/tripy/nvtripy/frontend/ops/reshape.py +++ b/tripy/nvtripy/frontend/ops/reshape.py @@ -20,6 +20,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.reshape import Reshape from nvtripy.types import ShapeLike from nvtripy.utils import wrappers @@ -42,6 +43,7 @@ def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike: return {"shape": shape} +@register_tensor_method("reshape") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, diff --git a/tripy/nvtripy/frontend/ops/squeeze.py b/tripy/nvtripy/frontend/ops/squeeze.py index 0a7311570..5ec260e15 100644 --- a/tripy/nvtripy/frontend/ops/squeeze.py +++ b/tripy/nvtripy/frontend/ops/squeeze.py @@ -16,9 +16,11 @@ from nvtripy import export, utils from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.utils import wrappers +@register_tensor_method("squeeze") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, diff --git a/tripy/nvtripy/frontend/ops/transpose.py b/tripy/nvtripy/frontend/ops/transpose.py index 49756e6b2..1d4b8cf3d 100644 --- a/tripy/nvtripy/frontend/ops/transpose.py +++ b/tripy/nvtripy/frontend/ops/transpose.py @@ -14,9 +14,11 @@ # limitations under the License. from nvtripy import export from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.utils import wrappers +@register_tensor_method("transpose") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, diff --git a/tripy/nvtripy/frontend/ops/unsqueeze.py b/tripy/nvtripy/frontend/ops/unsqueeze.py index 44076c2ef..fa3e25045 100644 --- a/tripy/nvtripy/frontend/ops/unsqueeze.py +++ b/tripy/nvtripy/frontend/ops/unsqueeze.py @@ -17,9 +17,11 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils +from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.utils import wrappers +@register_tensor_method("unsqueeze") @export.public_api(document_under="operations/functions") @wrappers.interface( dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, diff --git a/tripy/nvtripy/utils/ast.py b/tripy/nvtripy/utils/ast.py index 3a61ff1cb..79830b5a5 100644 --- a/tripy/nvtripy/utils/ast.py +++ b/tripy/nvtripy/utils/ast.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -102,9 +102,8 @@ def get_ast_node_func_name(node) -> Optional[str]: # Gets the column offset of the argument at `index` to function called `func_name` in the provided `code` snippet. def get_arg_candidate_column_offsets( - code: str, index: int, num_positional: int, func_name: str, is_kwarg: bool, arg_names: List[str] + code: str, index: int, num_positional: int, func_name: str, is_kwarg: bool ) -> Tuple[int, int]: - candidates = [] result = get_parsed_ast(code) @@ -123,8 +122,10 @@ def get_arg_candidate_column_offsets( if is_kwarg: arg_node = node.keywords[index - num_positional] else: + # Detect method calls by examining AST structure # For methods, the `self` argument is omitted from ast.Call.args - if "self" in arg_names: + is_method_call = isinstance(node.func, ast.Attribute) + if is_method_call: index -= 1 # If the final argument is a starred object, then we treat any args # past the end as pointing to the starred object (this would be a variadic call, diff --git a/tripy/nvtripy/utils/wrappers.py b/tripy/nvtripy/utils/wrappers.py index c47e7330b..7f84427c3 100644 --- a/tripy/nvtripy/utils/wrappers.py +++ b/tripy/nvtripy/utils/wrappers.py @@ -17,6 +17,7 @@ import functools import inspect +import types from dataclasses import dataclass from textwrap import indent from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -40,7 +41,7 @@ class DataTypeConstraints: # Try to include correct column offsets for non-tensor arguments. -def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_names): +def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name): from nvtripy.frontend.tensor import Tensor assert isinstance(arg, Tensor), f"This function should only be called for objects that are already Tensor instances" @@ -94,7 +95,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_na dispatch_target = dispatch_target.replace("__r", "__") candidates = utils.ast.get_arg_candidate_column_offsets( - source_info.code, arg_index, num_positional, dispatch_target or func_name, is_kwarg, arg_names + source_info.code, arg_index, num_positional, dispatch_target or func_name, is_kwarg ) # Only set column range if there is exactly one candidate, otherwise we can't reliably determine @@ -202,7 +203,6 @@ def add_arg(arg): name in kwargs, len(args), func.__name__, - [name for name, _ in merged_args], ) dtype = None diff --git a/tripy/tests/integration/test_cast.py b/tripy/tests/integration/test_cast.py index 84adf346d..87e4cf4a1 100644 --- a/tripy/tests/integration/test_cast.py +++ b/tripy/tests/integration/test_cast.py @@ -23,33 +23,39 @@ from tests.conftest import skip_if_older_than_sm89 from tests.helper import NUMPY_TO_TRIPY +dtype_pairs = [ + (np.int32, np.float32), + (np.float32, np.int32), + (np.int32, np.int8), + (np.float32, np.int8), + (np.int8, np.int32), + (np.int8, np.float32), + # important to test conversion into bool because default StableHLO semantics + # are simply to truncate to i1, which is not desirable + (np.float32, bool), + (np.int32, bool), + # requires a dequantization first + # TODO(#219): Dequantize fails with dynamic shapes + # (np.int8, bool), +] + class TestCast: @pytest.mark.parametrize( "input_dtype, target_dtype", - [ - (np.int32, np.float32), - (np.float32, np.int32), - (np.int32, np.int8), - (np.float32, np.int8), - (np.int8, np.int32), - (np.int8, np.float32), - # important to test conversion into bool because default StableHLO semantics - # are simply to truncate to i1, which is not desirable - (np.float32, bool), - (np.int32, bool), - # requires a dequantization first - # TODO(#219): Dequantize fails with dynamic shapes - # (np.int8, bool), - ], + dtype_pairs, ) - def test_cast(self, input_dtype, target_dtype, eager_or_compiled): + @pytest.mark.parametrize("use_tensor_method", [False, True]) + def test_cast(self, input_dtype, target_dtype, use_tensor_method, eager_or_compiled): tp_target_dtype = NUMPY_TO_TRIPY[target_dtype] # TODO(#222): Integer casts with negative numbers fail in many cases input_tensor = tp.copy(tp.Tensor(np.ones((2, 3), dtype=input_dtype)), tp.device("gpu")) - output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype) + if use_tensor_method: + output = eager_or_compiled(lambda t: t.cast(tp_target_dtype), input_tensor) + else: + output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype) np_input = cp.from_dlpack(input_tensor).get() assert np.array_equal(cp.from_dlpack(output).get(), np_input.astype(target_dtype)) diff --git a/tripy/tests/integration/test_copy.py b/tripy/tests/integration/test_copy.py index 406c40fbe..423b9a0e4 100644 --- a/tripy/tests/integration/test_copy.py +++ b/tripy/tests/integration/test_copy.py @@ -16,23 +16,40 @@ import numpy as np import nvtripy as tp +import pytest + class TestCopy: - def test_to_cpu(self): + @pytest.mark.parametrize( + "copy_func", + [ + lambda tensor, device: tp.copy(tensor, device), # Free function + lambda tensor, device: tensor.copy(device), # Tensor method + ], + ) + def test_copy_tensor_method(self, copy_func): + """Test that both copy methods work with compilation.""" gpu_tensor = tp.Tensor(cp.ones((2, 2), dtype=cp.float32)) assert gpu_tensor.device.kind == "gpu" - cpu_tensor = tp.copy(gpu_tensor, tp.device("cpu")) - assert cpu_tensor.device.kind == "cpu" + cpu_tensor = copy_func(gpu_tensor, tp.device("cpu")) + assert cpu_tensor.device.kind == "cpu" # If the tensor is really in CPU memory, we should be able to construct a NumPy array from it assert np.from_dlpack(cpu_tensor).shape == (2, 2) - def test_to_gpu(self): + @pytest.mark.parametrize( + "copy_func", + [ + lambda tensor, device: tp.copy(tensor, device), # Free function + lambda tensor, device: tensor.copy(device), # Tensor method + ], + ) + def test_to_gpu(self, copy_func): cpu_tensor = tp.Tensor(np.ones((2, 2), dtype=np.float32)) assert cpu_tensor.device.kind == "cpu" - gpu_tensor = tp.copy(cpu_tensor, tp.device("gpu")) + gpu_tensor = copy_func(cpu_tensor, tp.device("gpu")) assert gpu_tensor.device.kind == "gpu" # If the tensor is really in GPU memory, we should be able to construct a Cupy array from it diff --git a/tripy/tests/integration/test_flatten.py b/tripy/tests/integration/test_flatten.py index 2934f4a1a..8a8239328 100644 --- a/tripy/tests/integration/test_flatten.py +++ b/tripy/tests/integration/test_flatten.py @@ -17,22 +17,30 @@ import pytest import nvtripy as tp +test_cases = [ + ((2, 3, 4), 0, -1, (24,)), # Flatten all dimensions + ((2, 3, 4), 1, -1, (2, 12)), # Flatten dimensions 1 through end + ((2, 3, 4), 1, 2, (2, 12)), # Flatten dimensions 1 through 2 + ((2, 3, 4), 0, 1, (6, 4)), # Flatten dimensions 0 through 1 + ((2, 3, 4, 5), 1, 3, (2, 60)), # Flatten dimensions 1 through 3 +] + class TestFlatten: @pytest.mark.parametrize( "shape, start_dim, end_dim, expected_shape", - [ - ((2, 3, 4), 0, -1, (24,)), # Flatten all dimensions - ((2, 3, 4), 1, -1, (2, 12)), # Flatten dimensions 1 through end - ((2, 3, 4), 1, 2, (2, 12)), # Flatten dimensions 1 through 2 - ((2, 3, 4), 0, 1, (6, 4)), # Flatten dimensions 0 through 1 - ((2, 3, 4, 5), 1, 3, (2, 60)), # Flatten dimensions 1 through 3 - ], + test_cases, ) - def test_flatten(self, shape, start_dim, end_dim, expected_shape, eager_or_compiled): + @pytest.mark.parametrize("use_tensor_method", [False, True]) + def test_flatten(self, shape, start_dim, end_dim, expected_shape, use_tensor_method, eager_or_compiled): cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32) a = tp.Tensor(cp_a) - b = eager_or_compiled(tp.flatten, a, start_dim=start_dim, end_dim=end_dim) + + if use_tensor_method: + b = eager_or_compiled(lambda t: t.flatten(start_dim=start_dim, end_dim=end_dim), a) + else: + b = eager_or_compiled(tp.flatten, a, start_dim=start_dim, end_dim=end_dim) + assert b.shape == expected_shape assert np.array_equal(cp.from_dlpack(b).get(), cp_a.reshape(expected_shape).get()) diff --git a/tripy/tests/integration/test_reshape.py b/tripy/tests/integration/test_reshape.py index 3eafa65f5..b527ed4e5 100644 --- a/tripy/tests/integration/test_reshape.py +++ b/tripy/tests/integration/test_reshape.py @@ -20,23 +20,29 @@ import pytest import nvtripy as tp +test_cases = [ + ((2, 4), (1, 8)), + ((2, 4, 8, 9), (8, 8, 9)), + ((2, 4), (8,)), # change rank of output + ((2, 4), (1, -1)), # check negative dim +] + class TestReshape: @pytest.mark.parametrize( "shape, new_shape", - [ - ((2, 4), (1, 8)), - ((2, 4, 8, 9), (8, 8, 9)), - ((2, 4), (8,)), # change rank of output - ((2, 4), (1, -1)), # check negative dim - ], + test_cases, ) - def test_static_reshape(self, shape, new_shape, eager_or_compiled): + @pytest.mark.parametrize("use_tensor_method", [False, True]) + def test_static_reshape(self, shape, new_shape, use_tensor_method, eager_or_compiled): cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32) a = tp.Tensor(cp_a) - b = eager_or_compiled(tp.reshape, a, new_shape) - if -1 in new_shape: - new_shape = tuple(np.prod(shape) // -np.prod(new_shape) if d == -1 else d for d in new_shape) + + if use_tensor_method: + b = eager_or_compiled(lambda t: t.reshape(new_shape), a) + else: + b = eager_or_compiled(tp.reshape, a, new_shape) + assert np.array_equal(cp.from_dlpack(b).get(), cp_a.reshape(new_shape).get()) def test_reshape_shape_tensor(self, eager_or_compiled): diff --git a/tripy/tests/integration/test_squeeze.py b/tripy/tests/integration/test_squeeze.py index 350f075d4..d873333fd 100644 --- a/tripy/tests/integration/test_squeeze.py +++ b/tripy/tests/integration/test_squeeze.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,18 +15,26 @@ import nvtripy as tp import pytest +test_cases = [ + ((1, 2, 1), 0, (2, 1)), # Squeeze first dimension + ((1, 2, 1), (0, 2), (2,)), # Squeeze first and third dimensions + ((1, 2, 1), tuple(), (1, 2, 1)), # No dimensions to squeeze + ((1, 2, 1), (-3, -1), (2,)), # Squeeze using negative dimensions +] + class TestSqueeze: @pytest.mark.parametrize( "input_shape, dims, expected_shape", - [ - ((1, 2, 1), 0, (2, 1)), # Squeeze first dimension - ((1, 2, 1), (0, 2), (2,)), # Squeeze first and third dimensions - ((1, 2, 1), tuple(), (1, 2, 1)), # No dimensions to squeeze - ((1, 2, 1), (-3, -1), (2,)), # Squeeze using negative dimensions - ], + test_cases, ) - def test_squeeze(self, input_shape, dims, expected_shape): + @pytest.mark.parametrize("use_tensor_method", [False, True]) + def test_squeeze(self, input_shape, dims, expected_shape, use_tensor_method): input_tensor = tp.ones(input_shape, dtype=tp.float32) - output_tensor = tp.squeeze(input_tensor, dims=dims) + + if use_tensor_method: + output_tensor = input_tensor.squeeze(dims) + else: + output_tensor = tp.squeeze(input_tensor, dims=dims) + assert output_tensor.shape == expected_shape diff --git a/tripy/tests/integration/test_unsqueeze.py b/tripy/tests/integration/test_unsqueeze.py index fdc646a83..dca392660 100644 --- a/tripy/tests/integration/test_unsqueeze.py +++ b/tripy/tests/integration/test_unsqueeze.py @@ -24,16 +24,17 @@ class TestUnsqueezeOp: @pytest.mark.parametrize("axis", [-1, 0, 2]) - def test_unsqueeze_dynamic_op(self, axis, eager_or_compiled): - def func(a): - return tp.unsqueeze(a, dim=axis) - + @pytest.mark.parametrize("use_tensor_method", [False, True]) + def test_unsqueeze_dynamic_op(self, axis, use_tensor_method, eager_or_compiled): inp = np.ones((4, 2, 2, 3), dtype=np.float32) - out = eager_or_compiled(func, tp.Tensor(inp)) + if use_tensor_method: + out = eager_or_compiled(lambda t: t.unsqueeze(axis), tp.Tensor(inp)) + else: + out = eager_or_compiled(tp.unsqueeze, tp.Tensor(inp), axis) + ref_out = np.expand_dims(inp, axis=axis) assert tp.allclose(out, tp.Tensor(ref_out)) - assert out.shape == tuple(ref_out.shape) def test_unsqueeze_compile(self): diff --git a/tripy/tests/utils/test_ast.py b/tripy/tests/utils/test_ast.py index e3964e02c..da229f5ad 100644 --- a/tripy/tests/utils/test_ast.py +++ b/tripy/tests/utils/test_ast.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -334,6 +334,6 @@ def test_get_arg_candidate_column_offsets( if expected[i] is None: continue assert ( - get_arg_candidate_column_offsets(call_str, i, num_positional, func_name, i >= num_positional, arg_names) + get_arg_candidate_column_offsets(call_str, i, num_positional, func_name, i >= num_positional) == expected[i] )