Skip to content

Commit b77175d

Browse files
authored
Add: File Based Caching for lm_eval tests (#2051)
# LM Eval Caching System ## Overview Caches base model evaluation results to speed up tests by ~30-50%. When multiple quantization tests share the same base model and evaluation parameters, cached results are reused instead of re-evaluating. ## Quick Start ### Enable (Default) ```bash pytest tests/lmeval/test_lmeval.py ``` First run caches results, subsequent runs use cached results for matching parameters. ### Disable ```bash DISABLE_LMEVAL_CACHE=1 pytest tests/lmeval/test_lmeval.py ``` ### Custom Cache Location ```bash LMEVAL_CACHE_DIR=/tmp/cache pytest tests/lmeval/test_lmeval.py ``` ## How It Works Results are cached based on: - Model ID (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`) - Task (e.g., `gsm8k`) - Few-shot count - Sample limit - Batch size - Model arguments (hashed) ### Cache Storage Single CSV file: `.lmeval_cache/cache.csv` ```csv model,task,num_fewshot,limit,batch_size,model_args_hash,result TinyLlama/TinyLlama-1.1B-Chat-v1.0,gsm8k,5,1000,100,abc123def456,"{'results': {...}}" ``` ## Usage in Tests ```python from tests.testing_utils import cached_lm_eval_run class TestLMEval: @cached_lm_eval_run def _eval_base_model(self) -> dict: return self._eval_model(self.model) ``` ## Cache Management ### Inspect Cache ```bash cat .lmeval_cache/cache.csv # Or with pandas python -c "import pandas as pd; print(pd.read_csv('.lmeval_cache/cache.csv'))" ``` ### Clear Cache ```bash rm -rf .lmeval_cache/ ``` ### Ignore in Git Already added to `.gitignore`: ```gitignore .lmeval_cache/ ``` ## Environment Variables | Variable | Values | Default | Description | |----------|--------|---------|-------------| | `DISABLE_LMEVAL_CACHE` | `1`, `true`, `yes` | disabled | Disable caching | | `LMEVAL_CACHE_DIR` | path | `.lmeval_cache` | Cache directory | ## Performance Example **Without cache:** 175s total ``` Test 1 (W4A16): Base (60s) + Compressed (30s) = 90s Test 2 (W8A8): Base (60s) + Compressed (25s) = 85s ``` **With cache:** 115s total (34% faster) ``` Test 1 (W4A16): Base (60s) + Compressed (30s) = 90s [MISS] Test 2 (W8A8): Base (0.1s) + Compressed (25s) = 25s [HIT] ``` ## Example Cache Logs When running tests, you'll see these log messages: **First test (cache miss):** ``` LM-Eval cache MISS: meta-llama/Meta-Llama-3-8B-Instruct/gsm8k # ... evaluation runs ... LM-Eval cache WRITE: meta-llama/Meta-Llama-3-8B-Instruct/gsm8k ``` **Second test with same base model (cache hit):** ``` LM-Eval cache HIT: meta-llama/Meta-Llama-3-8B-Instruct/gsm8k # ... evaluation skipped, cached result returned ... ``` ### Testing the Cache Run two tests that share the same base model: ```bash # Clean cache rm -rf .lmeval_cache/ # Test 1: FP8_DYNAMIC scheme - Cache MISS CUDA_VISIBLE_DEVICES=5 CADENCE=weekly \ TEST_DATA_FILE=tests/lmeval/configs/fp8_dynamic_per_token.yaml \ .venv/bin/python -m pytest tests/lmeval/test_lmeval.py::TestLMEval::test_lm_eval -v # Test 2: FP8 static scheme (same base model) - Cache HIT CUDA_VISIBLE_DEVICES=5 CADENCE=weekly \ TEST_DATA_FILE=tests/lmeval/configs/fp8_static_per_tensor.yaml \ .venv/bin/python -m pytest tests/lmeval/test_lmeval.py::TestLMEval::test_lm_eval -v # Inspect cache cat .lmeval_cache/cache.csv ``` Expected output shows the second test completes much faster due to cached base model evaluation. ## Troubleshooting ### Cache Not Working Check environment variable: ```bash echo $DISABLE_LMEVAL_CACHE # Should be empty ``` Verify cache file exists: ```bash ls -lh .lmeval_cache/cache.csv ``` ### Cache Always Misses Ensure exact parameter match: - Model ID (case-sensitive) - Task name - All numeric parameters (fewshot, limit, batch_size) - Model arguments ### Corrupted Cache Simply delete and recreate: ```bash rm .lmeval_cache/cache.csv pytest tests/lmeval/test_lmeval.py ``` ## CI/CD Integration ### GitHub Actions ```yaml - name: Run LM Eval tests with cache env: LMEVAL_CACHE_DIR: ${{ runner.temp }}/lm_cache run: pytest tests/lmeval/test_lmeval.py ``` ## Design Notes ### Why CSV? - Matches existing timing data format (`timings/*.csv`) - Simple pandas integration ### Error Handling Failures are logged but don't break tests - cache simply falls back to uncached execution on any error. Rebased version of #1900 --------- Signed-off-by: Rahul-Tuli <[email protected]>
1 parent 9cb1c6d commit b77175d

File tree

3 files changed

+169
-8
lines changed

3 files changed

+169
-8
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ timings/
805805
output_finetune/
806806
env_log.json
807807

808+
# LM Eval cache
809+
.lmeval_cache/
810+
808811
# uv artifacts
809812
uv.lock
810813
.venv/

tests/lmeval/test_lmeval.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from llmcompressor.core import active_session
1616
from tests.e2e.e2e_utils import load_model, run_oneshot_for_e2e_testing
1717
from tests.test_timer.timer_utils import get_singleton_manager, log_time
18-
from tests.testing_utils import requires_gpu
18+
from tests.testing_utils import cached_lm_eval_run, requires_gpu
1919

2020

2121
class LmEvalConfig(BaseModel):
@@ -100,12 +100,12 @@ def set_up(self, test_data_file: str):
100100
self.recipe = eval_config.get("recipe")
101101
self.quant_type = eval_config.get("quant_type")
102102
self.save_dir = eval_config.get("save_dir")
103+
self.seed = eval_config.get("seed", None)
103104

104-
seed = eval_config.get("seed", None)
105-
if seed is not None:
106-
random.seed(seed)
107-
numpy.random.seed(seed)
108-
torch.manual_seed(seed)
105+
if self.seed is not None:
106+
random.seed(self.seed)
107+
numpy.random.seed(self.seed)
108+
torch.manual_seed(self.seed)
109109

110110
logger.info("========== RUNNING ==============")
111111
logger.info(self.scheme)
@@ -161,8 +161,9 @@ def test_lm_eval(self, test_data_file: str):
161161
self.tear_down()
162162

163163
@log_time
164+
@cached_lm_eval_run
164165
def _eval_base_model(self) -> dict:
165-
"""Evaluate the base (uncompressed) model."""
166+
"""Evaluate the base (uncompressed) model with caching."""
166167
return self._eval_model(self.model)
167168

168169
@log_time

tests/testing_utils.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,37 @@
11
import dataclasses
22
import enum
3+
import hashlib
4+
import json
35
import logging
46
import os
57
from dataclasses import dataclass
68
from enum import Enum
9+
from functools import wraps
710
from pathlib import Path
811
from subprocess import PIPE, STDOUT, run
9-
from typing import Callable, List, Optional, Union
12+
from typing import Any, Callable, Dict, List, Optional, Union
1013

14+
import pandas as pd
1115
import pytest
1216
import torch
1317
import yaml
1418
from datasets import Dataset
19+
from loguru import logger
1520
from transformers import ProcessorMixin
1621

1722
TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", None)
23+
DISABLE_LMEVAL_CACHE = os.environ.get("DISABLE_LMEVAL_CACHE", "").lower() in (
24+
"1",
25+
"true",
26+
"yes",
27+
)
28+
LMEVAL_CACHE_DIR = Path(os.environ.get("LMEVAL_CACHE_DIR", ".lmeval_cache"))
29+
LMEVAL_CACHE_FILE = LMEVAL_CACHE_DIR / "cache.csv"
30+
31+
32+
def _sha256_hash(text: str, length: Optional[int] = None) -> str:
33+
hash_result = hashlib.sha256(text.encode()).hexdigest()
34+
return hash_result[:length] if length else hash_result
1835

1936

2037
# TODO: maybe test type as decorators?
@@ -292,3 +309,143 @@ def requires_cadence(cadence: Union[str, List[str]]) -> Callable:
292309
return pytest.mark.skipif(
293310
(current_cadence not in cadence), reason="cadence mismatch"
294311
)
312+
313+
314+
@dataclass(frozen=True)
315+
class LMEvalCacheKey:
316+
"""Cache key for LM Eval results based on evaluation parameters."""
317+
318+
model: str
319+
task: str
320+
num_fewshot: int
321+
limit: int
322+
batch_size: int
323+
model_args_hash: str
324+
lmeval_version: str
325+
seed: Optional[int]
326+
327+
@classmethod
328+
def from_test_instance(cls, test_instance: Any) -> "LMEvalCacheKey":
329+
"""Create cache key from test instance."""
330+
try:
331+
import lm_eval
332+
333+
lmeval_version = lm_eval.__version__
334+
except (ImportError, AttributeError):
335+
lmeval_version = "unknown"
336+
337+
lmeval = test_instance.lmeval
338+
model_args_json = json.dumps(lmeval.model_args, sort_keys=True)
339+
seed = getattr(test_instance, "seed", None)
340+
341+
return cls(
342+
model=test_instance.model,
343+
task=lmeval.task,
344+
num_fewshot=lmeval.num_fewshot,
345+
limit=lmeval.limit,
346+
batch_size=lmeval.batch_size,
347+
model_args_hash=_sha256_hash(model_args_json, 16),
348+
lmeval_version=lmeval_version,
349+
seed=seed,
350+
)
351+
352+
def _matches(self, row: pd.Series) -> bool:
353+
"""Check if a DataFrame row matches this cache key."""
354+
# Handle NaN for seed comparison (pandas reads None as NaN)
355+
seed_matches = (pd.isna(row["seed"]) and self.seed is None) or (
356+
row["seed"] == self.seed
357+
)
358+
return (
359+
row["model"] == self.model
360+
and row["task"] == self.task
361+
and row["num_fewshot"] == self.num_fewshot
362+
and row["limit"] == self.limit
363+
and row["batch_size"] == self.batch_size
364+
and row["model_args_hash"] == self.model_args_hash
365+
and row["lmeval_version"] == self.lmeval_version
366+
and seed_matches
367+
)
368+
369+
def get_cached_result(self) -> Optional[Dict]:
370+
"""Load cached result from CSV file."""
371+
if not LMEVAL_CACHE_FILE.exists():
372+
return None
373+
374+
try:
375+
df = pd.read_csv(LMEVAL_CACHE_FILE)
376+
matches = df[df.apply(self._matches, axis=1)]
377+
378+
if matches.empty:
379+
return None
380+
381+
return json.loads(matches.iloc[0]["result"])
382+
383+
except Exception as e:
384+
logger.debug(f"Cache read failed: {e}")
385+
return None
386+
387+
def store_result(self, result: Dict) -> None:
388+
"""Store result in CSV file."""
389+
try:
390+
LMEVAL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
391+
392+
new_row = {
393+
"model": self.model,
394+
"task": self.task,
395+
"num_fewshot": self.num_fewshot,
396+
"limit": self.limit,
397+
"batch_size": self.batch_size,
398+
"model_args_hash": self.model_args_hash,
399+
"lmeval_version": self.lmeval_version,
400+
"seed": self.seed,
401+
"result": json.dumps(result, default=str),
402+
}
403+
404+
# Load existing cache or create new
405+
if LMEVAL_CACHE_FILE.exists():
406+
df = pd.read_csv(LMEVAL_CACHE_FILE)
407+
# Remove duplicate entries for this key
408+
df = df[~df.apply(self._matches, axis=1)]
409+
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
410+
else:
411+
df = pd.DataFrame([new_row])
412+
413+
df.to_csv(LMEVAL_CACHE_FILE, index=False)
414+
logger.info(f"LM-Eval cache WRITE: {self.model}/{self.task}")
415+
416+
except Exception as e:
417+
logger.debug(f"Cache write failed: {e}")
418+
419+
420+
def cached_lm_eval_run(func: Callable) -> Callable:
421+
"""
422+
Decorator to cache lm_eval results in CSV format.
423+
424+
Caches results based on model, task, num_fewshot, limit, batch_size,
425+
and model_args to avoid redundant base model evaluations.
426+
427+
Environment variables:
428+
DISABLE_LMEVAL_CACHE: Set to "1"/"true"/"yes" to disable
429+
LMEVAL_CACHE_DIR: Custom cache directory (default: .lmeval_cache)
430+
"""
431+
432+
@wraps(func)
433+
def wrapper(self, *args, **kwargs):
434+
# Skip caching if disabled
435+
if DISABLE_LMEVAL_CACHE:
436+
return func(self, *args, **kwargs)
437+
438+
# Try to get cached result
439+
cache_key = LMEvalCacheKey.from_test_instance(self)
440+
if (cached_result := cache_key.get_cached_result()) is not None:
441+
logger.info(f"LM-Eval cache HIT: {cache_key.model}/{cache_key.task}")
442+
return cached_result
443+
444+
# Run evaluation and cache result
445+
logger.info(f"LM-Eval cache MISS: {cache_key.model}/{cache_key.task}")
446+
result = func(self, *args, **kwargs)
447+
cache_key.store_result(result)
448+
449+
return result
450+
451+
return wrapper

0 commit comments

Comments
 (0)