Skip to content

Commit a6e0c35

Browse files
committed
Improve detection of instance methods, test cleanup, simplify method registry
1 parent 44e33c6 commit a6e0c35

File tree

18 files changed

+67
-111
lines changed

18 files changed

+67
-111
lines changed

tripy/nvtripy/frontend/ops/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@
2121
import nvtripy.frontend.ops.matmul
2222
import nvtripy.frontend.ops.shape
2323
import nvtripy.frontend.ops.slice
24+
25+
# Import regular methods that should be available as tensor methods
26+
import nvtripy.frontend.ops.cast
27+
import nvtripy.frontend.ops.copy
28+
import nvtripy.frontend.ops.reshape
29+
import nvtripy.frontend.ops.transpose
30+
import nvtripy.frontend.ops.flatten
31+
import nvtripy.frontend.ops.permute
32+
import nvtripy.frontend.ops.squeeze
33+
import nvtripy.frontend.ops.unsqueeze

tripy/nvtripy/frontend/ops/_registry.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,19 @@ def register_tensor_method(name: str):
2929
Decorator to add the method to the tensor method registry with the name specified.
3030
This does not use the FunctionRegistry decorator because every tensor method would also be
3131
registered in the public function registry and we would prefer to avoid having overhead
32-
from having to dispatch overloads and check types twice.
32+
from having to dispatch overloads and check types twice. This needs to be the top level decorator so we can
33+
get input type validation from other decorators like `public_api`.
3334
"""
3435

3536
# We make a special exception for "shape" since we actually do want that to be a property
3637
# We also add additional methods of the tensor class that are not magic methods
3738
allowed_methods = ["copy", "cast", "shape", "reshape", "transpose", "flatten", "permute", "squeeze", "unsqueeze"]
3839
assert name in allowed_methods or name.startswith(
3940
"__"
40-
), f"The tensor method registry should only be used for magic methods, but was used for: {name}"
41+
), f"The tensor method registry should only be used for magic methods and specially allowed methods, but was used for: {name}"
4142

4243
def impl(func: Callable[..., Any]) -> Callable[..., Any]:
43-
if name == "shape":
44-
TENSOR_METHOD_REGISTRY[name] = func
45-
else:
46-
# Create a method wrapper that maps 'self' to the first argument (input)
47-
# This is the standard pattern for all tensor methods except 'shape' (which is a property)
48-
@wraps(func)
49-
def method_wrapper(self, *args, **kwargs):
50-
return func(self, *args, **kwargs)
51-
52-
TENSOR_METHOD_REGISTRY[name] = method_wrapper
53-
44+
TENSOR_METHOD_REGISTRY[name] = func
5445
return func
5546

5647
return impl

tripy/nvtripy/frontend/ops/copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from nvtripy.utils import wrappers
2626

2727

28-
@export.public_api(document_under="operations/functions")
2928
@register_tensor_method("copy")
29+
@export.public_api(document_under="operations/functions")
3030
@wrappers.interface(
3131
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
3232
dtype_variables={"T1": list(DATA_TYPES.keys())},

tripy/nvtripy/frontend/ops/flatten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from nvtripy.utils import wrappers
2222

2323

24-
@export.public_api(document_under="operations/functions")
2524
@register_tensor_method("flatten")
25+
@export.public_api(document_under="operations/functions")
2626
@wrappers.interface(
2727
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2828
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/permute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from nvtripy.utils import wrappers
2626

2727

28-
@export.public_api(document_under="operations/functions")
2928
@register_tensor_method("permute")
29+
@export.public_api(document_under="operations/functions")
3030
@wrappers.interface(
3131
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
3232
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/reshape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike:
4343
return {"shape": shape}
4444

4545

46-
@export.public_api(document_under="operations/functions")
4746
@register_tensor_method("reshape")
47+
@export.public_api(document_under="operations/functions")
4848
@wrappers.interface(
4949
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
5050
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/squeeze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from nvtripy.utils import wrappers
2121

2222

23-
@export.public_api(document_under="operations/functions")
2423
@register_tensor_method("squeeze")
24+
@export.public_api(document_under="operations/functions")
2525
@wrappers.interface(
2626
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2727
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/transpose.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from nvtripy.utils import wrappers
1919

2020

21-
@export.public_api(document_under="operations/functions")
2221
@register_tensor_method("transpose")
22+
@export.public_api(document_under="operations/functions")
2323
@wrappers.interface(
2424
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2525
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/ops/unsqueeze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from nvtripy.utils import wrappers
2222

2323

24-
@export.public_api(document_under="operations/functions")
2524
@register_tensor_method("unsqueeze")
25+
@export.public_api(document_under="operations/functions")
2626
@wrappers.interface(
2727
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
2828
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]},

tripy/nvtripy/frontend/tensor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,6 @@ def from_trace_tensor(cls, trace_tensor, include_code_index=2):
125125
return instance
126126

127127
def __getattr__(self, name: str):
128-
if name in TENSOR_METHOD_REGISTRY:
129-
import types
130-
131-
return types.MethodType(TENSOR_METHOD_REGISTRY[name], self)
132-
133128
import nvtripy as tp
134129
from nvtripy.common.exception import search_for_missing_attr
135130

0 commit comments

Comments
 (0)