|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: BSD-2-Clause |
| 3 | + |
| 4 | +from numba.core.registry import DelayedRegistry |
| 5 | +from numba.cuda.vectorizers import CUDAVectorize, CUDAGUFuncVectorize |
| 6 | + |
| 7 | + |
| 8 | +class _BaseVectorize(object): |
| 9 | + @classmethod |
| 10 | + def get_identity(cls, kwargs): |
| 11 | + return kwargs.pop("identity", None) |
| 12 | + |
| 13 | + @classmethod |
| 14 | + def get_cache(cls, kwargs): |
| 15 | + return kwargs.pop("cache", False) |
| 16 | + |
| 17 | + @classmethod |
| 18 | + def get_writable_args(cls, kwargs): |
| 19 | + return kwargs.pop("writable_args", ()) |
| 20 | + |
| 21 | + @classmethod |
| 22 | + def get_target_implementation(cls, kwargs): |
| 23 | + target = kwargs.pop("target", "cpu") |
| 24 | + try: |
| 25 | + return cls.target_registry[target] |
| 26 | + except KeyError: |
| 27 | + raise ValueError("Unsupported target: %s" % target) |
| 28 | + |
| 29 | + |
| 30 | +class Vectorize(_BaseVectorize): |
| 31 | + target_registry = DelayedRegistry({"cuda": CUDAVectorize}) |
| 32 | + |
| 33 | + def __new__(cls, func, **kws): |
| 34 | + identity = cls.get_identity(kws) |
| 35 | + cache = cls.get_cache(kws) |
| 36 | + imp = cls.get_target_implementation(kws) |
| 37 | + return imp(func, identity=identity, cache=cache, targetoptions=kws) |
| 38 | + |
| 39 | + |
| 40 | +class GUVectorize(_BaseVectorize): |
| 41 | + target_registry = DelayedRegistry({"cuda": CUDAGUFuncVectorize}) |
| 42 | + |
| 43 | + def __new__(cls, func, signature, **kws): |
| 44 | + identity = cls.get_identity(kws) |
| 45 | + cache = cls.get_cache(kws) |
| 46 | + imp = cls.get_target_implementation(kws) |
| 47 | + writable_args = cls.get_writable_args(kws) |
| 48 | + return imp( |
| 49 | + func, |
| 50 | + signature, |
| 51 | + identity=identity, |
| 52 | + cache=cache, |
| 53 | + targetoptions=kws, |
| 54 | + writable_args=writable_args, |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +def vectorize(ftylist_or_function=(), target="cuda", **kws): |
| 59 | + """vectorize(ftylist_or_function=(), target='cuda', identity=None, **kws) |
| 60 | +
|
| 61 | + A decorator that creates a NumPy ufunc object using Numba compiled |
| 62 | + code. When no arguments or only keyword arguments are given, |
| 63 | + vectorize will return a Numba dynamic ufunc (DUFunc) object, where |
| 64 | + compilation/specialization may occur at call-time. |
| 65 | +
|
| 66 | + Args |
| 67 | + ----- |
| 68 | + ftylist_or_function: function or iterable |
| 69 | +
|
| 70 | + When the first argument is a function, signatures are dealt |
| 71 | + with at call-time. |
| 72 | +
|
| 73 | + When the first argument is an iterable of type signatures, |
| 74 | + which are either function type object or a string describing |
| 75 | + the function type, signatures are finalized at decoration |
| 76 | + time. |
| 77 | +
|
| 78 | + Keyword Args |
| 79 | + ------------ |
| 80 | +
|
| 81 | + target: str |
| 82 | + A string for code generation target. Default to "cuda". |
| 83 | +
|
| 84 | + identity: int, str, or None |
| 85 | + The identity (or unit) value for the element-wise function |
| 86 | + being implemented. Allowed values are None (the default), 0, 1, |
| 87 | + and "reorderable". |
| 88 | +
|
| 89 | + cache: bool |
| 90 | + Turns on caching. |
| 91 | +
|
| 92 | +
|
| 93 | + Returns |
| 94 | + -------- |
| 95 | +
|
| 96 | + A NumPy universal function |
| 97 | +
|
| 98 | + Examples |
| 99 | + ------- |
| 100 | + @vectorize(['float32(float32, float32)', |
| 101 | + 'float64(float64, float64)'], identity=0) |
| 102 | + def sum(a, b): |
| 103 | + return a + b |
| 104 | +
|
| 105 | + @vectorize |
| 106 | + def sum(a, b): |
| 107 | + return a + b |
| 108 | +
|
| 109 | + @vectorize(identity=1) |
| 110 | + def mul(a, b): |
| 111 | + return a * b |
| 112 | +
|
| 113 | + """ |
| 114 | + if isinstance(ftylist_or_function, str): |
| 115 | + # Common user mistake |
| 116 | + ftylist = [ftylist_or_function] |
| 117 | + elif ftylist_or_function is not None: |
| 118 | + ftylist = ftylist_or_function |
| 119 | + |
| 120 | + def wrap(func): |
| 121 | + kws["target"] = target |
| 122 | + vec = Vectorize(func, **kws) |
| 123 | + for sig in ftylist: |
| 124 | + vec.add(sig) |
| 125 | + if len(ftylist) > 0: |
| 126 | + vec.disable_compile() |
| 127 | + return vec.build_ufunc() |
| 128 | + |
| 129 | + return wrap |
| 130 | + |
| 131 | + |
| 132 | +def guvectorize(*args, **kwargs): |
| 133 | + """guvectorize(ftylist, signature, target='cuda', identity=None, **kws) |
| 134 | +
|
| 135 | + A decorator to create NumPy generalized-ufunc object from Numba compiled |
| 136 | + code. |
| 137 | +
|
| 138 | + Args |
| 139 | + ----- |
| 140 | + ftylist: iterable |
| 141 | + An iterable of type signatures, which are either |
| 142 | + function type object or a string describing the |
| 143 | + function type. |
| 144 | +
|
| 145 | + signature: str |
| 146 | + A NumPy generalized-ufunc signature. |
| 147 | + e.g. "(m, n), (n, p)->(m, p)" |
| 148 | +
|
| 149 | + identity: int, str, or None |
| 150 | + The identity (or unit) value for the element-wise function |
| 151 | + being implemented. Allowed values are None (the default), 0, 1, |
| 152 | + and "reorderable". |
| 153 | +
|
| 154 | + cache: bool |
| 155 | + Turns on caching. |
| 156 | +
|
| 157 | + writable_args: tuple |
| 158 | + a tuple of indices of input variables that are writable. |
| 159 | +
|
| 160 | + target: str |
| 161 | + A string for code generation target. Defaults to "cuda". |
| 162 | +
|
| 163 | + Returns |
| 164 | + -------- |
| 165 | +
|
| 166 | + A NumPy generalized universal-function |
| 167 | +
|
| 168 | + Example |
| 169 | + ------- |
| 170 | + @guvectorize(['void(int32[:,:], int32[:,:], int32[:,:])', |
| 171 | + 'void(float32[:,:], float32[:,:], float32[:,:])'], |
| 172 | + '(x, y),(x, y)->(x, y)') |
| 173 | + def add_2d_array(a, b, c): |
| 174 | + for i in range(c.shape[0]): |
| 175 | + for j in range(c.shape[1]): |
| 176 | + c[i, j] = a[i, j] + b[i, j] |
| 177 | +
|
| 178 | + """ |
| 179 | + if len(args) == 1: |
| 180 | + ftylist = [] |
| 181 | + signature = args[0] |
| 182 | + kwargs.setdefault("is_dynamic", True) |
| 183 | + elif len(args) == 2: |
| 184 | + ftylist = args[0] |
| 185 | + signature = args[1] |
| 186 | + else: |
| 187 | + raise TypeError("guvectorize() takes one or two positional arguments") |
| 188 | + |
| 189 | + if isinstance(ftylist, str): |
| 190 | + # Common user mistake |
| 191 | + ftylist = [ftylist] |
| 192 | + |
| 193 | + kwargs.setdefault("target", "cuda") |
| 194 | + |
| 195 | + def wrap(func): |
| 196 | + guvec = GUVectorize(func, signature, **kwargs) |
| 197 | + for fty in ftylist: |
| 198 | + guvec.add(fty) |
| 199 | + if len(ftylist) > 0: |
| 200 | + guvec.disable_compile() |
| 201 | + return guvec.build_ufunc() |
| 202 | + |
| 203 | + return wrap |
0 commit comments