Skip to content

Conversation

@brandon-b-miller
Copy link
Contributor

This PR allows a user to register extra arg handling for any type they wish to pass to a kernel. Arg handlers must inherit from ArgHandlerBase and implement the required handling. Then, a user may pass it to register_arg_handler, potentially upon the import of a library that expects to pass a custom object to a numba kernel.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 3, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@gmarkall
Copy link
Contributor

gmarkall commented Oct 3, 2025

Adding to the pre-populated list of extensions will increase the launch time for all kernels whether they use the extension or not - I'm not sure how great the impact will be, but I am concerned it could increase kernel launch time by a noticeable percentage. If we go with this route (as opposed to globally registered extensions being tried right before the NotImplementedError in _prepare_args) can we make some measurements to determine the impact on launch time, so we can understand the tradeoff we're making in ease of use vs. speed?

@gmarkall gmarkall added the 2 - In Progress Currently a work in progress label Oct 7, 2025
@ZzEeKkAa
Copy link
Contributor

ZzEeKkAa commented Oct 21, 2025

Regarding performance - could we add support through the type's method instead of global registry? It should not impact existing kernel launches and will have O(1) overhead for those types that actually need custom support.

UPD: I looked at the implementation - it is map lookup, so should be also O(1) overhead per argument. We should greatly improve performance by switching existing logic to register_arg_handler

@brandon-b-miller
Copy link
Contributor Author

/ok to test

@brandon-b-miller
Copy link
Contributor Author

Regarding performance - could we add support through the type's method instead of global registry? It should not impact existing kernel launches and will have O(1) overhead for those types that actually need custom support.

UPD: I looked at the implementation - it is map lookup, so should be also O(1) overhead per argument. We should greatly improve performance by switching existing logic to register_arg_handler

Agree. @gmarkall does the lookup approach alleviate your concern about perf?

@brandon-b-miller
Copy link
Contributor Author

/ok to test

@brandon-b-miller brandon-b-miller added 3 - Ready for Review Ready for review by team and removed 2 - In Progress Currently a work in progress labels Oct 27, 2025
@brandon-b-miller
Copy link
Contributor Author

/ok to test

Copy link
Contributor

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had another look and made some comments towards simplifying the API - I think it was always overcomplicated, so it would be good not to propagate that needless complication into a new API.

I think the testing is a bit light, and have some suggestions.

I also note that there's no documentation, but I don't think we need to block this PR for it because there's absolutely no documentation for argument handling extensions at all already. Once we're in a place where we have an API we're happy with, then I think it will be good to document it.

One other thought, which I'm not suggesting is a good idea, but could be a basis for further thought on the API: for a short while I considered that perhaps typeof implementations should register arg handlers, so that it's not necessary to manually register them at all. However, typeof is in the critical path for kernel launch, so it should not be fiddling around registering things. I wonder if there's another way we can avoid explicit registration at all, more closely integrated in the way typing works (similar to how it's not necessary to tell the jit decorator what to link if the extension adds its required files to the link during lowering).

@gmarkall gmarkall added 4 - Waiting on author Waiting for author to respond to review and removed 3 - Ready for Review Ready for review by team labels Nov 18, 2025
@brandon-b-miller
Copy link
Contributor Author

/ok to test

def __init__(self, arr):
self.arr = arr

