|
| 1 | +/****************************************************************-*- C++ -*-**** |
| 2 | + * Copyright (c) 2022-2025 NVIDIA Corporation & Affiliates. * |
| 3 | + * All rights reserved. * |
| 4 | + * * |
| 5 | + * This source code and the accompanying materials are made available under * |
| 6 | + * the terms of the Apache License 2.0 which accompanies this distribution. * |
| 7 | + ******************************************************************************/ |
| 8 | + |
| 9 | +#pragma once |
| 10 | + |
| 11 | +#include <functional> |
| 12 | +#include <memory> |
| 13 | +#include <stdexcept> |
| 14 | +#include <unordered_map> |
| 15 | + |
| 16 | +namespace cudaq { |
| 17 | + |
| 18 | +/// @brief A template class for implementing an extension point mechanism. |
| 19 | +/// |
| 20 | +/// This class provides a framework for registering and retrieving plugin-like |
| 21 | +/// extensions. It allows dynamic creation of objects based on registered types. |
| 22 | +/// |
| 23 | +/// @tparam T The base type of the extensions. |
| 24 | +/// @tparam CtorArgs Variadic template parameters for constructor arguments. |
| 25 | +/// |
| 26 | +/// How to use the extension_point class |
| 27 | +/// |
| 28 | +/// The extension_point class provides a mechanism for creating extensible |
| 29 | +/// frameworks with plugin-like functionality. Here's how to use it: |
| 30 | +/// |
| 31 | +/// 1. Define your extension point: |
| 32 | +/// Create a new class that inherits from cudaq::extension_point<YourClass>. |
| 33 | +/// This class should declare pure virtual methods that extensions will |
| 34 | +/// implement. |
| 35 | +/// |
| 36 | +/// @code |
| 37 | +/// class MyExtensionPoint : public cudaq::extension_point<MyExtensionPoint> { |
| 38 | +/// public: |
| 39 | +/// virtual std::string parrotBack(const std::string &msg) const = 0; |
| 40 | +/// }; |
| 41 | +/// @endcode |
| 42 | +/// |
| 43 | +/// 2. Implement concrete extensions: |
| 44 | +/// Create classes that inherit from your extension point and implement its |
| 45 | +/// methods. Use the CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION macro to define a |
| 46 | +/// creator function. |
| 47 | +/// |
| 48 | +/// @code |
| 49 | +/// class RepeatBackOne : public MyExtensionPoint { |
| 50 | +/// public: |
| 51 | +/// std::string parrotBack(const std::string &msg) const override { |
| 52 | +/// return msg + " from RepeatBackOne."; |
| 53 | +/// } |
| 54 | +/// |
| 55 | +/// CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackOne) |
| 56 | +/// }; |
| 57 | +/// @endcode |
| 58 | +/// |
| 59 | +/// 3. Register your extensions: |
| 60 | +/// Use the CUDAQ_REGISTER_EXTENSION macro to register each extension. |
| 61 | +/// |
| 62 | +/// @code |
| 63 | +/// CUDAQ_REGISTER_EXTENSION(RepeatBackOne) |
| 64 | +/// @endcode |
| 65 | +/// |
| 66 | +/// 4. Use your extensions: |
| 67 | +/// You can now create instances of your extensions, check registrations, and |
| 68 | +/// more. |
| 69 | +/// |
| 70 | +/// @code |
| 71 | +/// auto extension = MyExtensionPoint::get("RepeatBackOne"); |
| 72 | +/// std::cout << extension->parrotBack("Hello") << std::endl; |
| 73 | +/// |
| 74 | +/// auto registeredTypes = MyExtensionPoint::get_registered(); |
| 75 | +/// bool isRegistered = MyExtensionPoint::is_registered("RepeatBackOne"); |
| 76 | +/// @endcode |
| 77 | +/// |
| 78 | +/// This approach allows for a flexible, extensible design where new |
| 79 | +/// functionality can be added without modifying existing code. |
| 80 | +template <typename T, typename... CtorArgs> |
| 81 | +class extension_point { |
| 82 | + |
| 83 | + /// Type alias for the creator function. |
| 84 | + using CreatorFunction = std::function<std::unique_ptr<T>(CtorArgs...)>; |
| 85 | + |
| 86 | +protected: |
| 87 | + /// @brief Get the registry of creator functions. |
| 88 | + /// @return A reference to the static registry map. |
| 89 | + /// See CUDAQ_INSTANTIATE_REGISTRY() macros below for sample implementations |
| 90 | + /// that need to be included in C++ source files. |
| 91 | + static std::unordered_map<std::string, CreatorFunction> &get_registry(); |
| 92 | + |
| 93 | +public: |
| 94 | + /// @brief Create an instance of a registered extension. |
| 95 | + /// @param name The identifier of the registered extension. |
| 96 | + /// @param args Constructor arguments for the extension. |
| 97 | + /// @return A unique pointer to the created instance. |
| 98 | + /// @throws std::runtime_error if the extension is not found. |
| 99 | + static std::unique_ptr<T> get(const std::string &name, CtorArgs... args) { |
| 100 | + auto ®istry = get_registry(); |
| 101 | + auto iter = registry.find(name); |
| 102 | + if (iter == registry.end()) |
| 103 | + throw std::runtime_error("Cannot find extension with name = " + name); |
| 104 | + |
| 105 | + return iter->second(std::forward<CtorArgs>(args)...); |
| 106 | + } |
| 107 | + |
| 108 | + /// @brief Get a list of all registered extension names. |
| 109 | + /// @return A vector of registered extension names. |
| 110 | + static std::vector<std::string> get_registered() { |
| 111 | + std::vector<std::string> names; |
| 112 | + auto ®istry = get_registry(); |
| 113 | + for (auto &[k, v] : registry) |
| 114 | + names.push_back(k); |
| 115 | + return names; |
| 116 | + } |
| 117 | + |
| 118 | + /// @brief Check if an extension is registered. |
| 119 | + /// @param name The identifier of the extension to check. |
| 120 | + /// @return True if the extension is registered, false otherwise. |
| 121 | + static bool is_registered(const std::string &name) { |
| 122 | + auto ®istry = get_registry(); |
| 123 | + return registry.find(name) != registry.end(); |
| 124 | + } |
| 125 | + virtual ~extension_point() = default; |
| 126 | +}; |
| 127 | + |
| 128 | +/// @brief Macro for defining a creator function for an extension. |
| 129 | +/// @param BASE The base class of the extension. |
| 130 | +/// @param TYPE The derived class implementing the extension. |
| 131 | +#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(BASE, TYPE) \ |
| 132 | + static inline bool register_type() { \ |
| 133 | + auto ®istry = get_registry(); \ |
| 134 | + registry[TYPE::class_identifier] = TYPE::create; \ |
| 135 | + return true; \ |
| 136 | + } \ |
| 137 | + static const bool registered_; \ |
| 138 | + static inline const std::string class_identifier = #TYPE; \ |
| 139 | + static std::unique_ptr<BASE> create() { return std::make_unique<TYPE>(); } |
| 140 | + |
| 141 | +#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION_WITH_NAME(NAME, BASE, TYPE) \ |
| 142 | + static inline bool register_type() { \ |
| 143 | + auto ®istry = get_registry(); \ |
| 144 | + registry[#NAME] = TYPE::create; \ |
| 145 | + return true; \ |
| 146 | + } \ |
| 147 | + static const bool registered_; \ |
| 148 | + static std::unique_ptr<BASE> create() { return std::make_unique<TYPE>(); } |
| 149 | + |
| 150 | +/// @brief Macro for defining a custom creator function for an extension. |
| 151 | +/// @param TYPE The class implementing the extension. |
| 152 | +/// @param ... Custom implementation of the create function. |
| 153 | +#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(TYPE, ...) \ |
| 154 | + static inline bool register_type() { \ |
| 155 | + auto ®istry = get_registry(); \ |
| 156 | + registry[TYPE::class_identifier] = TYPE::create; \ |
| 157 | + return true; \ |
| 158 | + } \ |
| 159 | + static const bool registered_; \ |
| 160 | + static inline const std::string class_identifier = #TYPE; \ |
| 161 | + __VA_ARGS__ |
| 162 | + |
| 163 | +#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION_WITH_NAME(TYPE, NAME, ...) \ |
| 164 | + static inline bool register_type() { \ |
| 165 | + auto ®istry = TYPE::get_registry(); \ |
| 166 | + registry.insert({NAME, TYPE::create}); \ |
| 167 | + return true; \ |
| 168 | + } \ |
| 169 | + static const bool registered_; \ |
| 170 | + static inline const std::string class_identifier = #TYPE; \ |
| 171 | + __VA_ARGS__ |
| 172 | + |
| 173 | +/// @brief Macro for registering an extension type. |
| 174 | +/// @param TYPE The class to be registered as an extension. |
| 175 | +#define CUDAQ_REGISTER_EXTENSION(TYPE) \ |
| 176 | + const bool TYPE::registered_ = TYPE::register_type(); |
| 177 | + |
| 178 | +/// In order to support building CUDA-Q libraries with g++ and building |
| 179 | +/// application code with nvq++ (which uses clang++ under the hood), you must |
| 180 | +/// implement the templated get_registry() function for every set of |
| 181 | +/// extension_point<Args..>. This *must* be done in a C++ file that is built |
| 182 | +/// with the CUDA-Q libraries. |
| 183 | +/// |
| 184 | +/// Use this version of the helper macro if the only template argument to |
| 185 | +/// extension_point<> is the derived class (with no additional creator args). |
| 186 | +#define CUDAQ_INSTANTIATE_REGISTRY_NO_ARGS(FULL_TYPE_NAME) \ |
| 187 | + template <> \ |
| 188 | + std::unordered_map<std::string, \ |
| 189 | + std::function<std::unique_ptr<FULL_TYPE_NAME>()>> & \ |
| 190 | + cudaq::extension_point<FULL_TYPE_NAME>::get_registry() { \ |
| 191 | + static std::unordered_map< \ |
| 192 | + std::string, std::function<std::unique_ptr<FULL_TYPE_NAME>()>> \ |
| 193 | + registry; \ |
| 194 | + return registry; \ |
| 195 | + } |
| 196 | + |
| 197 | +/// Use this variadic version of the helper macro if there are additional |
| 198 | +/// arguments for the creator function. |
| 199 | +#define CUDAQ_INSTANTIATE_REGISTRY(FULL_TYPE_NAME, ...) \ |
| 200 | + template <> \ |
| 201 | + std::unordered_map< \ |
| 202 | + std::string, \ |
| 203 | + std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> & \ |
| 204 | + cudaq::extension_point<FULL_TYPE_NAME, __VA_ARGS__>::get_registry() { \ |
| 205 | + static std::unordered_map< \ |
| 206 | + std::string, \ |
| 207 | + std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> \ |
| 208 | + registry; \ |
| 209 | + return registry; \ |
| 210 | + } |
| 211 | + |
| 212 | +} // namespace cudaq |
0 commit comments