@@ -28,7 +28,7 @@ class Transformer : public EncoderOrDecoderBase {
2828
2929protected:
3030 using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
31- std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
31+ std::unordered_map<std::string, std::pair<Shape, Expr> > cache_; // caching transformation of the encoder that should not be created again
3232 mutable /* lazy*/ std::vector<float > sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings
3333
3434 bool depthScaling_{false }; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth);
@@ -288,10 +288,10 @@ class Transformer : public EncoderOrDecoderBase {
288288 // Caching transformation of the encoder that should not be created again.
289289 // @TODO: set this automatically by memoizing encoder context and
290290 // memoization propagation (short-term)
291- if (cache // if caching
292- && cache_.count (prefix + " _keys" ) > 0 // and the keys expression has been seen
293- && cache_[prefix + " _keys" ]-> shape (). elements () == keys->shape (). elements ()) { // and the underlying element size did not change
294- kh = cache_[prefix + " _keys" ]; // then return cached tensor
291+ if (cache // if caching
292+ && cache_.count (prefix + " _keys" ) > 0 // and the keys expression has been seen
293+ && cache_[prefix + " _keys" ]. first == keys->shape ()) { // and the underlying element size did not change
294+ kh = cache_[prefix + " _keys" ]. second ; // then return cached tensor
295295 }
296296 else {
297297 int dimKeys = keys->shape ()[-1 ]; // different than dimModel when using lemma and factors combined with concatenation
@@ -300,22 +300,22 @@ class Transformer : public EncoderOrDecoderBase {
300300
301301 kh = affine (keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
302302 kh = SplitHeads (kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
303- cache_[prefix + " _keys" ] = kh ;
303+ cache_[prefix + " _keys" ] = std::make_pair (keys-> shape (), kh) ;
304304 }
305305
306306 Expr vh;
307307 if (cache
308308 && cache_.count (prefix + " _values" ) > 0
309- && cache_[prefix + " _values" ]-> shape (). elements () == values->shape (). elements ()) {
310- vh = cache_[prefix + " _values" ];
309+ && cache_[prefix + " _values" ]. first == values->shape ()) {
310+ vh = cache_[prefix + " _values" ]. second ;
311311 } else {
312312 int dimValues = values->shape ()[-1 ]; // different than dimModel when using lemma and factors combined with concatenation
313313 auto Wv = graph_->param (prefix + " _Wv" , {dimValues, dimModel}, inits::glorotUniform (true , true , depthScaling_ ? 1 .f / sqrtf ((float )depth_) : 1 .f ));
314314 auto bv = graph_->param (prefix + " _bv" , {1 , dimModel}, inits::zeros ());
315315
316316 vh = affine (values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
317317 vh = SplitHeads (vh, dimHeads);
318- cache_[prefix + " _values" ] = vh ;
318+ cache_[prefix + " _values" ] = std::make_pair (values-> shape (), vh) ;
319319 }
320320
321321 int dimBeam = q->shape ()[-4 ];
0 commit comments