Skip to content

Commit e1e7e0a

Browse files
committed
add mem_faction
1 parent 203cc24 commit e1e7e0a

File tree

24 files changed

+198
-124
lines changed

24 files changed

+198
-124
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self, kvargs):
5555
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
5656
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
5757
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
58+
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
5859

5960
self._init_datatype()
6061
self._init_config()
@@ -119,6 +120,7 @@ def _init_mem_manager(self):
119120
head_num=self.config["num_attention_heads"] // self.world_size_,
120121
head_dim=self.config["n_embed"] // self.config["num_attention_heads"],
121122
layer_num=self.config["n_layer"],
123+
mem_fraction=self.mem_fraction,
122124
)
123125
self.max_total_token_num = self.mem_manager.size
124126
return

lightllm/common/int8kv_mem_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55

66
class INT8KVMemoryManager(MemoryManager):
7-
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True):
7+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9):
88
self.kv_dtype = torch.int8
9-
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True)
9+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=mem_fraction)
1010

1111
def get_cell_size(self):
12-
return self.head_num * self.head_dim * self.layer_num * 2 * torch._utils._element_size(
12+
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(
1313
self.kv_dtype
14-
) + self.head_num * self.layer_num * 2 * torch._utils._element_size(self.dtype)
14+
) + 2 * self.head_num * self.layer_num * torch._utils._element_size(self.dtype)
1515

1616
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
1717
self.kv_buffer = [

lightllm/common/mem_manager.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99

1010

1111
class MemoryManager:
12-
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False):
12+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
1313
self.size = size
1414
self.head_num = head_num
1515
self.head_dim = head_dim
1616
self.layer_num = layer_num
1717
self.always_copy = always_copy
18-
self.kv_dtype = dtype
18+
self.dtype = dtype
1919
# profile the max total token num if the size is None
20-
self.profile_size()
20+
self.profile_size(mem_fraction)
2121
# mem_state 修改为使用计数方式,方便后期实现token共享机制,实现beam search 等
2222
self.mem_state = torch.zeros((self.size,), dtype=torch.int32, device="cuda")
2323
self.indexes = torch.arange(0, self.size, dtype=torch.long, device="cuda")
@@ -33,26 +33,32 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
3333
self.shared_can_use_token_num = SharedInt(f"{str(nccl_port)}_mem_manger_can_use_token_num")
3434

3535
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
36-
self._init_buffers(self.size, dtype, head_num, head_dim, layer_num)
36+
self._init_buffers(
37+
self.size,
38+
dtype,
39+
head_num,
40+
head_dim,
41+
layer_num,
42+
)
3743

3844
def get_cell_size(self):
39-
return self.head_num * self.head_dim * self.layer_num * 2 * torch._utils._element_size(self.kv_dtype)
45+
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
4046

41-
def profile_size(self):
47+
def profile_size(self, mem_fraction):
4248
if self.size is not None:
4349
return
4450
import torch.distributed as dist
4551

4652
tp_rank = dist.get_rank()
4753
world_size = dist.get_world_size()
4854
total_memory = get_total_gpu_memory()
49-
available_memory = get_available_gpu_memory(tp_rank, world_size) - total_memory * (1 - 0.9)
55+
available_memory = get_available_gpu_memory(tp_rank, world_size) - total_memory * (1 - mem_fraction)
5056
cell_size = self.get_cell_size()
5157
self.size = int(available_memory * 1024 ** 3 / cell_size)
5258
logger.info(
5359
f"{str(available_memory)} GB space is available after load the model weight\n"
5460
f"{str(cell_size / 1024 ** 2)} MB is the size of one token kv cache\n"
55-
f"{self.size} is the profiled max_total_token_num with the mem_fraction 0.9\n"
61+
f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n"
5662
)
5763
return
5864

lightllm/common/ppl_int4kv_mem_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55

66
class PPLINT4KVMemoryManager(MemoryManager):
7-
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True):
7+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9):
88
self.kv_dtype = torch.int8
99
self.group_quant_size = 8
10-
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True)
10+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction)
1111

