1515 */
1616
1717#include < atomic>
18- #include < functional >
18+ #include < condition_variable >
1919#include < future>
20+ #include < optional>
21+ #include < queue>
2022#include < thread>
21- #include < type_traits>
22- #include < vector>
23+ #include < utility>
24+
25+ #include < kvikio/function_wrapper.hpp>
2326
24- namespace kvikio {
2527/* *
26- * @brief A simple thread pool that executes tasks in an embarrassingly parallel manner.
27- *
28- * The implementation is header-only.
28+ * @brief A simple, header-only thread pool that executes tasks in an embarrassingly parallel
29+ * manner.
2930 */
31+ namespace kvikio {
32+ class this_thread {
33+ public:
34+ static bool is_from_pool () { return get_thread_idx ().has_value (); }
35+
36+ static std::optional<std::size_t > get_thread_idx () { return this_thread_idx; }
37+
38+ private:
39+ friend class ThreadPoolSimple ;
40+
41+ static void set_thread_idx (std::size_t thread_idx) { this_thread_idx = thread_idx; }
42+
43+ inline static thread_local std::optional<std::size_t > this_thread_idx{std::nullopt };
44+ };
45+
46+ struct Worker {
47+ std::thread thread;
48+ std::condition_variable task_available_cv;
49+ std::condition_variable task_done_cv;
50+ std::mutex task_mutex;
51+ std::queue<FunctionWrapper> task_queue;
52+ bool should_stop{false };
53+ };
54+
3055class ThreadPoolSimple {
3156 public:
32- ThreadPoolSimple (
33- unsigned int num_threads, const std::function< void ()>& worker_thread_init_func = [] {} )
34- : _num_threads{num_threads}, _worker_thread_init_func{worker_thread_init_func}
57+ template < typename F>
58+ ThreadPoolSimple ( unsigned int num_threads, F&& worker_thread_init_func)
59+ : _num_threads{num_threads}, _worker_thread_init_func{std::forward<F>( worker_thread_init_func) }
3560 {
61+ create_threads ();
3662 }
3763
38- void reset ();
64+ ThreadPoolSimple (unsigned int num_threads) : ThreadPoolSimple(num_threads, FunctionWrapper{}) {}
65+
66+ ~ThreadPoolSimple () { destroy_threads (); }
67+
68+ template <typename F>
69+ void reset (unsigned int num_threads, F&& worker_thread_init_func)
70+ {
71+ wait ();
72+ destroy_threads ();
73+
74+ _num_threads = num_threads;
75+ _worker_thread_init_func = std::forward<F>(worker_thread_init_func);
76+ create_threads ();
77+ }
78+
79+ void reset (unsigned int num_threads) { reset (num_threads, FunctionWrapper{}); }
80+
81+ void wait ()
82+ {
83+ for (unsigned int thread_idx = 0 ; thread_idx < _num_threads; ++thread_idx) {
84+ auto & task_done_cv = _workers[thread_idx].task_done_cv ;
85+ auto & mut = _workers[thread_idx].task_mutex ;
86+ auto & task_queue = _workers[thread_idx].task_queue ;
87+
88+ std::unique_lock lock (mut);
89+ task_done_cv.wait (lock, [&] { return task_queue.empty (); });
90+ }
91+ }
92+
93+ unsigned int num_thread () const { return _num_threads; }
3994
4095 template <typename F, typename R = std::invoke_result_t <std::decay_t <F>>>
4196 [[nodiscard]] std::future<R> submit_task (F&& task)
4297 {
98+ auto tid =
99+ std::atomic_fetch_add_explicit (&_task_submission_counter, 1 , std::memory_order_relaxed);
100+ tid %= _num_threads;
101+
102+ return submit_task_to_thread (std::forward<F>(task), tid);
103+ }
104+
105+ template <typename F, typename R = std::invoke_result_t <std::decay_t <F>>>
106+ [[nodiscard]] std::future<R> submit_task_to_thread (F&& task, std::size_t thread_idx)
107+ {
108+ auto & task_available_cv = _workers[thread_idx].task_available_cv ;
109+ auto & mut = _workers[thread_idx].task_mutex ;
110+ auto & task_queue = _workers[thread_idx].task_queue ;
111+
112+ std::promise<R> p;
113+ auto fut = p.get_future ();
114+
115+ {
116+ std::lock_guard lock (mut);
117+
118+ task_queue.emplace ([task = std::forward<F>(task), p = std::move (p), thread_idx]() mutable {
119+ try {
120+ if constexpr (std::is_same_v<R, void >) {
121+ task ();
122+ p.set_value ();
123+ } else {
124+ p.set_value (task ());
125+ }
126+ } catch (...) {
127+ p.set_exception (std::current_exception ());
128+ }
129+ });
130+ }
131+
132+ task_available_cv.notify_one ();
133+ return fut;
43134 }
44135
45136 private:
46- void worker () {}
137+ void run_worker (std::size_t thread_idx)
138+ {
139+ this_thread::set_thread_idx (thread_idx);
140+
141+ auto & task_available_cv = _workers[thread_idx].task_available_cv ;
142+ auto & task_done_cv = _workers[thread_idx].task_done_cv ;
143+ auto & mut = _workers[thread_idx].task_mutex ;
144+ auto & task_queue = _workers[thread_idx].task_queue ;
145+ auto & should_stop = _workers[thread_idx].should_stop ;
146+
147+ if (_worker_thread_init_func) { std::invoke (_worker_thread_init_func); }
148+
149+ while (true ) {
150+ std::unique_lock lock (mut);
151+
152+ if (task_queue.empty ()) { task_done_cv.notify_all (); }
153+
154+ task_available_cv.wait (lock, [&] { return !task_queue.empty () || should_stop; });
155+
156+ if (should_stop) { break ; }
157+
158+ auto task = std::move (task_queue.front ());
159+ task_queue.pop ();
160+ lock.unlock ();
161+
162+ task ();
163+ }
164+ }
47165
48166 void create_threads ()
49167 {
50- for (unsigned int i = 0 ; i < _num_threads; ++i) {
51- _thread_container.emplace_back (&ThreadPoolSimple::worker, _worker_thread_init_func);
168+ _workers = std::make_unique<Worker[]>(_num_threads);
169+ for (unsigned int thread_idx = 0 ; thread_idx < _num_threads; ++thread_idx) {
170+ _workers[thread_idx].thread = std::thread ([this , thread_idx] { run_worker (thread_idx); });
52171 }
53172 }
54173
55- void destroy_threads () {}
174+ void destroy_threads ()
175+ {
176+ for (unsigned int thread_idx = 0 ; thread_idx < _num_threads; ++thread_idx) {
177+ auto & task_available_cv = _workers[thread_idx].task_available_cv ;
178+ auto & mut = _workers[thread_idx].task_mutex ;
179+
180+ {
181+ std::lock_guard lock (mut);
182+ _workers[thread_idx].should_stop = true ;
183+ }
184+
185+ task_available_cv.notify_one ();
186+
187+ _workers[thread_idx].thread .join ();
188+ }
189+ }
56190
57- std::atomic_bool _done{false };
58191 unsigned int _num_threads{};
59- std::function<void ()> _worker_thread_init_func{};
60- std::vector<std::thread> _thread_container{};
192+ FunctionWrapper _worker_thread_init_func;
193+ std::unique_ptr<Worker[]> _workers;
194+ std::atomic_size_t _task_submission_counter{0 };
61195};
62196
63- } // namespace kvikio
197+ } // namespace kvikio
0 commit comments