-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathtest_cache.py
More file actions
111 lines (81 loc) · 3.05 KB
/
test_cache.py
File metadata and controls
111 lines (81 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import sqlite3
import time
import pytest
from cuda.tile._cache import cache_key, cache_lookup, cache_store, evict_lru
def test_cache_key_equal():
k1 = cache_key("v1", "sm_90", 3, b"data")
k2 = cache_key("v1", "sm_90", 3, b"data")
assert k1 == k2
def test_cache_key_differs():
base = cache_key("v1", "sm_90", 3, b"data")
assert cache_key("v2", "sm_90", 3, b"data") != base
assert cache_key("v1", "sm_80", 3, b"data") != base
assert cache_key("v1", "sm_90", 2, b"data") != base
assert cache_key("v1", "sm_90", 3, b"other") != base
def make_cubin(tmp_path, name, content):
src = tmp_path / name
src.write_bytes(content)
return src
@pytest.fixture
def cache_env(tmp_path):
cache_dir = str(tmp_path / "cache")
return cache_dir, tmp_path
def test_store_then_lookup(cache_env):
cache_dir, tmp_path = cache_env
key = cache_key("v1", "sm_90", 3, b"data")
content = b"\x7fELF_fake_cubin_data"
cache_store(cache_dir, key, make_cubin(tmp_path, "kernel.cubin", content))
result = cache_lookup(cache_dir, key, str(tmp_path))
assert result is not None
assert result.read_bytes() == content
def test_lookup_updates_atime(cache_env):
cache_dir, tmp_path = cache_env
key = cache_key("v1", "sm_90", 3, b"data")
cache_store(cache_dir, key, make_cubin(tmp_path, "kernel.cubin", b"data"))
# Manually set old atime in DB
import os
db_path = os.path.join(cache_dir, "cache.db")
old_time = time.time() - 1000
conn = sqlite3.connect(db_path)
conn.execute("UPDATE cache SET atime = ? WHERE key = ?", (old_time, key))
conn.commit()
conn.close()
cache_lookup(cache_dir, key, str(tmp_path))
conn = sqlite3.connect(db_path)
atime = conn.execute(
"SELECT atime FROM cache WHERE key = ?", (key,)
).fetchone()[0]
conn.close()
assert atime > old_time
def test_lookup_miss(cache_env):
cache_dir, _ = cache_env
result = cache_lookup(cache_dir, "a" * 64, str(cache_dir))
assert result is None
def test_evict_lru(cache_env):
cache_dir, tmp_path = cache_env
import os
db_path = os.path.join(cache_dir, "cache.db")
# Populate 5 entries (1000 bytes each, 5000 total)
keys = []
for i in range(5):
key = cache_key(str(i), "sm_90", 3, b"data")
keys.append(key)
cache_store(cache_dir, key,
make_cubin(tmp_path, f"k{i}.cubin", b"x" * 1000))
# Set controlled atimes so eviction order is deterministic
conn = sqlite3.connect(db_path)
for i, key in enumerate(keys):
conn.execute(
"UPDATE cache SET atime = ? WHERE key = ?",
(float(i), key)
)
conn.commit()
conn.close()
# Evict to keep 3000 bytes; newest 3 survive (indices 2, 3, 4)
evict_lru(cache_dir, 3000)
remaining = [k for k in keys
if cache_lookup(cache_dir, k, str(tmp_path)) is not None]
assert remaining == keys[2:]