def numpy_array_wrapper_int32_arg_handler(ty, val, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here ty will be the Numba type from typeof, which will be types.int32[::1] in this case, but would also match any other type that comes out of numpy_array_wrapper_int32_typeof_impl. Would it make a more exemplar test to just return ty, val.arr, here?

(If there's a flaw in my logic / understanding, please do let me know)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, why make these specific to int32 types at all? Apart from the name, and the hardcoded types.int32[::1], it looks like these could be fully generic for a wrapper holding any dtype / shape NumPy array?

Copy link
Contributor Author

@brandon-b-miller brandon-b-miller Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is test code, I prefer to spell things out in some cases. Here we're testing the effect of one or more handlers on one or more classes, it makes sense to me to write them separately and somewhat verbosely for readability as to what we're testing, even if things are repeated.

Copy link
Contributor

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for changing this - I think the API design looks much nicer now.

A couple of other notes:

  • I don't think we should allow registering multiple handlers for the same Python type. This seems like it would most likely be a user error, or cause user surprise.
  • There are some other notes on the diff on the tests.

@brandon-b-miller
Copy link
Contributor Author

/ok to test

@brandon-b-miller
Copy link
Contributor Author

/ok to test

ty, val = extension.prepare_args(
ty, val, stream=stream, retr=retr
)
elif handler := _arg_handlers.get(type(val)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the if / elif logic here is correct. It seems to implement "If there are any extensions registered for this kernel, ignore all globally-registered extensions". But I could imagine a scenario where there is a globally-registered extension for one argument, and a "locally"-registered one for another argument - I think this situation is precluded from working in the current implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this section of the code might have been what pushed me towards having a class that the user supplies with a particular interface in the original implementation. The existing API requires we pass a class with a _prepare_args method. If we expect the same thing from the new API, it lets us assemble a single container of handlers we end up passing along the existing codepath, and we don't have to have deal with two kinds / can error if we find a duplicate along the way.

@brandon-b-miller
Copy link
Contributor Author

@gmarkall should we instead be changing the existing extensions interface to be more like the one we're moving towards in this PR now? I note that the existing extensions API does not require a typeof_impl, meaning it only works for the case where a signature is provided as well.

Either way we go, I think these two APIs should expect to be passed the same things, and there may be some flexibility due to the lack of existing docs. We can go with the function vs the class if we desire, but the more I think about it the more I think the typeof_impl requirement should be optional. It doesn't buy the person registering the handler much over just calling typeof_impl.register on import in their own library, but it does make it hard to normalize the extensions api with the register_arg_handler API. What do you think?

@greptile-apps
Copy link

greptile-apps bot commented Nov 21, 2025

Greptile Overview

Greptile Summary

Added support for registering custom argument handlers via register_arg_handler, allowing libraries to pass custom types to CUDA kernels. The implementation:

  • Introduces global _arg_handlers dictionary to store type-to-handler mappings
  • Modifies _Kernel._prepare_args to check registered handlers when no explicit extensions are provided
  • Registers typeof implementations alongside arg handlers for proper type inference
  • Includes collision detection to prevent duplicate handler registration

Key changes:

  • register_arg_handler(handler, handled_types, impl) - new public API for registering handlers
  • Modified arg preparation logic uses elif to check _arg_handlers only when self.extensions is falsy
  • Comprehensive test coverage including basic usage, multiple handlers, collision cases, and interaction with explicit extensions

The feature enables third-party libraries to seamlessly integrate custom objects with numba-cuda kernels.

Confidence Score: 3/5

  • Safe to merge with one logic issue that should be addressed
  • The PR adds valuable functionality with good test coverage, but has a critical logic bug where registering multiple types can leave the system in a partially registered state if a collision occurs on a later type. The collision check should validate all types before registering any.
  • Pay close attention to numba_cuda/numba/cuda/dispatcher.py lines 2131-2137 where the registration logic needs atomic behavior

Important Files Changed

File Analysis

Filename Score Overview
numba_cuda/numba/cuda/dispatcher.py 3/5 Added register_arg_handler function and modified _prepare_args to check registered handlers when extensions aren't provided. Potential issue: partial registration state when handling multiple types with a collision.
numba_cuda/numba/cuda/tests/cudapy/test_extending.py 4/5 Added comprehensive test suite for arg handler registration covering basic usage, multiple handlers, collision detection, and interaction with explicit extensions.

Sequence Diagram

sequenceDiagram
    participant User
    participant Library
    participant register_arg_handler
    participant typeof_impl
    participant _arg_handlers
    participant Kernel
    participant _prepare_args

    User->>Library: Import library
    Library->>register_arg_handler: register_arg_handler(handler, types, impl)
    register_arg_handler->>_arg_handlers: Check for collision
    alt Collision detected
        register_arg_handler-->>Library: raise ValueError
    else No collision
        register_arg_handler->>typeof_impl: register(type)(impl)
        register_arg_handler->>_arg_handlers: Store handler
        register_arg_handler-->>Library: Success
    end
    
    User->>Kernel: kernel(custom_obj)
    Kernel->>_prepare_args: _prepare_args(ty, val, ...)
    alt self.extensions exists
        _prepare_args->>_prepare_args: Use extension.prepare_args
    else No extensions
        _prepare_args->>_arg_handlers: Get handler for type(val)
        alt Handler found
            _prepare_args->>_arg_handlers: handler(ty, val, ...)
            _arg_handlers-->>_prepare_args: Return (new_ty, new_val)
        end
    end
    _prepare_args->>_prepare_args: Continue with standard arg handling
    _prepare_args-->>Kernel: Prepared args
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +2131 to +2137
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
typeof_impl.register(ty)(impl)
_arg_handlers[ty] = handler
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: If handled_types contains multiple types and a collision occurs on a later type, earlier types will already be registered in both _arg_handlers and typeof_impl, leaving partial registration state

Suggested change
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
typeof_impl.register(ty)(impl)
_arg_handlers[ty] = handler
# Check all types for collisions first
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
# If all checks pass, register all types
for ty in handled_types:
typeof_impl.register(ty)(impl)
_arg_handlers[ty] = handler

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

4 - Waiting on author Waiting for author to respond to review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants