Skip to content

Commit 5dbb605

Browse files
Update
1 parent cff1032 commit 5dbb605

File tree

4 files changed

+268
-31
lines changed

4 files changed

+268
-31
lines changed

cpp/benchmarks/CMakeLists.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,7 @@ function(kvikio_add_benchmark)
4444
add_executable(${_KVIKIO_NAME} ${_KVIKIO_SOURCES})
4545
set_target_properties(${_KVIKIO_NAME} PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib")
4646

47-
<<<<<<< HEAD
48-
=======
4947
target_include_directories(${_KVIKIO_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
50-
>>>>>>> e54d507 (Add a simple threadpool.)
5148
target_link_libraries(${_KVIKIO_NAME} PUBLIC benchmark::benchmark kvikio::kvikio)
5249

5350
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
@@ -63,10 +60,6 @@ function(kvikio_add_benchmark)
6360
)
6461
endfunction()
6562

66-
<<<<<<< HEAD
67-
kvikio_add_benchmark(NAME THREADPOOL_BENCHMARK SOURCES "threadpool/threadpool_benchmark.cpp")
68-
=======
6963
kvikio_add_benchmark(
7064
NAME THREADPOOL_BENCHMARK SOURCES "threadpool/threadpool_benchmark.cpp" "utils/utils.cpp"
7165
)
72-
>>>>>>> e54d507 (Add a simple threadpool.)

cpp/benchmarks/threadpool/threadpool_benchmark.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <benchmark/benchmark.h>
3030
#include <kvikio/defaults.hpp>
31+
#include <kvikio/threadpool_simple.hpp>
3132
#include <utils/utils.hpp>
3233

3334
namespace kvikio {
@@ -46,7 +47,7 @@ void task_compute(std::size_t num_compute_iterations)
4647
}
4748

4849
template <ScalingType scaling_type>
49-
void BM_threadpool_compute(benchmark::State& state)
50+
void BM_BS_threadpool_compute(benchmark::State& state)
5051
{
5152
auto const num_threads = state.range(0);
5253

@@ -67,22 +68,64 @@ void BM_threadpool_compute(benchmark::State& state)
6768

6869
state.counters["threads"] = num_threads;
6970
}
71+
72+
template <ScalingType scaling_type>
73+
void BM_simple_threadpool_compute(benchmark::State& state)
74+
{
75+
auto const num_threads = state.range(0);
76+
77+
std::size_t const num_compute_tasks =
78+
(scaling_type == ScalingType::STRONG_SCALING) ? 10'000 : (1'000 * num_threads);
79+
80+
std::size_t constexpr num_compute_iterations{1'000};
81+
kvikio::ThreadPoolSimple thread_pool(num_threads);
82+
83+
for (auto _ : state) {
84+
// Submit a total of "num_compute_tasks" tasks to the thread pool.
85+
for (auto i = std::size_t{0}; i < num_compute_tasks; ++i) {
86+
[[maybe_unused]] auto fut =
87+
thread_pool.submit_task([] { task_compute(num_compute_iterations); });
88+
}
89+
thread_pool.wait();
90+
}
91+
92+
state.counters["threads"] = num_threads;
93+
}
7094
} // namespace kvikio
7195

7296
int main(int argc, char** argv)
7397
{
7498
benchmark::Initialize(&argc, argv);
7599

76-
benchmark::RegisterBenchmark("BM_threadpool_compute:strong_scaling",
77-
kvikio::BM_threadpool_compute<kvikio::ScalingType::STRONG_SCALING>)
100+
benchmark::RegisterBenchmark(
101+
"BS_threadpool_compute:strong_scaling",
102+
kvikio::BM_BS_threadpool_compute<kvikio::ScalingType::STRONG_SCALING>)
103+
->RangeMultiplier(2)
104+
->Range(1, 64) // Increase from 1 to 64 (inclusive of both endpoints) with x2 stepping.
105+
->UseRealTime() // Use the wall clock to determine the number of benchmark iterations.
106+
->Unit(benchmark::kMillisecond)
107+
->MinTime(2); // Minimum of 2 seconds.
108+
109+
benchmark::RegisterBenchmark("BS_threadpool_compute:weak_scaling",
110+
kvikio::BM_BS_threadpool_compute<kvikio::ScalingType::WEAK_SCALING>)
111+
->RangeMultiplier(2)
112+
->Range(1, 64)
113+
->UseRealTime()
114+
->Unit(benchmark::kMillisecond)
115+
->MinTime(2);
116+
117+
benchmark::RegisterBenchmark(
118+
"simple_threadpool_compute:strong_scaling",
119+
kvikio::BM_simple_threadpool_compute<kvikio::ScalingType::STRONG_SCALING>)
78120
->RangeMultiplier(2)
79121
->Range(1, 64) // Increase from 1 to 64 (inclusive of both endpoints) with x2 stepping.
80122
->UseRealTime() // Use the wall clock to determine the number of benchmark iterations.
81123
->Unit(benchmark::kMillisecond)
82124
->MinTime(2); // Minimum of 2 seconds.
83125

84-
benchmark::RegisterBenchmark("BM_threadpool_compute:weak_scaling",
85-
kvikio::BM_threadpool_compute<kvikio::ScalingType::WEAK_SCALING>)
126+
benchmark::RegisterBenchmark(
127+
"simple_threadpool_compute:weak_scaling",
128+
kvikio::BM_simple_threadpool_compute<kvikio::ScalingType::WEAK_SCALING>)
86129
->RangeMultiplier(2)
87130
->Range(1, 64)
88131
->UseRealTime()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions andc
14+
* limitations under the License.
15+
*/
16+
17+
#include <functional>
18+
#include <memory>
19+
20+
namespace kvikio {
21+
class SimpleFunctionWrapper {
22+
private:
23+
struct inner_base {
24+
virtual void operator()() = 0;
25+
26+
virtual ~inner_base() = default;
27+
};
28+
29+
template <typename F>
30+
struct inner : inner_base {
31+
using F_decay = std::decay_t<F>;
32+
static_assert(std::is_invocable_r_v<void, F_decay>);
33+
34+
inner(F&& f) : _f(std::forward<F>(f)) {}
35+
36+
void operator()() override { std::invoke(_f); }
37+
38+
~inner() override = default;
39+
40+
F_decay _f;
41+
};
42+
43+
std::unique_ptr<inner_base> _callable;
44+
45+
public:
46+
template <typename F>
47+
SimpleFunctionWrapper(F&& f) : _callable(std::make_unique<inner<F>>(std::forward<F>(f)))
48+
{
49+
using F_decay = std::decay_t<F>;
50+
static_assert(std::is_invocable_r_v<void, F_decay>);
51+
}
52+
53+
SimpleFunctionWrapper() = default;
54+
55+
SimpleFunctionWrapper(SimpleFunctionWrapper&&) = default;
56+
SimpleFunctionWrapper& operator=(SimpleFunctionWrapper&&) = default;
57+
58+
SimpleFunctionWrapper(const SimpleFunctionWrapper&) = delete;
59+
SimpleFunctionWrapper& operator=(const SimpleFunctionWrapper&) = delete;
60+
61+
void operator()() { return _callable->operator()(); }
62+
63+
operator bool() { return _callable != nullptr; }
64+
};
65+
66+
using FunctionWrapper = SimpleFunctionWrapper;
67+
} // namespace kvikio

cpp/include/kvikio/threadpool_simple.hpp

Lines changed: 153 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,49 +15,183 @@
1515
*/
1616

1717
#include <atomic>
18-
#include <functional>
18+
#include <condition_variable>
1919
#include <future>
20+
#include <optional>
21+
#include <queue>
2022
#include <thread>
21-
#include <type_traits>
22-
#include <vector>
23+
#include <utility>
24+
25+
#include <kvikio/function_wrapper.hpp>
2326

24-
namespace kvikio {
2527
/**
26-
* @brief A simple thread pool that executes tasks in an embarrassingly parallel manner.
27-
*
28-
* The implementation is header-only.
28+
* @brief A simple, header-only thread pool that executes tasks in an embarrassingly parallel
29+
* manner.
2930
*/
31+
namespace kvikio {
32+
class this_thread {
33+
public:
34+
static bool is_from_pool() { return get_thread_idx().has_value(); }
35+
36+
static std::optional<std::size_t> get_thread_idx() { return this_thread_idx; }
37+
38+
private:
39+
friend class ThreadPoolSimple;
40+
41+
static void set_thread_idx(std::size_t thread_idx) { this_thread_idx = thread_idx; }
42+
43+
inline static thread_local std::optional<std::size_t> this_thread_idx{std::nullopt};
44+
};
45+
46+
struct Worker {
47+
std::thread thread;
48+
std::condition_variable task_available_cv;
49+
std::condition_variable task_done_cv;
50+
std::mutex task_mutex;
51+
std::queue<FunctionWrapper> task_queue;
52+
bool should_stop{false};
53+
};
54+
3055
class ThreadPoolSimple {
3156
public:
32-
ThreadPoolSimple(
33-
unsigned int num_threads, const std::function<void()>& worker_thread_init_func = [] {})
34-
: _num_threads{num_threads}, _worker_thread_init_func{worker_thread_init_func}
57+
template <typename F>
58+
ThreadPoolSimple(unsigned int num_threads, F&& worker_thread_init_func)
59+
: _num_threads{num_threads}, _worker_thread_init_func{std::forward<F>(worker_thread_init_func)}
3560
{
61+
create_threads();
3662
}
3763

38-
void reset();
64+
ThreadPoolSimple(unsigned int num_threads) : ThreadPoolSimple(num_threads, FunctionWrapper{}) {}
65+
66+
~ThreadPoolSimple() { destroy_threads(); }
67+
68+
template <typename F>
69+
void reset(unsigned int num_threads, F&& worker_thread_init_func)
70+
{
71+
wait();
72+
destroy_threads();
73+
74+
_num_threads = num_threads;
75+
_worker_thread_init_func = std::forward<F>(worker_thread_init_func);
76+
create_threads();
77+
}
78+
79+
void reset(unsigned int num_threads) { reset(num_threads, FunctionWrapper{}); }
80+
81+
void wait()
82+
{
83+
for (unsigned int thread_idx = 0; thread_idx < _num_threads; ++thread_idx) {
84+
auto& task_done_cv = _workers[thread_idx].task_done_cv;
85+
auto& mut = _workers[thread_idx].task_mutex;
86+
auto& task_queue = _workers[thread_idx].task_queue;
87+
88+
std::unique_lock lock(mut);
89+
task_done_cv.wait(lock, [&] { return task_queue.empty(); });
90+
}
91+
}
92+
93+
unsigned int num_thread() const { return _num_threads; }
3994

4095
template <typename F, typename R = std::invoke_result_t<std::decay_t<F>>>
4196
[[nodiscard]] std::future<R> submit_task(F&& task)
4297
{
98+
auto tid =
99+
std::atomic_fetch_add_explicit(&_task_submission_counter, 1, std::memory_order_relaxed);
100+
tid %= _num_threads;
101+
102+
return submit_task_to_thread(std::forward<F>(task), tid);
103+
}
104+
105+
template <typename F, typename R = std::invoke_result_t<std::decay_t<F>>>
106+
[[nodiscard]] std::future<R> submit_task_to_thread(F&& task, std::size_t thread_idx)
107+
{
108+
auto& task_available_cv = _workers[thread_idx].task_available_cv;
109+
auto& mut = _workers[thread_idx].task_mutex;
110+
auto& task_queue = _workers[thread_idx].task_queue;
111+
112+
std::promise<R> p;
113+
auto fut = p.get_future();
114+
115+
{
116+
std::lock_guard lock(mut);
117+
118+
task_queue.emplace([task = std::forward<F>(task), p = std::move(p), thread_idx]() mutable {
119+
try {
120+
if constexpr (std::is_same_v<R, void>) {
121+
task();
122+
p.set_value();
123+
} else {
124+
p.set_value(task());
125+
}
126+
} catch (...) {
127+
p.set_exception(std::current_exception());
128+
}
129+
});
130+
}
131+
132+
task_available_cv.notify_one();
133+
return fut;
43134
}
44135

45136
private:
46-
void worker() {}
137+
void run_worker(std::size_t thread_idx)
138+
{
139+
this_thread::set_thread_idx(thread_idx);
140+
141+
auto& task_available_cv = _workers[thread_idx].task_available_cv;
142+
auto& task_done_cv = _workers[thread_idx].task_done_cv;
143+
auto& mut = _workers[thread_idx].task_mutex;
144+
auto& task_queue = _workers[thread_idx].task_queue;
145+
auto& should_stop = _workers[thread_idx].should_stop;
146+
147+
if (_worker_thread_init_func) { std::invoke(_worker_thread_init_func); }
148+
149+
while (true) {
150+
std::unique_lock lock(mut);
151+
152+
if (task_queue.empty()) { task_done_cv.notify_all(); }
153+
154+
task_available_cv.wait(lock, [&] { return !task_queue.empty() || should_stop; });
155+
156+
if (should_stop) { break; }
157+
158+
auto task = std::move(task_queue.front());
159+
task_queue.pop();
160+
lock.unlock();
161+
162+
task();
163+
}
164+
}
47165

48166
void create_threads()
49167
{
50-
for (unsigned int i = 0; i < _num_threads; ++i) {
51-
_thread_container.emplace_back(&ThreadPoolSimple::worker, _worker_thread_init_func);
168+
_workers = std::make_unique<Worker[]>(_num_threads);
169+
for (unsigned int thread_idx = 0; thread_idx < _num_threads; ++thread_idx) {
170+
_workers[thread_idx].thread = std::thread([this, thread_idx] { run_worker(thread_idx); });
52171
}
53172
}
54173

55-
void destroy_threads() {}
174+
void destroy_threads()
175+
{
176+
for (unsigned int thread_idx = 0; thread_idx < _num_threads; ++thread_idx) {
177+
auto& task_available_cv = _workers[thread_idx].task_available_cv;
178+
auto& mut = _workers[thread_idx].task_mutex;
179+
180+
{
181+
std::lock_guard lock(mut);
182+
_workers[thread_idx].should_stop = true;
183+
}
184+
185+
task_available_cv.notify_one();
186+
187+
_workers[thread_idx].thread.join();
188+
}
189+
}
56190

57-
std::atomic_bool _done{false};
58191
unsigned int _num_threads{};
59-
std::function<void()> _worker_thread_init_func{};
60-
std::vector<std::thread> _thread_container{};
192+
FunctionWrapper _worker_thread_init_func;
193+
std::unique_ptr<Worker[]> _workers;
194+
std::atomic_size_t _task_submission_counter{0};
61195
};
62196

63-
} // namespace kvikio
197+
} // namespace kvikio

0 commit comments

Comments
 (0)