|
| 1 | +#define _GLIBCXX_USE_CXX11_ABI 1 |
| 2 | +#define HL_PERMIT_FAILED_UNROLL 1 |
| 3 | + |
| 4 | +#include "mul.hpp" |
| 5 | + |
| 6 | +#include "Halide.h" |
| 7 | +#include "HalideBuffer.h" |
| 8 | + |
| 9 | +#include <unordered_map> |
| 10 | + |
| 11 | +/* Estimates for some of the Halide parameters */ |
| 12 | +static const int maxHalideRow = 1000000; |
| 13 | +static const int featureCount = 32; |
| 14 | +static const int activeRows = 60000; |
| 15 | +static const int groups = 1; |
| 16 | +static const int featureRowCount = 100000; |
| 17 | + |
| 18 | +template <typename Operation> |
| 19 | +using MulStrategyMap = |
| 20 | + std::unordered_map<LayerDimensions, std::unique_ptr<Operation>, |
| 21 | + LayerDimensionsHash>; |
| 22 | + |
| 23 | +template <typename Operation> |
| 24 | +const Operation &getHalideMul(int inFeatureCount, int outFeatureCount, |
| 25 | + int groups, bool cuda, |
| 26 | + MulStrategyMap<Operation> &container) { |
| 27 | + const LayerDimensions dims = {inFeatureCount, outFeatureCount, groups, cuda}; |
| 28 | + auto it = container.find(dims); |
| 29 | + |
| 30 | + if (it != container.end()) { |
| 31 | + return *(it->second); |
| 32 | + } |
| 33 | + |
| 34 | + auto mul = |
| 35 | + container.insert(std::make_pair(dims, std::make_unique<Operation>(dims))) |
| 36 | + .first->second.get(); |
| 37 | + return *mul; |
| 38 | +} |
| 39 | + |
| 40 | +struct HalideMulFactory::Impl { |
| 41 | + MulStrategyMap<HalideMulBackward> backward; |
| 42 | + MulStrategyMap<HalideMulForward> forward; |
| 43 | +}; |
| 44 | + |
| 45 | +HalideMulFactory::HalideMulFactory() : pimpl(new Impl()) {} |
| 46 | + |
| 47 | +HalideMulFactory::~HalideMulFactory() = default; |
| 48 | + |
| 49 | +const HalideMulFactory &HalideMulFactory::getInstance() { |
| 50 | + static HalideMulFactory instance; |
| 51 | + return instance; |
| 52 | +} |
| 53 | + |
| 54 | +const HalideMulForward & |
| 55 | +HalideMulFactory::getHalideMulForward(int inFeatureCount, int outFeatureCount, |
| 56 | + int groups, bool cuda) const { |
| 57 | + return getHalideMul<HalideMulForward>(inFeatureCount, outFeatureCount, groups, |
| 58 | + cuda, pimpl->forward); |
| 59 | +} |
| 60 | + |
| 61 | +const HalideMulBackward & |
| 62 | +HalideMulFactory::getHalideMulBackward(int inFeatureCount, int outFeatureCount, |
| 63 | + int groups, bool cuda) const { |
| 64 | + return getHalideMul<HalideMulBackward>(inFeatureCount, outFeatureCount, |
| 65 | + groups, cuda, pimpl->backward); |
| 66 | +} |
| 67 | + |
| 68 | +HalideMul::HalideMul(int inFeatureCount, int outFeatureCount, int groups) |
| 69 | + : dimensions({inFeatureCount, outFeatureCount, groups}) {} |
| 70 | + |
| 71 | +HalideMul::HalideMul(const LayerDimensions &dims) : dimensions(dims) {} |
| 72 | + |
| 73 | +HalideMul::~HalideMul() = default; |
| 74 | + |
| 75 | +/* Implementation of forward Halide matrix multiplication */ |
| 76 | +struct HalideMulForward::Impl { |
| 77 | +public: |
| 78 | + Impl(const LayerDimensions &dimensions, bool cuda) { |
| 79 | + Halide::Target target = Halide::get_host_target(); |
| 80 | + Halide::Func matmul = Halide::Func("matmul"); |
| 81 | + |
| 82 | + /* Variables */ |
| 83 | + Halide::Var i, g, j; |
| 84 | + Halide::RDom k{0, dimensions.inFeatureCount / dimensions.groups}; |
| 85 | + |
| 86 | + /* Algorithm */ |
| 87 | + Halide::Expr producer = clamp(rules(2 * i), 0, maxHalideRow - 1); |
| 88 | + matmul(j, i, g) = sum(inputFeatures(k, g, producer) * weights(j, k, g)); |
| 89 | + |
| 90 | + /* Schedule */ |
| 91 | + matmul.estimate(j, 0, featureCount) |
| 92 | + .estimate(g, 0, groups) |
| 93 | + .estimate(i, 0, featureRowCount); |
| 94 | + |
| 95 | + inputFeatures.dim(0).set_bounds_estimate(0, featureCount); |
| 96 | + inputFeatures.dim(1).set_bounds_estimate(0, groups); |
| 97 | + inputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); |
| 98 | + |
| 99 | + weights.dim(0).set_bounds_estimate(0, featureCount); |
| 100 | + weights.dim(1).set_bounds_estimate(0, featureCount); |
| 101 | + weights.dim(2).set_bounds_estimate(0, groups); |
| 102 | + |
| 103 | + rules.dim(0).set_bounds_estimate(0, activeRows); |
| 104 | + activeRowsParam.set_estimate(activeRows); |
| 105 | + |
| 106 | + p = Halide::Pipeline({matmul}); |
| 107 | + |
| 108 | + if (!cuda) { |
| 109 | + p.auto_schedule(target); |
| 110 | + } else { |
| 111 | + target.set_feature(Halide::Target::CUDA); |
| 112 | + } |
| 113 | + |
| 114 | + p.compile_jit(target); |
| 115 | + }; |
| 116 | + |
| 117 | + Halide::ImageParam inputFeatures = |
| 118 | + Halide::ImageParam(Halide::type_of<float>(), 3, "source"); |
| 119 | + Halide::ImageParam weights = |
| 120 | + Halide::ImageParam(Halide::type_of<float>(), 3, "weight"); |
| 121 | + Halide::ImageParam rules = |
| 122 | + Halide::ImageParam(Halide::type_of<int>(), 1, "rules"); |
| 123 | + |
| 124 | + Halide::Param<int> activeRowsParam = Halide::Param<int>("row_count"); |
| 125 | + |
| 126 | + Halide::Pipeline p; |
| 127 | +}; |
| 128 | + |
| 129 | +HalideMulForward::HalideMulForward(int inFeatureCount, int outFeatureCount, |
| 130 | + int groups, bool cuda) |
| 131 | + : HalideMul(inFeatureCount, outFeatureCount, groups), |
| 132 | + pimpl(new Impl(dimensions, cuda)) {} |
| 133 | + |
| 134 | +HalideMulForward::HalideMulForward(const LayerDimensions &dims) |
| 135 | + : HalideMul(dims), pimpl(new Impl(dimensions, dims.cuda)) {} |
| 136 | + |
| 137 | +HalideMulForward::~HalideMulForward() = default; |
| 138 | + |
| 139 | +/* Executes the forward matrix multiplication created through the |
| 140 | + implementation object. */ |
| 141 | +void HalideMulForward::execute(float *input, float *weight, int *rules, |
| 142 | + float *output, int activeRowCount) const { |
| 143 | + |
| 144 | + int inputPlanes = dimensions.inFeatureCount / dimensions.groups; |
| 145 | + int outputPlanes = dimensions.outFeatureCount / dimensions.groups; |
| 146 | + |
| 147 | + pimpl->inputFeatures.set(Halide::Buffer<float>( |
| 148 | + input, inputPlanes, dimensions.groups, maxHalideRow)); |
| 149 | + pimpl->weights.set(Halide::Buffer<float>(weight, outputPlanes, inputPlanes, |
| 150 | + dimensions.groups)); |
| 151 | + pimpl->rules.set(Halide::Buffer<int>(rules, 2 * activeRowCount)); |
| 152 | + pimpl->activeRowsParam.set(activeRowCount); |
| 153 | + |
| 154 | + auto out = Halide::Buffer<float>(output, outputPlanes, activeRowCount, |
| 155 | + dimensions.groups); |
| 156 | + pimpl->p.realize(out); |
| 157 | +} |
| 158 | + |
| 159 | +/* Implementation of backward Halide matrix multiplication */ |
| 160 | +struct HalideMulBackward::Impl { |
| 161 | +public: |
| 162 | + Impl(const LayerDimensions &dimensions, bool cuda) { |
| 163 | + Halide::Target target = Halide::get_host_target(); |
| 164 | + |
| 165 | + int outputPlanes = dimensions.outFeatureCount / dimensions.groups; |
| 166 | + |
| 167 | + /* Variables */ |
| 168 | + Halide::Func o_matmul = Halide::Func("o_matmul"); |
| 169 | + Halide::Func o_weights = Halide::Func("o_weights"); |
| 170 | + Halide::Var i, g, k, j, gw, outp, inp; |
| 171 | + |
| 172 | + Halide::RDom planes = Halide::RDom(0, outputPlanes); |
| 173 | + Halide::RDom nums = Halide::RDom(0, activeRowsParam); |
| 174 | + |
| 175 | + /* Algorithm */ |
| 176 | + Halide::Expr producer = clamp(rules(2 * i + 1), 0, maxHalideRow - 1); |
| 177 | + |
| 178 | + Halide::Expr orAccess_dom = clamp(rules(2 * nums + 1), 0, maxHalideRow - 1); |
| 179 | + Halide::Expr irAccess_dom = clamp(rules(2 * nums), 0, maxHalideRow - 1); |
| 180 | + |
| 181 | + o_matmul(k, i, g) = |
| 182 | + sum(weights(planes, k, g) * outputFeatures(planes, g, producer)); |
| 183 | + |
| 184 | + o_weights(outp, inp, gw) = sum(outputFeatures(outp, gw, orAccess_dom) * |
| 185 | + inputFeatures(inp, gw, irAccess_dom)); |
| 186 | + |
| 187 | + /* Schedule */ |
| 188 | + o_matmul.estimate(k, 0, featureCount) |
| 189 | + .estimate(g, 0, groups) |
| 190 | + .estimate(i, 0, featureRowCount); |
| 191 | + o_weights.estimate(gw, 0, groups) |
| 192 | + .estimate(outp, 0, featureCount) |
| 193 | + .estimate(inp, 0, featureCount); |
| 194 | + |
| 195 | + inputFeatures.dim(0).set_bounds_estimate(0, featureCount); |
| 196 | + inputFeatures.dim(1).set_bounds_estimate(0, groups); |
| 197 | + inputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); |
| 198 | + |
| 199 | + outputFeatures.dim(0).set_bounds_estimate(0, featureCount); |
| 200 | + outputFeatures.dim(1).set_bounds_estimate(0, groups); |
| 201 | + outputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); |
| 202 | + |
| 203 | + weights.dim(0).set_bounds_estimate(0, featureCount); |
| 204 | + weights.dim(1).set_bounds_estimate(0, featureCount); |
| 205 | + weights.dim(2).set_bounds_estimate(0, groups); |
| 206 | + |
| 207 | + rules.dim(0).set_bounds_estimate(0, activeRows); |
| 208 | + activeRowsParam.set_estimate(activeRows); |
| 209 | + |
| 210 | + p = Halide::Pipeline({o_matmul, o_weights}); |
| 211 | + |
| 212 | + if (cuda) { |
| 213 | + target.set_feature(Halide::Target::CUDA); |
| 214 | + } else { |
| 215 | + p.auto_schedule(target); |
| 216 | + } |
| 217 | + |
| 218 | + p.compile_jit(target); |
| 219 | + }; |
| 220 | + |
| 221 | + Halide::ImageParam inputFeatures = |
| 222 | + Halide::ImageParam(Halide::type_of<float>(), 3, "input_features"); |
| 223 | + Halide::ImageParam outputFeatures = |
| 224 | + Halide::ImageParam(Halide::type_of<float>(), 3, "output_features"); |
| 225 | + Halide::ImageParam rules = |
| 226 | + Halide::ImageParam(Halide::type_of<int>(), 1, "rules"); |
| 227 | + Halide::ImageParam weights = |
| 228 | + Halide::ImageParam(Halide::type_of<float>(), 3, "weights"); |
| 229 | + |
| 230 | + Halide::Param<int> activeRowsParam = Halide::Param<int>("row_count"); |
| 231 | + |
| 232 | + Halide::Pipeline p; |
| 233 | +}; |
| 234 | + |
| 235 | +HalideMulBackward::HalideMulBackward(int inFeatureCount, int outFeatureCount, |
| 236 | + int groups, bool cuda) |
| 237 | + : HalideMul(inFeatureCount, outFeatureCount, groups), |
| 238 | + pimpl(new Impl(dimensions, cuda)) {} |
| 239 | + |
| 240 | +HalideMulBackward::HalideMulBackward(const LayerDimensions &dims) |
| 241 | + : HalideMul(dims), pimpl(new Impl(dimensions, dims.cuda)) {} |
| 242 | + |
| 243 | +HalideMulBackward::~HalideMulBackward() = default; |
| 244 | + |
| 245 | +/* Executes the backward matrix multiplications created through the |
| 246 | + implementation object. */ |
| 247 | +void HalideMulBackward::execute(float *inputFeatures, float *outputFeatures, |
| 248 | + int *rules, float *weights, |
| 249 | + float *dWeightsOutput, float *output, |
| 250 | + int activeRowCount) const { |
| 251 | + |
| 252 | + int inputPlanes = dimensions.inFeatureCount / dimensions.groups; |
| 253 | + int outputPlanes = dimensions.outFeatureCount / dimensions.groups; |
| 254 | + |
| 255 | + pimpl->inputFeatures.set(Halide::Buffer<float>( |
| 256 | + inputFeatures, inputPlanes, dimensions.groups, maxHalideRow)); |
| 257 | + pimpl->outputFeatures.set(Halide::Buffer<float>( |
| 258 | + outputFeatures, outputPlanes, dimensions.groups, maxHalideRow)); |
| 259 | + pimpl->weights.set(Halide::Buffer<float>(weights, outputPlanes, inputPlanes, |
| 260 | + dimensions.groups)); |
| 261 | + pimpl->rules.set(Halide::Buffer<int>(rules, 2 * activeRowCount)); |
| 262 | + |
| 263 | + pimpl->activeRowsParam.set(activeRowCount); |
| 264 | + |
| 265 | + auto halideOutput = Halide::Buffer<float>(output, inputPlanes, activeRowCount, |
| 266 | + dimensions.groups); |
| 267 | + auto halideWOutput = Halide::Buffer<float>(dWeightsOutput, outputPlanes, |
| 268 | + inputPlanes, dimensions.groups); |
| 269 | + |
| 270 | + pimpl->p.realize({halideOutput, halideWOutput}); |
| 271 | +} |
0 commit comments