-
Notifications
You must be signed in to change notification settings - Fork 48
Allow registering an ArgHandler #504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Allow registering an ArgHandler #504
Conversation
|
Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Adding to the pre-populated list of extensions will increase the launch time for all kernels whether they use the extension or not - I'm not sure how great the impact will be, but I am concerned it could increase kernel launch time by a noticeable percentage. If we go with this route (as opposed to globally registered extensions being tried right before the |
|
Regarding performance - could we add support through the type's method instead of global registry? It should not impact existing kernel launches and will have O(1) overhead for those types that actually need custom support. UPD: I looked at the implementation - it is map lookup, so should be also O(1) overhead per argument. We should greatly improve performance by switching existing logic to |
|
/ok to test |
Agree. @gmarkall does the lookup approach alleviate your concern about perf? |
|
/ok to test |
|
/ok to test |
gmarkall
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've had another look and made some comments towards simplifying the API - I think it was always overcomplicated, so it would be good not to propagate that needless complication into a new API.
I think the testing is a bit light, and have some suggestions.
I also note that there's no documentation, but I don't think we need to block this PR for it because there's absolutely no documentation for argument handling extensions at all already. Once we're in a place where we have an API we're happy with, then I think it will be good to document it.
One other thought, which I'm not suggesting is a good idea, but could be a basis for further thought on the API: for a short while I considered that perhaps typeof implementations should register arg handlers, so that it's not necessary to manually register them at all. However, typeof is in the critical path for kernel launch, so it should not be fiddling around registering things. I wonder if there's another way we can avoid explicit registration at all, more closely integrated in the way typing works (similar to how it's not necessary to tell the jit decorator what to link if the extension adds its required files to the link during lowering).
|
/ok to test |
| def __init__(self, arr): | ||
| self.arr = arr | ||
|
|
||
| def numpy_array_wrapper_int32_arg_handler(ty, val, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here ty will be the Numba type from typeof, which will be types.int32[::1] in this case, but would also match any other type that comes out of numpy_array_wrapper_int32_typeof_impl. Would it make a more exemplar test to just return ty, val.arr, here?
(If there's a flaw in my logic / understanding, please do let me know)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, why make these specific to int32 types at all? Apart from the name, and the hardcoded types.int32[::1], it looks like these could be fully generic for a wrapper holding any dtype / shape NumPy array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is test code, I prefer to spell things out in some cases. Here we're testing the effect of one or more handlers on one or more classes, it makes sense to me to write them separately and somewhat verbosely for readability as to what we're testing, even if things are repeated.
gmarkall
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for changing this - I think the API design looks much nicer now.
A couple of other notes:
- I don't think we should allow registering multiple handlers for the same Python type. This seems like it would most likely be a user error, or cause user surprise.
- There are some other notes on the diff on the tests.
|
/ok to test |
|
/ok to test |
| ty, val = extension.prepare_args( | ||
| ty, val, stream=stream, retr=retr | ||
| ) | ||
| elif handler := _arg_handlers.get(type(val)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure the if / elif logic here is correct. It seems to implement "If there are any extensions registered for this kernel, ignore all globally-registered extensions". But I could imagine a scenario where there is a globally-registered extension for one argument, and a "locally"-registered one for another argument - I think this situation is precluded from working in the current implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this section of the code might have been what pushed me towards having a class that the user supplies with a particular interface in the original implementation. The existing API requires we pass a class with a _prepare_args method. If we expect the same thing from the new API, it lets us assemble a single container of handlers we end up passing along the existing codepath, and we don't have to have deal with two kinds / can error if we find a duplicate along the way.
|
@gmarkall should we instead be changing the existing Either way we go, I think these two APIs should expect to be passed the same things, and there may be some flexibility due to the lack of existing docs. We can go with the function vs the class if we desire, but the more I think about it the more I think the |
Greptile OverviewGreptile SummaryAdded support for registering custom argument handlers via
Key changes:
The feature enables third-party libraries to seamlessly integrate custom objects with numba-cuda kernels. Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Library
participant register_arg_handler
participant typeof_impl
participant _arg_handlers
participant Kernel
participant _prepare_args
User->>Library: Import library
Library->>register_arg_handler: register_arg_handler(handler, types, impl)
register_arg_handler->>_arg_handlers: Check for collision
alt Collision detected
register_arg_handler-->>Library: raise ValueError
else No collision
register_arg_handler->>typeof_impl: register(type)(impl)
register_arg_handler->>_arg_handlers: Store handler
register_arg_handler-->>Library: Success
end
User->>Kernel: kernel(custom_obj)
Kernel->>_prepare_args: _prepare_args(ty, val, ...)
alt self.extensions exists
_prepare_args->>_prepare_args: Use extension.prepare_args
else No extensions
_prepare_args->>_arg_handlers: Get handler for type(val)
alt Handler found
_prepare_args->>_arg_handlers: handler(ty, val, ...)
_arg_handlers-->>_prepare_args: Return (new_ty, new_val)
end
end
_prepare_args->>_prepare_args: Continue with standard arg handling
_prepare_args-->>Kernel: Prepared args
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
| for ty in handled_types: | ||
| if _arg_handlers.get(ty, None): | ||
| raise ValueError( | ||
| f"A handler for args of type {ty} is already registered." | ||
| ) | ||
| typeof_impl.register(ty)(impl) | ||
| _arg_handlers[ty] = handler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: If handled_types contains multiple types and a collision occurs on a later type, earlier types will already be registered in both _arg_handlers and typeof_impl, leaving partial registration state
| for ty in handled_types: | |
| if _arg_handlers.get(ty, None): | |
| raise ValueError( | |
| f"A handler for args of type {ty} is already registered." | |
| ) | |
| typeof_impl.register(ty)(impl) | |
| _arg_handlers[ty] = handler | |
| # Check all types for collisions first | |
| for ty in handled_types: | |
| if _arg_handlers.get(ty, None): | |
| raise ValueError( | |
| f"A handler for args of type {ty} is already registered." | |
| ) | |
| # If all checks pass, register all types | |
| for ty in handled_types: | |
| typeof_impl.register(ty)(impl) | |
| _arg_handlers[ty] = handler |
This PR allows a user to register extra arg handling for any type they wish to pass to a kernel. Arg handlers must inherit from
ArgHandlerBaseand implement the required handling. Then, a user may pass it toregister_arg_handler, potentially upon the import of a library that expects to pass a custom object to a numba kernel.