1212
def get_cell_size(self):
13-
return self.head_num * self.head_dim // 2 * self.layer_num * 2 * torch._utils._element_size(
13+
return 2 * self.head_num * self.head_dim // 2 * self.layer_num * torch._utils._element_size(
1414
self.kv_dtype
15-
) + self.head_num * self.head_dim // self.group_quant_size * self.layer_num * torch._utils._element_size(
15+
) + 2 * self.head_num * self.head_dim // self.group_quant_size * self.layer_num * torch._utils._element_size(
1616
self.dtype
1717
)
1818

lightllm/common/ppl_int8kv_mem_manager.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,27 @@
44

55

66
class PPLINT8KVMemoryManager(MemoryManager):
7-
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True):
8-
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True)
7+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9):
8+
self.kv_dtype = torch.int8
9+
self.group_quant_size = 8
10+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction)
11+
12+
def get_cell_size(self):
13+
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(
14+
self.kv_dtype
15+
) + 2 * self.head_num * self.head_dim // self.group_quant_size * self.layer_num * torch._utils._element_size(
16+
self.dtype
17+
)
918

1019
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
11-
group_quant_size = 8
12-
self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
13-
self.scale_buffer = [torch.empty((size, 2 * head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)]
14-
20+
self.kv_buffer = [
21+
torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)
22+
]
23+
self.scale_buffer = [
24+
torch.empty((size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda")
25+
for _ in range(layer_num)
26+
]
27+
1528
def _free_buffers(self):
1629
self.kv_buffer = None
17-
self.scale_buffer = None
30+
self.scale_buffer = None

lightllm/models/deepseek2/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def _init_mem_manager(self):
4848
head_num=1,
4949
head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"],
5050
layer_num=self.config["num_hidden_layers"],
51+
mem_fraction=self.mem_fraction,
5152
)
5253
return
5354

lightllm/models/gemma_2b/model.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from lightllm.common.mem_utils import MemoryManager
1414

15+
1516
class Gemma_2bTpPartModel(TpPartBaseModel):
1617
# weight class
1718
pre_and_post_weight_class = Gemma_2bPreAndPostLayerWeight
@@ -38,20 +39,22 @@ def _verify_params(self):
3839
# assert self.config["num_key_value_heads"] % self.world_size_ == 0
3940
assert self.config["num_attention_heads"] % self.world_size_ == 0
4041
return
41-
42+
4243
def _init_custom(self):
4344
self._init_to_get_rotary()
4445
return
45-
46+
4647
def _init_mem_manager(self):
47-
self.mem_manager = MemoryManager(self.max_total_token_num,
48-
dtype=self.data_type,
49-
head_num=self.config["num_key_value_heads"], # [SYM] always == 1
50-
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
51-
layer_num=self.config["num_hidden_layers"])
48+
self.mem_manager = MemoryManager(
49+
self.max_total_token_num,
50+
dtype=self.data_type,
51+
head_num=self.config["num_key_value_heads"], # [SYM] always == 1
52+
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
53+
layer_num=self.config["num_hidden_layers"],
54+
mem_fraction=self.mem_fraction,
55+
)
5256
return
5357

54-
5558
def _init_to_get_rotary(self, default_base=10000):
5659
if self.config.get("rope_scaling", {}) is None:
5760
rope_scaling_factor = 1.0
@@ -64,16 +67,16 @@ def _init_to_get_rotary(self, default_base=10000):
6467
max_seq_len = self.config["max_sequence_length"]
6568
else:
6669
max_position_embeddings = self.config.get(
67-
"max_position_embeddings",
68-
2048 if base <= 10000.0 + 1e-5 else 16384
70+
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
6971
)
7072
max_seq_len = max_position_embeddings * rope_scaling_factor
7173

72-
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_))
74+
inv_freq = 1.0 / (
75+
base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_)
76+
)
7377
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
7478
freqs = torch.outer(t, inv_freq)
7579

