Skip to content

Commit c009bde

Browse files
authored
[CLI] Report KV cache memory usage in mlc_llm compile (#3221)
This PR prints out the memory usage of KV cache: MB for one token's KV cache, and the total MB for model weights + intermediate buffers + a 4K-long KV cache. If somehow the required fields are not present in `config` and `metadata` (e.g. for an old model), we do nothing. Sample output in CLI: ``` [2025-05-03 21:44:24] INFO model_metadata.py:94: Total memory usage without KV cache: 2254.16 MB (Parameters: 923.16 MB. Temporary buffer: 1331.00 MB) [2025-05-03 21:44:24] INFO model_metadata.py:128: KV cache size: 0.11 MB per token in the context window [2025-05-03 21:44:24] INFO model_metadata.py:133: Total memory usage with a 4K KV cache: 2702.16 MB ```
1 parent 06bdbc2 commit c009bde

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

python/mlc_llm/cli/model_metadata.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,49 @@ def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBas
9393
total_size = params_bytes + temp_func_bytes
9494
logger.info(
9595
"%s: %.2f MB (Parameters: %.2f MB. Temporary buffer: %.2f MB)",
96-
green("Total memory usage without KV cache:"),
96+
green("Total memory usage without KV cache"),
9797
total_size / 1024 / 1024,
9898
params_bytes / 1024 / 1024,
9999
temp_func_bytes / 1024 / 1024,
100100
)
101101

102+
# Compute KV cache size per token of context window.
103+
if isinstance(config, ConfigBase):
104+
config = asdict(config)
105+
if (
106+
"head_dim" in config
107+
and "num_hidden_layers" in config
108+
and "num_key_value_heads" in config
109+
and "quantization" in metadata
110+
):
111+
quantization_type = metadata["quantization"]
112+
dtype_bytes = None
113+
if "f32" in quantization_type:
114+
dtype_bytes = 4
115+
elif "bf16" in quantization_type:
116+
dtype_bytes = 2
117+
elif "f16" in quantization_type:
118+
dtype_bytes = 2
119+
# TODO: If support quantized KV in future, need to change this # pylint: disable=fixme
120+
if dtype_bytes is not None:
121+
bytes_per_token = (
122+
config["head_dim"]
123+
* config["num_hidden_layers"]
124+
* config["num_key_value_heads"]
125+
* dtype_bytes
126+
* 2 # 2 for key and value
127+
)
128+
logger.info(
129+
"%s: %.2f MB per token in the context window",
130+
green("KV cache size"),
131+
bytes_per_token / 1024 / 1024,
132+
)
133+
logger.info(
134+
"%s: %.2f MB",
135+
green("Total memory usage with a 4K KV cache"),
136+
(total_size + bytes_per_token * 4096) / 1024 / 1024,
137+
)
138+
102139
logger.info(
103140
"To reduce memory usage, "
104141
"tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size`"

0 commit comments

Comments
 (0)