Skip to content

Commit d186d94

Browse files
committed
Check shapes on transformer cache
1 parent e27da62 commit d186d94

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/models/transformer.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Transformer : public EncoderOrDecoderBase {
2828

2929
protected:
3030
using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
31-
std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
31+
std::unordered_map<std::string, std::pair<Shape, Expr>> cache_; // caching transformation of the encoder that should not be created again
3232
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings
3333

3434
bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth);
@@ -288,10 +288,10 @@ class Transformer : public EncoderOrDecoderBase {
288288
// Caching transformation of the encoder that should not be created again.
289289
// @TODO: set this automatically by memoizing encoder context and
290290
// memoization propagation (short-term)
291-
if (cache // if caching
292-
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
293-
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
294-
kh = cache_[prefix + "_keys"]; // then return cached tensor
291+
if (cache // if caching
292+
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
293+
&& cache_[prefix + "_keys"].first == keys->shape()) { // and the underlying element size did not change
294+
kh = cache_[prefix + "_keys"].second; // then return cached tensor
295295
}
296296
else {
297297
int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
@@ -300,22 +300,22 @@ class Transformer : public EncoderOrDecoderBase {
300300

301301
kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
302302
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
303-
cache_[prefix + "_keys"] = kh;
303+
cache_[prefix + "_keys"] = std::make_pair(keys->shape(), kh);
304304
}
305305

306306
Expr vh;
307307
if (cache
308308
&& cache_.count(prefix + "_values") > 0
309-
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
310-
vh = cache_[prefix + "_values"];
309+
&& cache_[prefix + "_values"].first == values->shape()) {
310+
vh = cache_[prefix + "_values"].second;
311311
} else {
312312
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
313313
auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
314314
auto bv = graph_->param(prefix + "_bv", {1, dimModel}, inits::zeros());
315315

316316
vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
317317
vh = SplitHeads(vh, dimHeads);
318-
cache_[prefix + "_values"] = vh;
318+
cache_[prefix + "_values"] = std::make_pair(values->shape(), vh);
319319
}
320320

321321
int dimBeam = q->shape()[-4];

0 commit comments

Comments
 (0)