7680
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
7781
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
7882
return
79-

lightllm/models/internlm2_wquant/model.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,36 @@
33
import torch
44

55
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
6-
from lightllm.models.internlm2_wquant.layer_weights.transformer_layer_weight import Internlm2TransformerLayerWeightQuantized
6+
from lightllm.models.internlm2_wquant.layer_weights.transformer_layer_weight import (
7+
Internlm2TransformerLayerWeightQuantized,
8+
)
79
from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant
810
from lightllm.common.mem_utils import select_mem_manager_class
911

1012

1113
class Internlm2TpPartModelWQuant(InternlmTpPartModelWQuant):
1214
# weight class
13-
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
15+
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
1416
transformer_weight_class = Internlm2TransformerLayerWeightQuantized
1517

1618
def __init__(self, kvargs):
1719
super().__init__(kvargs)
18-
20+
1921
def _verify_params(self):
2022
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
2123
assert any("w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
2224
assert self.config["num_key_value_heads"] % self.world_size_ == 0
2325
assert self.config["num_attention_heads"] % self.world_size_ == 0
2426
return
25-
27+
2628
def _init_mem_manager(self):
27-
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
28-
dtype=torch.float16,
29-
head_num=self.config["num_key_value_heads"] // self.world_size_,
30-
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
31-
layer_num=self.config["num_hidden_layers"],
32-
always_copy=True)
33-
return
29+
self.mem_manager = select_mem_manager_class(self.mode)(
30+
self.max_total_token_num,
31+
dtype=torch.float16,
32+
head_num=self.config["num_key_value_heads"] // self.world_size_,
33+
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
34+
layer_num=self.config["num_hidden_layers"],
35+
always_copy=True,
36+
mem_fraction=self.mem_fraction,
37+
)
38+
return

lightllm/models/internlm_wquant/model.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import torch
44

55
from lightllm.models.internlm_wquant.layer_infer.transformer_layer_infer import InternlmTransformerLayerInferWquant
6-
from lightllm.models.internlm_wquant.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeightQuantized
6+
from lightllm.models.internlm_wquant.layer_weights.transformer_layer_weight import (
7+
InternlmTransformerLayerWeightQuantized,
8+
)
79
from lightllm.models.llama.model import LlamaTpPartModel
810
from lightllm.common.mem_utils import select_mem_manager_class
911

@@ -17,19 +19,24 @@ class InternlmTpPartModelWQuant(LlamaTpPartModel):
1719

1820
def __init__(self, kvargs):
1921
super().__init__(kvargs)
20-
22+
2123
def _verify_params(self):
2224
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
23-
assert any("w6a16" in mode_ or "w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
25+
assert any(
26+
"w6a16" in mode_ or "w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode
27+
), "only for weight quant model"
2428
assert self.config["num_key_value_heads"] % self.world_size_ == 0
2529
assert self.config["num_attention_heads"] % self.world_size_ == 0
2630
return
27-
31+
2832
def _init_mem_manager(self):
29-
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
30-
dtype=torch.float16,
31-
head_num=self.config["num_key_value_heads"] // self.world_size_,
32-
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
33-
layer_num=self.config["num_hidden_layers"],
34-
always_copy=True)
35-
return
33+
self.mem_manager = select_mem_manager_class(self.mode)(
34+
self.max_total_token_num,
35+
dtype=torch.float16,
36+
head_num=self.config["num_key_value_heads"] // self.world_size_,
37+
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
38+
layer_num=self.config["num_hidden_layers"],
39+
always_copy=True,
40+
mem_fraction=self.mem_fraction,
41+
)
42+
return

lightllm/models/llama/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _init_mem_manager(self):
6262
head_num=self.config["num_key_value_heads"] // self.world_size_,
6363
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
6464
layer_num=self.config["num_hidden_layers"],
65+
mem_fraction=self.mem_fraction,
6566
)
6667
return
6768

0 commit comments

Comments
 (0)