Skip to content

Commit 9e728f0

Browse files
committed
Add tensor methods
1 parent f821499 commit 9e728f0

File tree

10 files changed

+34
-3
lines changed

10 files changed

+34
-3
lines changed

tripy/nvtripy/frontend/ops/_registry.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -32,13 +32,23 @@ def register_tensor_method(name: str):
3232
"""
3333

3434
# We make a special exception for "shape" since we actually do want that to be a property
35-
allowed_methods = ["shape"]
35+
allowed_methods = ["copy", "cast", "shape", "reshape", "transpose", "flatten", "permute", "squeeze", "unsqueeze"]
3636
assert name in allowed_methods or name.startswith(
3737
"__"
3838
), f"The tensor method registry should only be used for magic methods, but was used for: {name}"
3939

4040
def impl(func: Callable[..., Any]) -> Callable[..., Any]:
41-
TENSOR_METHOD_REGISTRY[name] = func
41+
# Don't wrap properties (like shape)
42+
if name == "shape":
43+
TENSOR_METHOD_REGISTRY[name] = func
44+
else:
45+
# Create a method wrapper that maps 'self' to the first argument (input)
46+
# This is the standard pattern for all tensor methods except 'shape' (which is a property)
47+
def method_wrapper(self, *args, **kwargs):
48+
return func(self, *args, **kwargs)
49+
50+
TENSOR_METHOD_REGISTRY[name] = method_wrapper
51+
4252
return func
4353

4454
return impl

tripy/nvtripy/frontend/ops/cast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
from nvtripy.common.datatype import bool as tp_bool
2121
from nvtripy.common.datatype import float32, int8
2222
from nvtripy.frontend.ops import utils as op_utils
23+
from nvtripy.frontend.ops._registry import register_tensor_method
2324
from nvtripy.frontend.ops.dequantize import dequantize
2425
from nvtripy.frontend.ops.quantize import quantize
2526
from nvtripy.trace.ops.cast import Cast
2627
from nvtripy.utils import wrappers
2728

2829

30+
@register_tensor_method("cast")
2931
@export.public_api(document_under="operations/functions")
3032
@wrappers.interface(
3133
dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"},

tripy/nvtripy/frontend/ops/copy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from nvtripy.common import device as tp_device
2222
from nvtripy.common.datatype import DATA_TYPES
2323
from nvtripy.common.exception import raise_error
24+
from nvtripy.frontend.ops._registry import register_tensor_method
2425
from nvtripy.utils import wrappers
2526

2627

2728
@export.public_api(document_under="operations/functions")
29+
@register_tensor_method("copy")
2830
@wrappers.interface(
2931
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
3032
dtype_variables={"T1": list(DATA_TYPES.keys())},

tripy/nvtripy/frontend/ops/flatten.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
from nvtripy import export
1818
from nvtripy.common.exception import raise_error
1919
from nvtripy.frontend.ops import utils as op_utils
20+
from nvtripy.frontend.ops._registry import register_tensor_method
2021
from nvtripy.utils import wrappers
2122

2223

2324
@export.public_api(document_under="operations/functions")
25+
@register_tensor_method("flatten")
2426
@wrappers.interface(
2527
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2628
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/permute.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from nvtripy import export
2121
from nvtripy.common.exception import raise_error
2222
from nvtripy.frontend.ops import utils as op_utils
23+
from nvtripy.frontend.ops._registry import register_tensor_method
2324
from nvtripy.trace.ops.permute import Permute
2425
from nvtripy.utils import wrappers
2526

2627

2728
@export.public_api(document_under="operations/functions")
29+
@register_tensor_method("permute")
2830
@wrappers.interface(
2931
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
3032
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/reshape.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nvtripy import export
2121
from nvtripy.common.exception import raise_error
2222
from nvtripy.frontend.ops import utils as op_utils
23+
from nvtripy.frontend.ops._registry import register_tensor_method
2324
from nvtripy.trace.ops.reshape import Reshape
2425
from nvtripy.types import ShapeLike
2526
from nvtripy.utils import wrappers
@@ -43,6 +44,7 @@ def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike:
4344

4445

4546
@export.public_api(document_under="operations/functions")
47+
@register_tensor_method("reshape")
4648
@wrappers.interface(
4749
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
4850
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/squeeze.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
from nvtripy import export, utils
1818
from nvtripy.frontend.ops import utils as op_utils
19+
from nvtripy.frontend.ops._registry import register_tensor_method
1920
from nvtripy.utils import wrappers
2021

2122

2223
@export.public_api(document_under="operations/functions")
24+
@register_tensor_method("squeeze")
2325
@wrappers.interface(
2426
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2527
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/transpose.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# limitations under the License.
1515
from nvtripy import export
1616
from nvtripy.common.exception import raise_error
17+
from nvtripy.frontend.ops._registry import register_tensor_method
1718
from nvtripy.utils import wrappers
1819

1920

2021
@export.public_api(document_under="operations/functions")
22+
@register_tensor_method("transpose")
2123
@wrappers.interface(
2224
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2325
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/unsqueeze.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
from nvtripy import export
1919
from nvtripy.frontend.ops import utils as op_utils
20+
from nvtripy.frontend.ops._registry import register_tensor_method
2021
from nvtripy.utils import wrappers
2122

2223

2324
@export.public_api(document_under="operations/functions")
25+
@register_tensor_method("unsqueeze")
2426
@wrappers.interface(
2527
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2628
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def from_trace_tensor(cls, trace_tensor, include_code_index=2):
125125
return instance
126126

127127
def __getattr__(self, name: str):
128+
if name in TENSOR_METHOD_REGISTRY:
129+
import types
130+
131+
return types.MethodType(TENSOR_METHOD_REGISTRY[name], self)
132+
128133
import nvtripy as tp
129134
from nvtripy.common.exception import search_for_missing_attr
130135

0 commit comments

Comments
 (0)