Skip to content

Commit 358b4d2

Browse files
authored
fix (#503)
1 parent 1a4d24e commit 358b4d2

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

mlx_lm/models/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def _update_concat(self, keys, values):
399399
# preserve context
400400
self.keys = self._temporal_order(self.keys)
401401
self.values = self._temporal_order(self.values)
402+
self._idx = self.keys.shape[2]
402403

403404
# The largest size is self.max_size + S - 1 to ensure
404405
# every token gets at least self.max_size context

tests/test_prompt_cache.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,20 @@ def test_save_load_batch_caches(self):
554554
for c, lc in zip(cache, loaded_cache):
555555
self.assertTrue(mx.array_equal(c.left_padding, left_padding))
556556

557+
def test_rotating_cache_updates(self):
558+
cache = RotatingKVCache(max_size=8)
559+
k = v = mx.zeros((1, 1, 10, 1))
560+
cache.update_and_fetch(k, v)
561+
562+
for _ in range(3):
563+
k = v = mx.zeros((1, 1, 1, 1))
564+
cache.update_and_fetch(k, v)
565+
566+
k = v = mx.zeros((1, 1, 3, 1))
567+
k, v = cache.update_and_fetch(k, v)
568+
self.assertEqual(k.shape[2], 10)
569+
self.assertEqual(v.shape[2], 10)
570+
557571

558572
if __name__ == "__main__":
559573
unittest.main()

0 commit comments

Comments
 (0)