Skip to content

Commit d4ee0a8

Browse files
authored
optimize phi3 memory usage (intel#11867)
1 parent 5b83493 commit d4ee0a8

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

python/llm/src/ipex_llm/transformers/kv.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ def update(
121121

122122
return self.key_cache[layer_idx], self.value_cache[layer_idx]
123123

124+
@classmethod
125+
def from_reserved(cls, layers: int,
126+
bsz: int, n_head: int, length: int, head_dim: int,
127+
dtype: torch.dtype, device: torch.device):
128+
past_key_values = cls()
129+
for _i in range(layers):
130+
k_cache, v_cache = init_kv_cache(
131+
bsz, n_head, head_dim,
132+
0, length + cls.KV_ALLOC_BLOCK_LENGTH,
133+
dtype, device
134+
)
135+
past_key_values.key_cache.append(k_cache)
136+
past_key_values.value_cache.append(v_cache)
137+
return past_key_values
138+
124139

125140
# Copied from transformers.models.llama.modeling_llama.repeat_kv
126141
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

python/llm/src/ipex_llm/transformers/models/phi3.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,9 @@ def model_forward(
254254
):
255255
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
256256
use_cache = use_cache if use_cache is not None else self.config.use_cache
257-
input = input_ids if input_ids is not None else inputs_embeds
258-
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
259-
use_compress_kv = should_use_compresskv(input, input.shape[1])
257+
inputs = input_ids if input_ids is not None else inputs_embeds
258+
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
259+
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
260260
if use_cache:
261261
if use_compress_kv and not isinstance(past_key_values,
262262
DynamicCompressCache):
@@ -272,6 +272,14 @@ def model_forward(
272272
DynamicCompressCache
273273
)):
274274
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
275+
if past_key_values.get_seq_length() == 0:
276+
n_layer = self.config.num_hidden_layers
277+
n_head = self.config.num_attention_heads
278+
head_dim = self.config.hidden_size // self.config.num_attention_heads
279+
past_key_values = DynamicNormalCache.from_reserved(
280+
n_layer, inputs.size(0), n_head, inputs.size(1), head_dim,
281+
inputs.dtype, inputs.device
282+
)
275283
return origin_model_forward(
276284
self=self,
277285
input_ids=input_ids,

0 commit comments

Comments
 (0)