Skip to content

Commit 2bc0540

Browse files
author
niushengxiao
committed
feat: extend the num head terms for fp8 calibration
1 parent 2964a98 commit 2bc0540

File tree

4 files changed

+1506
-750
lines changed

4 files changed

+1506
-750
lines changed

lightllm/common/offline_fp8_quant_mem_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __init__(
2525

2626
self.qmax = torch.finfo(torch.float8_e4m3fn).max
2727
self.qmin = torch.finfo(torch.float8_e4m3fn).min
28-
self.layer_num = layer_num
2928
self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num
3029
self.count = 0
3130
self.scales = None
@@ -45,7 +44,13 @@ def __init__(
4544
self.scales_list = cfg["scales"]
4645
self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"])
4746
if not get_env_start_args().enable_fa3:
48-
self.scales = torch.repeat_interleave(self.scales, self.head_num, dim=-1)
47+
self.scales = torch.repeat_interleave(self.scales, head_num, dim=-1)
48+
elif cfg["num_head"] > self.total_head_num:
49+
factor = cfg["num_head"] // self.total_head_num
50+
self.scales = self.scales[..., ::factor].contiguous()
51+
elif cfg["num_head"] < self.total_head_num:
52+
factor = self.total_head_num // cfg["num_head"]
53+
self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous()
4954
if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1:
5055
half_head = self.total_head_num // 2
5156
start_head = dist.get_rank() * head_num
@@ -77,7 +82,7 @@ def _load_and_check_config(self):
7782
raise ValueError(
7883
f"num_layers {cfg['num_layers']} in config " f"not match current layer_num {self.layer_num}"
7984
)
80-
if cfg["num_head"] != self.total_head_num:
85+
if cfg["num_head"] % self.total_head_num != 0 and self.total_head_num % cfg["num_head"] != 0:
8186
raise ValueError(
8287
f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}"
8388
)

lightllm/server/api_cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def make_argument_parser() -> argparse.ArgumentParser:
175175
export_fp8kv_calibration record and export kv cache quant calibration results to a json file.
176176
It can be used for llama and qwen model.
177177
Calibration need to disable cudagraph and use fa3 or flashinfer backend.
178-
Tp size must no more than head num when calibration.
179178
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
180179
ppl_fp16 mode use ppl fast fp16 decode attention kernel;
181180
you need to read source code to make sure the supported detail mode for all models""",

0 commit comments

Comments
 (0)