Skip to content

Commit d455b9b

Browse files
authored
[REVIEW][NFC] Vendor in serialize to allow for future CUDA-specific refactoring and changes (#349)
The serialize file contain a bunch of helper functions that we may want to specialize in the future since these are used in many of our functions. This vendors them into this repo.
1 parent b1da947 commit d455b9b

File tree

6 files changed

+739
-5
lines changed

6 files changed

+739
-5
lines changed

numba_cuda/numba/cuda/codegen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from llvmlite import ir
22

3-
from numba.core import config, serialize
3+
from numba.core import config
4+
from numba.cuda import serialize
45
from .cudadrv import devices, driver, nvvm, runtime, nvrtc
56
from numba.cuda.core.codegen import Codegen, CodeLibrary
67
from numba.cuda.cudadrv.libs import get_cudalib

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444

4545

4646
from numba import mviewbuf
47-
from numba.core import serialize, config
48-
from numba.cuda import utils
47+
from numba.core import config
48+
from numba.cuda import utils, serialize
4949
from .error import CudaSupportError, CudaDriverError
5050
from .drvapi import API_PROTOTYPES
5151
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import weakref
99
import uuid
1010

11-
from numba.core import compiler, serialize, sigutils, types, typing, config
12-
from numba.cuda import utils
11+
from numba.core import compiler, sigutils, types, typing, config
12+
from numba.cuda import serialize, utils
1313
from numba.cuda.core.caching import Cache, CacheImpl, NullCache
1414
from numba.core.compiler_lock import global_compiler_lock
1515
from numba.core.dispatcher import _DispatcherBase

numba_cuda/numba/cuda/serialize.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""
2+
Serialization support for compiled functions.
3+
"""
4+
5+
import sys
6+
import abc
7+
import io
8+
import copyreg
9+
10+
11+
import pickle
12+
from numba import cloudpickle
13+
from llvmlite import ir
14+
15+
16+
#
17+
# Pickle support
18+
#
19+
20+
21+
def _rebuild_reduction(cls, *args):
22+
"""
23+
Global hook to rebuild a given class from its __reduce__ arguments.
24+
"""
25+
return cls._rebuild(*args)
26+
27+
28+
# Keep unpickled object via `numba_unpickle` alive.
29+
_unpickled_memo = {}
30+
31+
32+
def _numba_unpickle(address, bytedata, hashed):
33+
"""Used by `numba_unpickle` from _helperlib.c
34+
35+
Parameters
36+
----------
37+
address : int
38+
bytedata : bytes
39+
hashed : bytes
40+
41+
Returns
42+
-------
43+
obj : object
44+
unpickled object
45+
"""
46+
key = (address, hashed)
47+
try:
48+
obj = _unpickled_memo[key]
49+
except KeyError:
50+
_unpickled_memo[key] = obj = cloudpickle.loads(bytedata)
51+
return obj
52+
53+
54+
def dumps(obj):
55+
"""Similar to `pickle.dumps()`. Returns the serialized object in bytes."""
56+
pickler = NumbaPickler
57+
with io.BytesIO() as buf:
58+
p = pickler(buf, protocol=4)
59+
p.dump(obj)
60+
pickled = buf.getvalue()
61+
62+
return pickled
63+
64+
65+
def runtime_build_excinfo_struct(static_exc, exc_args):
66+
exc, static_args, locinfo = cloudpickle.loads(static_exc)
67+
real_args = []
68+
exc_args_iter = iter(exc_args)
69+
for arg in static_args:
70+
if isinstance(arg, ir.Value):
71+
real_args.append(next(exc_args_iter))
72+
else:
73+
real_args.append(arg)
74+
return (exc, tuple(real_args), locinfo)
75+
76+
77+
# Alias to pickle.loads to allow `serialize.loads()`
78+
loads = cloudpickle.loads
79+
80+
81+
class _CustomPickled:
82+
"""A wrapper for objects that must be pickled with `NumbaPickler`.
83+
84+
Standard `pickle` will pick up the implementation registered via `copyreg`.
85+
This will spawn a `NumbaPickler` instance to serialize the data.
86+
87+
`NumbaPickler` overrides the handling of this type so as not to spawn a
88+
new pickler for the object when it is already being pickled by a
89+
`NumbaPickler`.
90+
"""
91+
92+
__slots__ = "ctor", "states"
93+
94+
def __init__(self, ctor, states):
95+
self.ctor = ctor
96+
self.states = states
97+
98+
def _reduce(self):
99+
return _CustomPickled._rebuild, (self.ctor, self.states)
100+
101+
@classmethod
102+
def _rebuild(cls, ctor, states):
103+
return cls(ctor, states)
104+
105+
106+
def _unpickle__CustomPickled(serialized):
107+
"""standard unpickling for `_CustomPickled`.
108+
109+
Uses `NumbaPickler` to load.
110+
"""
111+
ctor, states = loads(serialized)
112+
return _CustomPickled(ctor, states)
113+
114+
115+
def _pickle__CustomPickled(cp):
116+
"""standard pickling for `_CustomPickled`.
117+
118+
Uses `NumbaPickler` to dump.
119+
"""
120+
serialized = dumps((cp.ctor, cp.states))
121+
return _unpickle__CustomPickled, (serialized,)
122+
123+
124+
# Register custom pickling for the standard pickler.
125+
copyreg.pickle(_CustomPickled, _pickle__CustomPickled)
126+
127+
128+
def custom_reduce(cls, states):
129+
"""For customizing object serialization in `__reduce__`.
130+
131+
Object states provided here are used as keyword arguments to the
132+
`._rebuild()` class method.
133+
134+
Parameters
135+
----------
136+
states : dict
137+
Dictionary of object states to be serialized.
138+
139+
Returns
140+
-------
141+
result : tuple
142+
This tuple conforms to the return type requirement for `__reduce__`.
143+
"""
144+
return custom_rebuild, (_CustomPickled(cls, states),)
145+
146+
147+
def custom_rebuild(custom_pickled):
148+
"""Customized object deserialization.
149+
150+
This function is referenced internally by `custom_reduce()`.
151+
"""
152+
cls, states = custom_pickled.ctor, custom_pickled.states
153+
return cls._rebuild(**states)
154+
155+
156+
def is_serialiable(obj):
157+
"""Check if *obj* can be serialized.
158+
159+
Parameters
160+
----------
161+
obj : object
162+
163+
Returns
164+
--------
165+
can_serialize : bool
166+
"""
167+
with io.BytesIO() as fout:
168+
pickler = NumbaPickler(fout)
169+
try:
170+
pickler.dump(obj)
171+
except pickle.PicklingError:
172+
return False
173+
else:
174+
return True
175+
176+
177+
def _no_pickle(obj):
178+
raise pickle.PicklingError(f"Pickling of {type(obj)} is unsupported")
179+
180+
181+
def disable_pickling(typ):
182+
"""This is called on a type to disable pickling"""
183+
NumbaPickler.disabled_types.add(typ)
184+
# Return `typ` to allow use as a decorator
185+
return typ
186+
187+
188+
class NumbaPickler(cloudpickle.CloudPickler):
189+
disabled_types = set()
190+
"""A set of types that pickling cannot is disabled.
191+
"""
192+
193+
def reducer_override(self, obj):
194+
# Overridden to disable pickling of certain types
195+
if type(obj) in self.disabled_types:
196+
_no_pickle(obj) # noreturn
197+
return super().reducer_override(obj)
198+
199+
200+
def _custom_reduce__custompickled(cp):
201+
return cp._reduce()
202+
203+
204+
NumbaPickler.dispatch_table[_CustomPickled] = _custom_reduce__custompickled
205+
206+
207+
class ReduceMixin(abc.ABC):
208+
"""A mixin class for objects that should be reduced by the NumbaPickler
209+
instead of the standard pickler.
210+
"""
211+
212+
# Subclass MUST override the below methods
213+
214+
@abc.abstractmethod
215+
def _reduce_states(self):
216+
raise NotImplementedError
217+
218+
@abc.abstractclassmethod
219+
def _rebuild(cls, **kwargs):
220+
raise NotImplementedError
221+
222+
# Subclass can override the below methods
223+
224+
def _reduce_class(self):
225+
return self.__class__
226+
227+
# Private methods
228+
229+
def __reduce__(self):
230+
return custom_reduce(self._reduce_class(), self._reduce_states())
231+
232+
233+
class PickleCallableByPath:
234+
"""Wrap a callable object to be pickled by path to workaround limitation
235+
in pickling due to non-pickleable objects in function non-locals.
236+
237+
Note:
238+
- Do not use this as a decorator.
239+
- Wrapped object must be a global that exist in its parent module and it
240+
can be imported by `from the_module import the_object`.
241+
242+
Usage:
243+
244+
>>> def my_fn(x):
245+
>>> ...
246+
>>> wrapped_fn = PickleCallableByPath(my_fn)
247+
>>> # refer to `wrapped_fn` instead of `my_fn`
248+
"""
249+
250+
def __init__(self, fn):
251+
self._fn = fn
252+
253+
def __call__(self, *args, **kwargs):
254+
return self._fn(*args, **kwargs)
255+
256+
def __reduce__(self):
257+
return type(self)._rebuild, (
258+
self._fn.__module__,
259+
self._fn.__name__,
260+
)
261+
262+
@classmethod
263+
def _rebuild(cls, modname, fn_path):
264+
return cls(getattr(sys.modules[modname], fn_path))

0 commit comments

Comments
 (0)