@@ -29,28 +29,19 @@ def register_tensor_method(name: str):
2929 Decorator to add the method to the tensor method registry with the name specified.
3030 This does not use the FunctionRegistry decorator because every tensor method would also be
3131 registered in the public function registry and we would prefer to avoid having overhead
32- from having to dispatch overloads and check types twice.
32+ from having to dispatch overloads and check types twice. This needs to be the top level decorator so we can
33+ get input type validation from other decorators like `public_api`.
3334 """
3435
3536 # We make a special exception for "shape" since we actually do want that to be a property
3637 # We also add additional methods of the tensor class that are not magic methods
3738 allowed_methods = ["copy" , "cast" , "shape" , "reshape" , "transpose" , "flatten" , "permute" , "squeeze" , "unsqueeze" ]
3839 assert name in allowed_methods or name .startswith (
3940 "__"
40- ), f"The tensor method registry should only be used for magic methods, but was used for: { name } "
41+ ), f"The tensor method registry should only be used for magic methods and specially allowed methods , but was used for: { name } "
4142
4243 def impl (func : Callable [..., Any ]) -> Callable [..., Any ]:
43- if name == "shape" :
44- TENSOR_METHOD_REGISTRY [name ] = func
45- else :
46- # Create a method wrapper that maps 'self' to the first argument (input)
47- # This is the standard pattern for all tensor methods except 'shape' (which is a property)
48- @wraps (func )
49- def method_wrapper (self , * args , ** kwargs ):
50- return func (self , * args , ** kwargs )
51-
52- TENSOR_METHOD_REGISTRY [name ] = method_wrapper
53-
44+ TENSOR_METHOD_REGISTRY [name ] = func
5445 return func
5546
5647 return impl
0 commit comments