|
| 1 | +/// Copyright (c) RenChu Wang - All Rights Reserved |
| 2 | + |
| 3 | +#include <cassert> |
| 4 | +#include <iostream> |
| 5 | +#include <memory> |
| 6 | +#include <numeric> |
| 7 | +#include <semaphore> |
| 8 | +#include <sstream> |
| 9 | +#include <vector> |
| 10 | + |
| 11 | +using namespace std; |
| 12 | + |
| 13 | +// The compute class, with couple of members, representing differnt operations. |
| 14 | +class compute : enable_shared_from_this<compute> { |
| 15 | + public: |
| 16 | + compute() : call_count_(0) {} |
| 17 | + virtual ~compute() {} |
| 18 | + virtual int get() = 0; |
| 19 | + virtual vector<shared_ptr<compute>> children() const = 0; |
| 20 | + virtual string str() const = 0; |
| 21 | + |
| 22 | + int operator()() { |
| 23 | + ++call_count_; |
| 24 | + cout << this->str() << "\n"; |
| 25 | + return get(); |
| 26 | + } |
| 27 | + |
| 28 | + size_t cnt() const { return call_count_; } |
| 29 | + |
| 30 | + private: |
| 31 | + size_t call_count_; |
| 32 | +}; |
| 33 | + |
| 34 | +class scoped_semaphore { |
| 35 | + public: |
| 36 | + scoped_semaphore(counting_semaphore<>& sem, string by) |
| 37 | + : sem_(sem), by_(by) { |
| 38 | + cout << "acq(" << by_ << ")\n"; |
| 39 | + sem_.acquire(); |
| 40 | + } |
| 41 | + ~scoped_semaphore() { |
| 42 | + cout << "rel(" << by_ << ")\n"; |
| 43 | + sem_.release(); |
| 44 | + } |
| 45 | + |
| 46 | + private: |
| 47 | + // This is guarenteed to be alive, so using a reference. |
| 48 | + counting_semaphore<>& sem_; |
| 49 | + string by_; |
| 50 | +}; |
| 51 | +// The semaphore class. Simulate a fixed amount of resources (threads). |
| 52 | +class with_semaphore { |
| 53 | + protected: |
| 54 | + with_semaphore(counting_semaphore<>& sem) : sem_(sem) {} |
| 55 | + scoped_semaphore acquire_semaphore(string by) { |
| 56 | + // Using copy elision, to avoid acquiring and releasing and acquiring. |
| 57 | + return scoped_semaphore(sem_, by); |
| 58 | + } |
| 59 | + |
| 60 | + private: |
| 61 | + counting_semaphore<>& sem_; |
| 62 | +}; |
| 63 | + |
| 64 | +class literal : public compute { |
| 65 | + public: |
| 66 | + literal(int i) : data_(i) {} |
| 67 | + int get() override { return data_; } |
| 68 | + vector<shared_ptr<compute>> children() const override { return {}; } |
| 69 | + string str() const override { |
| 70 | + stringstream out; |
| 71 | + out << data_; |
| 72 | + return out.str(); |
| 73 | + } |
| 74 | + |
| 75 | + private: |
| 76 | + int data_; |
| 77 | +}; |
| 78 | + |
| 79 | +class summation : public compute { |
| 80 | + public: |
| 81 | + summation(vector<shared_ptr<compute>> op) : operands_(op) {} |
| 82 | + |
| 83 | + int get() override { |
| 84 | + int summation = 0; |
| 85 | + for (auto op : operands_) { |
| 86 | + summation += (*op)(); |
| 87 | + } |
| 88 | + return summation; |
| 89 | + } |
| 90 | + |
| 91 | + vector<shared_ptr<compute>> children() const override { return operands_; } |
| 92 | + string str() const override { |
| 93 | + stringstream out; |
| 94 | + for (size_t i = 0; i < operands_.size(); ++i) { |
| 95 | + out << operands_[i]->str(); |
| 96 | + |
| 97 | + if (i != operands_.size() - 1) { |
| 98 | + out << " + "; |
| 99 | + } |
| 100 | + } |
| 101 | + return out.str(); |
| 102 | + } |
| 103 | + |
| 104 | + private: |
| 105 | + vector<shared_ptr<compute>> operands_; |
| 106 | +}; |
| 107 | + |
| 108 | +class product : public compute { |
| 109 | + public: |
| 110 | + product(vector<shared_ptr<compute>> op) : operands_(op) {} |
| 111 | + |
| 112 | + int get() override { |
| 113 | + int product = 1; |
| 114 | + for (auto op : operands_) { |
| 115 | + product *= (*op)(); |
| 116 | + } |
| 117 | + return product; |
| 118 | + } |
| 119 | + |
| 120 | + vector<shared_ptr<compute>> children() const override { return operands_; } |
| 121 | + string str() const override { |
| 122 | + stringstream out; |
| 123 | + for (size_t i = 0; i < operands_.size(); ++i) { |
| 124 | + out << operands_[i]->str(); |
| 125 | + |
| 126 | + if (i != operands_.size() - 1) { |
| 127 | + out << " * "; |
| 128 | + } |
| 129 | + } |
| 130 | + return out.str(); |
| 131 | + } |
| 132 | + |
| 133 | + private: |
| 134 | + vector<shared_ptr<compute>> operands_; |
| 135 | +}; |
| 136 | + |
| 137 | +class cache : public compute { |
| 138 | + public: |
| 139 | + cache(shared_ptr<compute> op) : operand(op) {} |
| 140 | + |
| 141 | + int get() override { |
| 142 | + if (value_ >= 0) { |
| 143 | + return value_; |
| 144 | + } |
| 145 | + |
| 146 | + value_ = (*operand)(); |
| 147 | + return value_; |
| 148 | + } |
| 149 | + |
| 150 | + vector<shared_ptr<compute>> children() const override { return {operand}; } |
| 151 | + string str() const override { |
| 152 | + stringstream out; |
| 153 | + out << "c(" << operand->str() << ")"; |
| 154 | + return out.str(); |
| 155 | + } |
| 156 | + |
| 157 | + private: |
| 158 | + shared_ptr<compute> operand; |
| 159 | + int value_ = -1; |
| 160 | +}; |
| 161 | + |
| 162 | +// Task classes are like `Task` in python, |
| 163 | +// where execution starts immediately, |
| 164 | +// but depends on other `Task`s, |
| 165 | +// so this means that recursively triggering the tasks can cause a deadlock, |
| 166 | +// due to the limited budget smaller than the dependencies. |
| 167 | +// For example, when budget = 1, a depending on b, |
| 168 | +// a would require a thread to run, and then b would require a thread to run, |
| 169 | +// exceeding the budget (a runs first before b). |
| 170 | +class task_literal : public literal, with_semaphore { |
| 171 | + public: |
| 172 | + task_literal(int i, counting_semaphore<>& sem) |
| 173 | + : literal(i), with_semaphore(sem) {} |
| 174 | + |
| 175 | + int get() override { |
| 176 | + auto sem{acquire_semaphore("task_lit_" + str())}; |
| 177 | + return literal::get(); |
| 178 | + } |
| 179 | +}; |
| 180 | +class task_summation : public summation, with_semaphore { |
| 181 | + public: |
| 182 | + task_summation(vector<shared_ptr<compute>> op, counting_semaphore<>& sem) |
| 183 | + : summation(op), with_semaphore(sem) {} |
| 184 | + |
| 185 | + int get() override { |
| 186 | + auto sem{acquire_semaphore("task_sum_" + str())}; |
| 187 | + return summation::get(); |
| 188 | + } |
| 189 | +}; |
| 190 | + |
| 191 | +// Lazy classes doesn't cause deadlocks, |
| 192 | +// because they are reduced from the leaves of the expression tree, |
| 193 | +// which can be linearly ordered (no deadlock so long as semaphore > 1). |
| 194 | +class lazy_literal : public literal, with_semaphore { |
| 195 | + public: |
| 196 | + lazy_literal(int i, counting_semaphore<>& sem) |
| 197 | + : literal(i), with_semaphore(sem) {} |
| 198 | + int get() override { |
| 199 | + // As literal has no children, this can be the same as `literal_task`. |
| 200 | + auto sem{acquire_semaphore("lazy_lit_" + str())}; |
| 201 | + return literal::get(); |
| 202 | + } |
| 203 | +}; |
| 204 | +class lazy_summation : public summation, with_semaphore { |
| 205 | + public: |
| 206 | + lazy_summation(vector<shared_ptr<compute>> op, counting_semaphore<>& sem) |
| 207 | + : summation(op), with_semaphore(sem) {} |
| 208 | + int get() override { |
| 209 | + vector<int> values; |
| 210 | + |
| 211 | +// Uisng omp parallel to simulate multiple threads. |
| 212 | +#pragma omp parallel for |
| 213 | + for (auto child : children()) { |
| 214 | +#pragma omp critical |
| 215 | + values.push_back(((*child)())); |
| 216 | + } |
| 217 | + |
| 218 | + // Only acquiring when needed. |
| 219 | + // Previous I put this into the for loop, before the child call, |
| 220 | + // but this just means that we are acquiring twice for the child call, |
| 221 | + // and none for the current call. |
| 222 | + auto sem{acquire_semaphore("lazy_sum_" + str())}; |
| 223 | + return accumulate(values.begin(), values.end(), 0); |
| 224 | + } |
| 225 | +}; |
| 226 | + |
| 227 | +int main() { |
| 228 | + using expr = shared_ptr<compute>; |
| 229 | + expr one, two, three, sum_six, prod_six, twelve, thrity_six; |
| 230 | + vector<expr> one_two_three, six_six; |
| 231 | + |
| 232 | + { |
| 233 | + one = make_shared<literal>(1); |
| 234 | + two = make_shared<literal>(2); |
| 235 | + three = make_shared<literal>(3); |
| 236 | + |
| 237 | + one_two_three = {one, two, three}; |
| 238 | + sum_six = make_shared<summation>(one_two_three); |
| 239 | + prod_six = make_shared<product>(one_two_three); |
| 240 | + |
| 241 | + // Too lazy to implemen cache / product for lazy / task. |
| 242 | + expr cache_sum_six = make_shared<cache>(sum_six); |
| 243 | + expr cache_prod_six = make_shared<cache>(prod_six); |
| 244 | + |
| 245 | + six_six = {cache_sum_six, cache_prod_six}; |
| 246 | + |
| 247 | + twelve = make_shared<summation>(six_six); |
| 248 | + thrity_six = make_shared<product>(six_six); |
| 249 | + |
| 250 | + assert((*one)() == 1); |
| 251 | + assert((*two)() == 2); |
| 252 | + assert((*three)() == 3); |
| 253 | + assert((*sum_six)() == 6); |
| 254 | + assert((*prod_six)() == 6); |
| 255 | + assert((*cache_sum_six)() == 6); |
| 256 | + assert((*cache_prod_six)() == 6); |
| 257 | + assert((*twelve)() == 12); |
| 258 | + assert((*thrity_six)() == 36); |
| 259 | + assert(sum_six->cnt() == 2); |
| 260 | + assert(prod_six->cnt() == 2); |
| 261 | + |
| 262 | + cout << "twelve = " << (*twelve)() << "\n"; |
| 263 | + cout << "thrity_six = " << (*thrity_six)() << "\n"; |
| 264 | + cout << "Done normal\n\n\n\n"; |
| 265 | + } |
| 266 | + |
| 267 | + counting_semaphore<> sem(1); |
| 268 | + { |
| 269 | + one = make_shared<lazy_literal>(1, sem); |
| 270 | + two = make_shared<lazy_literal>(2, sem); |
| 271 | + three = make_shared<lazy_literal>(3, sem); |
| 272 | + |
| 273 | + one_two_three = {one, two, three}; |
| 274 | + sum_six = make_shared<lazy_summation>(one_two_three, sem); |
| 275 | + six_six = {sum_six, sum_six}; |
| 276 | + twelve = make_shared<lazy_summation>(six_six, sem); |
| 277 | + assert((*twelve)() == 12); |
| 278 | + cout << "Done lazy\n\n\n\n"; |
| 279 | + } |
| 280 | + |
| 281 | + { |
| 282 | + one = make_shared<task_literal>(1, sem); |
| 283 | + two = make_shared<task_literal>(2, sem); |
| 284 | + three = make_shared<task_literal>(3, sem); |
| 285 | + |
| 286 | + one_two_three = {one, two, three}; |
| 287 | + sum_six = make_shared<task_summation>(one_two_three, sem); |
| 288 | + six_six = {sum_six, sum_six}; |
| 289 | + twelve = make_shared<task_summation>(six_six, sem); |
| 290 | + assert((*twelve)() == 12); |
| 291 | + |
| 292 | + // Impossible to achieve. |
| 293 | + cout << "Done deadlock\n"; |
| 294 | + } |
| 295 | +} |
0 commit comments