Skip to content

Commit 73fdf57

Browse files
committed
move extension_point from cuda-qx to cudaq
Signed-off-by: Alex McCaskey <[email protected]>
1 parent 449e61e commit 73fdf57

File tree

3 files changed

+375
-0
lines changed

3 files changed

+375
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
/****************************************************************-*- C++ -*-****
2+
* Copyright (c) 2022-2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#pragma once
10+
11+
#include <functional>
12+
#include <memory>
13+
#include <stdexcept>
14+
#include <unordered_map>
15+
16+
namespace cudaq {
17+
18+
/// @brief A template class for implementing an extension point mechanism.
19+
///
20+
/// This class provides a framework for registering and retrieving plugin-like
21+
/// extensions. It allows dynamic creation of objects based on registered types.
22+
///
23+
/// @tparam T The base type of the extensions.
24+
/// @tparam CtorArgs Variadic template parameters for constructor arguments.
25+
///
26+
/// How to use the extension_point class
27+
///
28+
/// The extension_point class provides a mechanism for creating extensible
29+
/// frameworks with plugin-like functionality. Here's how to use it:
30+
///
31+
/// 1. Define your extension point:
32+
/// Create a new class that inherits from cudaq::extension_point<YourClass>.
33+
/// This class should declare pure virtual methods that extensions will
34+
/// implement.
35+
///
36+
/// @code
37+
/// class MyExtensionPoint : public cudaq::extension_point<MyExtensionPoint> {
38+
/// public:
39+
/// virtual std::string parrotBack(const std::string &msg) const = 0;
40+
/// };
41+
/// @endcode
42+
///
43+
/// 2. Implement concrete extensions:
44+
/// Create classes that inherit from your extension point and implement its
45+
/// methods. Use the CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION macro to define a
46+
/// creator function.
47+
///
48+
/// @code
49+
/// class RepeatBackOne : public MyExtensionPoint {
50+
/// public:
51+
/// std::string parrotBack(const std::string &msg) const override {
52+
/// return msg + " from RepeatBackOne.";
53+
/// }
54+
///
55+
/// CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackOne)
56+
/// };
57+
/// @endcode
58+
///
59+
/// 3. Register your extensions:
60+
/// Use the CUDAQ_REGISTER_EXTENSION macro to register each extension.
61+
///
62+
/// @code
63+
/// CUDAQ_REGISTER_EXTENSION(RepeatBackOne)
64+
/// @endcode
65+
///
66+
/// 4. Use your extensions:
67+
/// You can now create instances of your extensions, check registrations, and
68+
/// more.
69+
///
70+
/// @code
71+
/// auto extension = MyExtensionPoint::get("RepeatBackOne");
72+
/// std::cout << extension->parrotBack("Hello") << std::endl;
73+
///
74+
/// auto registeredTypes = MyExtensionPoint::get_registered();
75+
/// bool isRegistered = MyExtensionPoint::is_registered("RepeatBackOne");
76+
/// @endcode
77+
///
78+
/// This approach allows for a flexible, extensible design where new
79+
/// functionality can be added without modifying existing code.
80+
template <typename T, typename... CtorArgs>
81+
class extension_point {
82+
83+
/// Type alias for the creator function.
84+
using CreatorFunction = std::function<std::unique_ptr<T>(CtorArgs...)>;
85+
86+
protected:
87+
/// @brief Get the registry of creator functions.
88+
/// @return A reference to the static registry map.
89+
/// See CUDAQ_INSTANTIATE_REGISTRY() macros below for sample implementations
90+
/// that need to be included in C++ source files.
91+
static std::unordered_map<std::string, CreatorFunction> &get_registry();
92+
93+
public:
94+
/// @brief Create an instance of a registered extension.
95+
/// @param name The identifier of the registered extension.
96+
/// @param args Constructor arguments for the extension.
97+
/// @return A unique pointer to the created instance.
98+
/// @throws std::runtime_error if the extension is not found.
99+
static std::unique_ptr<T> get(const std::string &name, CtorArgs... args) {
100+
auto &registry = get_registry();
101+
auto iter = registry.find(name);
102+
if (iter == registry.end())
103+
throw std::runtime_error("Cannot find extension with name = " + name);
104+
105+
return iter->second(std::forward<CtorArgs>(args)...);
106+
}
107+
108+
/// @brief Get a list of all registered extension names.
109+
/// @return A vector of registered extension names.
110+
static std::vector<std::string> get_registered() {
111+
std::vector<std::string> names;
112+
auto &registry = get_registry();
113+
for (auto &[k, v] : registry)
114+
names.push_back(k);
115+
return names;
116+
}
117+
118+
/// @brief Check if an extension is registered.
119+
/// @param name The identifier of the extension to check.
120+
/// @return True if the extension is registered, false otherwise.
121+
static bool is_registered(const std::string &name) {
122+
auto &registry = get_registry();
123+
return registry.find(name) != registry.end();
124+
}
125+
virtual ~extension_point() = default;
126+
};
127+
128+
/// @brief Macro for defining a creator function for an extension.
129+
/// @param BASE The base class of the extension.
130+
/// @param TYPE The derived class implementing the extension.
131+
#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(BASE, TYPE) \
132+
static inline bool register_type() { \
133+
auto &registry = get_registry(); \
134+
registry[TYPE::class_identifier] = TYPE::create; \
135+
return true; \
136+
} \
137+
static const bool registered_; \
138+
static inline const std::string class_identifier = #TYPE; \
139+
static std::unique_ptr<BASE> create() { return std::make_unique<TYPE>(); }
140+
141+
#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION_WITH_NAME(NAME, BASE, TYPE) \
142+
static inline bool register_type() { \
143+
auto &registry = get_registry(); \
144+
registry[#NAME] = TYPE::create; \
145+
return true; \
146+
} \
147+
static const bool registered_; \
148+
static std::unique_ptr<BASE> create() { return std::make_unique<TYPE>(); }
149+
150+
/// @brief Macro for defining a custom creator function for an extension.
151+
/// @param TYPE The class implementing the extension.
152+
/// @param ... Custom implementation of the create function.
153+
#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(TYPE, ...) \
154+
static inline bool register_type() { \
155+
auto &registry = get_registry(); \
156+
registry[TYPE::class_identifier] = TYPE::create; \
157+
return true; \
158+
} \
159+
static const bool registered_; \
160+
static inline const std::string class_identifier = #TYPE; \
161+
__VA_ARGS__
162+
163+
#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION_WITH_NAME(TYPE, NAME, ...) \
164+
static inline bool register_type() { \
165+
auto &registry = TYPE::get_registry(); \
166+
registry.insert({NAME, TYPE::create}); \
167+
return true; \
168+
} \
169+
static const bool registered_; \
170+
static inline const std::string class_identifier = #TYPE; \
171+
__VA_ARGS__
172+
173+
/// @brief Macro for registering an extension type.
174+
/// @param TYPE The class to be registered as an extension.
175+
#define CUDAQ_REGISTER_EXTENSION(TYPE) \
176+
const bool TYPE::registered_ = TYPE::register_type();
177+
178+
/// In order to support building CUDA-Q libraries with g++ and building
179+
/// application code with nvq++ (which uses clang++ under the hood), you must
180+
/// implement the templated get_registry() function for every set of
181+
/// extension_point<Args..>. This *must* be done in a C++ file that is built
182+
/// with the CUDA-Q libraries.
183+
///
184+
/// Use this version of the helper macro if the only template argument to
185+
/// extension_point<> is the derived class (with no additional creator args).
186+
#define CUDAQ_INSTANTIATE_REGISTRY_NO_ARGS(FULL_TYPE_NAME) \
187+
template <> \
188+
std::unordered_map<std::string, \
189+
std::function<std::unique_ptr<FULL_TYPE_NAME>()>> & \
190+
cudaq::extension_point<FULL_TYPE_NAME>::get_registry() { \
191+
static std::unordered_map< \
192+
std::string, std::function<std::unique_ptr<FULL_TYPE_NAME>()>> \
193+
registry; \
194+
return registry; \
195+
}
196+
197+
/// Use this variadic version of the helper macro if there are additional
198+
/// arguments for the creator function.
199+
#define CUDAQ_INSTANTIATE_REGISTRY(FULL_TYPE_NAME, ...) \
200+
template <> \
201+
std::unordered_map< \
202+
std::string, \
203+
std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> & \
204+
cudaq::extension_point<FULL_TYPE_NAME, __VA_ARGS__>::get_registry() { \
205+
static std::unordered_map< \
206+
std::string, \
207+
std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> \
208+
registry; \
209+
return registry; \
210+
}
211+
212+
} // namespace cudaq

unittests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,3 +493,7 @@ if (CUDAQ_ENABLE_PYTHON)
493493
gtest_discover_tests(test_domains
494494
TEST_SUFFIX _Sampling PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_BINARY_DIR}/python")
495495
endif()
496+
497+
add_executable(test_extension_point extension/test_extension_point.cpp)
498+
target_link_libraries(test_extension_point PRIVATE GTest::gtest_main cudaq)
499+
gtest_discover_tests(test_extension_point)
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/****************************************************************-*- C++ -*-****
2+
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#include "cudaq/utils/extension_point.h"
10+
11+
#include <gtest/gtest.h>
12+
13+
namespace cudaq::testing {
14+
15+
// Define a new extension point for the framework
16+
class MyExtensionPoint : public cudaq::extension_point<MyExtensionPoint> {
17+
public:
18+
virtual std::string parrotBack(const std::string &msg) const = 0;
19+
virtual ~MyExtensionPoint() = default;
20+
};
21+
22+
} // namespace cudaq::testing
23+
24+
CUDAQ_INSTANTIATE_REGISTRY_NO_ARGS(cudaq::testing::MyExtensionPoint)
25+
26+
namespace cudaq::testing {
27+
28+
// Define a concrete realization of that extension point
29+
class RepeatBackOne : public MyExtensionPoint {
30+
public:
31+
std::string parrotBack(const std::string &msg) const override {
32+
return msg + " from RepeatBackOne.";
33+
}
34+
35+
// Extension must provide a creator function
36+
CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackOne)
37+
};
38+
39+
// Extensions must register themselves
40+
CUDAQ_REGISTER_EXTENSION(RepeatBackOne)
41+
42+
class RepeatBackTwo : public MyExtensionPoint {
43+
public:
44+
std::string parrotBack(const std::string &msg) const override {
45+
return msg + " from RepeatBackTwo.";
46+
}
47+
CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackTwo)
48+
};
49+
CUDAQ_REGISTER_EXTENSION(RepeatBackTwo)
50+
51+
} // namespace cudaq::testing
52+
53+
TEST(ExtensionPointTester, checkSimpleExtensionPoint) {
54+
55+
auto registeredNames = cudaq::testing::MyExtensionPoint::get_registered();
56+
EXPECT_EQ(registeredNames.size(), 2);
57+
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
58+
"RepeatBackTwo") != registeredNames.end());
59+
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
60+
"RepeatBackOne") != registeredNames.end());
61+
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
62+
"RepeatBackThree") == registeredNames.end());
63+
64+
{
65+
auto var = cudaq::testing::MyExtensionPoint::get("RepeatBackOne");
66+
EXPECT_EQ(var->parrotBack("Hello World"),
67+
"Hello World from RepeatBackOne.");
68+
}
69+
{
70+
auto var = cudaq::testing::MyExtensionPoint::get("RepeatBackTwo");
71+
EXPECT_EQ(var->parrotBack("Hello World"),
72+
"Hello World from RepeatBackTwo.");
73+
}
74+
}
75+
76+
namespace cudaq::testing {
77+
78+
class MyExtensionPointWithArgs
79+
: public cudaq::extension_point<MyExtensionPointWithArgs, int, double> {
80+
protected:
81+
int i;
82+
double d;
83+
84+
public:
85+
MyExtensionPointWithArgs(int i, double d) : i(i), d(d) {}
86+
virtual std::tuple<int, double, std::string> parrotBack() const = 0;
87+
virtual ~MyExtensionPointWithArgs() = default;
88+
};
89+
90+
} // namespace cudaq::testing
91+
92+
CUDAQ_INSTANTIATE_REGISTRY(cudaq::testing::MyExtensionPointWithArgs, int,
93+
double)
94+
95+
namespace cudaq::testing {
96+
97+
class RepeatBackOneWithArgs : public MyExtensionPointWithArgs {
98+
public:
99+
using MyExtensionPointWithArgs::MyExtensionPointWithArgs;
100+
std::tuple<int, double, std::string> parrotBack() const override {
101+
return std::make_tuple(i, d, "RepeatBackOne");
102+
}
103+
104+
CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(
105+
RepeatBackOneWithArgs,
106+
static std::unique_ptr<MyExtensionPointWithArgs> create(int i, double d) {
107+
return std::make_unique<RepeatBackOneWithArgs>(i, d);
108+
})
109+
};
110+
111+
CUDAQ_REGISTER_EXTENSION(RepeatBackOneWithArgs)
112+
113+
class RepeatBackTwoWithArgs : public MyExtensionPointWithArgs {
114+
public:
115+
using MyExtensionPointWithArgs::MyExtensionPointWithArgs;
116+
std::tuple<int, double, std::string> parrotBack() const override {
117+
return std::make_tuple(i, d, "RepeatBackTwo");
118+
}
119+
120+
CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(
121+
RepeatBackTwoWithArgs,
122+
static std::unique_ptr<MyExtensionPointWithArgs> create(int i, double d) {
123+
return std::make_unique<RepeatBackTwoWithArgs>(i, d);
124+
})
125+
};
126+
127+
CUDAQ_REGISTER_EXTENSION(RepeatBackTwoWithArgs)
128+
129+
} // namespace cudaq::testing
130+
131+
TEST(CoreTester, checkSimpleExtensionPointWithArgs) {
132+
133+
auto registeredNames =
134+
cudaq::testing::MyExtensionPointWithArgs::get_registered();
135+
EXPECT_EQ(registeredNames.size(), 2);
136+
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
137+
"RepeatBackTwoWithArgs") != registeredNames.end());
138+
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
139+
"RepeatBackOneWithArgs") != registeredNames.end());
140+
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
141+
"RepeatBackThree") == registeredNames.end());
142+
143+
{
144+
auto var = cudaq::testing::MyExtensionPointWithArgs::get(
145+
"RepeatBackOneWithArgs", 5, 2.2);
146+
auto [i, d, msg] = var->parrotBack();
147+
EXPECT_EQ(msg, "RepeatBackOne");
148+
EXPECT_EQ(i, 5);
149+
EXPECT_NEAR(d, 2.2, 1e-2);
150+
}
151+
{
152+
auto var = cudaq::testing::MyExtensionPointWithArgs::get(
153+
"RepeatBackTwoWithArgs", 15, 12.2);
154+
auto [i, d, msg] = var->parrotBack();
155+
EXPECT_EQ(msg, "RepeatBackTwo");
156+
EXPECT_EQ(i, 15);
157+
EXPECT_NEAR(d, 12.2, 1e-2);
158+
}
159+
}

0 commit comments

Comments
 (0)