|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: BSD-2-Clause |
| 3 | + |
| 4 | +import weakref |
| 5 | +from collections import ChainMap |
| 6 | + |
| 7 | +from numba.cuda import types |
| 8 | + |
| 9 | + |
| 10 | +class DataModelManager(object): |
| 11 | + """Manages mapping of FE types to their corresponding data model""" |
| 12 | + |
| 13 | + def __init__(self, handlers=None): |
| 14 | + """ |
| 15 | + Parameters |
| 16 | + ----------- |
| 17 | + handlers: Mapping[Type, DataModel] or None |
| 18 | + Optionally provide the initial handlers mapping. |
| 19 | + """ |
| 20 | + # { numba type class -> model factory } |
| 21 | + self._handlers = handlers or {} |
| 22 | + # { numba type instance -> model instance } |
| 23 | + self._cache = weakref.WeakKeyDictionary() |
| 24 | + |
| 25 | + def register(self, fetypecls, handler): |
| 26 | + """Register the datamodel factory corresponding to a frontend-type class""" |
| 27 | + assert issubclass(fetypecls, types.Type) |
| 28 | + self._handlers[fetypecls] = handler |
| 29 | + |
| 30 | + def lookup(self, fetype): |
| 31 | + """Returns the corresponding datamodel given the frontend-type instance""" |
| 32 | + try: |
| 33 | + return self._cache[fetype] |
| 34 | + except KeyError: |
| 35 | + pass |
| 36 | + handler = self._handlers[type(fetype)] |
| 37 | + model = self._cache[fetype] = handler(self, fetype) |
| 38 | + return model |
| 39 | + |
| 40 | + def __getitem__(self, fetype): |
| 41 | + """Shorthand for lookup()""" |
| 42 | + return self.lookup(fetype) |
| 43 | + |
| 44 | + def copy(self): |
| 45 | + """ |
| 46 | + Make a copy of the manager. |
| 47 | + Use this to inherit from the default data model and specialize it |
| 48 | + for custom target. |
| 49 | + """ |
| 50 | + return DataModelManager(self._handlers.copy()) |
| 51 | + |
| 52 | + def chain(self, other_manager): |
| 53 | + """Create a new DataModelManager by chaining the handlers mapping of |
| 54 | + `other_manager` with a fresh handlers mapping. |
| 55 | +
|
| 56 | + Any existing and new handlers inserted to `other_manager` will be |
| 57 | + visible to the new manager. Any handlers inserted to the new manager |
| 58 | + can override existing handlers in `other_manager` without actually |
| 59 | + mutating `other_manager`. |
| 60 | +
|
| 61 | + Parameters |
| 62 | + ---------- |
| 63 | + other_manager: DataModelManager |
| 64 | + """ |
| 65 | + chained = ChainMap(self._handlers, other_manager._handlers) |
| 66 | + return DataModelManager(chained) |
0 commit comments