Input prompts:
prompts: List[str] = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"""A brief message congratulating the team on the launch:
Hi everyone,
I just """,
# Few shot prompt (providing a few examples before asking model to complete more);
"Roosevelt was the first president of the United States, he has",
]- After optimizing the decode phase with cuda graph, the time for a single decode phase is
8.2402ms, compared to17.2241ms before using cuda graph, which is a performance improvement of 2x times, which is almost the same as the performance improvement after applying cuda graph to vllm.
INFO: After apply cuda graph, Decode inference time: 8.2402 ms
INFO: Before apply cuda graph, Decode inference time: 17.2241 ms- On the basis of the previous, flashattention has been used to take off the original standard attention.
flashattention1 is more helpful in training the model, and its speedup effect is limited when the prompt words are very short. The decode phase of inference should be flash-decoding.
INFO: input tokens shape is torch.Size([8, 115])
# Before using flashattention
INFO:lite_llama.generate:Batch inference time: 3152.0476 ms
INFO:lite_llama.generate:Tokens per second: 97.71 tokens/s
# After using flashattention
INFO:lite_llama.generate:Batch inference time: 2681.3823 ms
INFO:lite_llama.generate:Tokens per second: 114.87 tokens/s- Continue optimization by upgrading
flashattentiontoflashattention2to reduce some computation.
INFO:lite_llama.generate:Batch inference time: 2103.0737 ms
INFO:lite_llama.generate:Tokens per second: 146.45 tokens/s- Further optimized by using
flashdecodingin the decoding phase to improve the parallelism of attention computation during decoding, thereby fully leveraging the GPU's computational power.
INFO:lite_llama.generate:Decode stage Batch inference time: 1641.4178 ms
INFO:lite_llama.generate:Decode stage tokens per second : 187.64 tokens/s- Further optimization includes efficient dynamic management of the KV cache (similar to TokenAttention), addressing issues of memory waste and inefficient allocation in KV cache usage.
INFO:lite_llama.generate:Decode stage Batch inference time: 1413.9111 ms
INFO:lite_llama.generate:Decode stage tokens per second : 217.84 tokens/s-
A simple optimization is to replace the
repeat_kvfunction withGQA_KV_heads_index. -
A common and straightforward optimization is the fusion of the key and value linear layers.
-
A commonly used optimization is operator fusion: fusing the residual connection's skip operation with the
rmsnormoperator to form a newskip_rmsnormoperator. -
Refactored and optimized the
MHAmodule, improving thecontext_attentionand token_attention kernels to supportNopad attentionas well as dynamic allocation and management of thekv cache.
- token_attention now supports directly passing kv_cache indices and the actual sequence length seq_len, reducing
concatandviewoperations within theMHAmodule and enablingNopadtoken_attention. - During each prefill/decode step, the number of kv_cache indices is dynamically allocated based on the actual prompt length, instead of pre-allocating a continuous kv_cache space for
(max(prompt_len) + max_gen_len) * batch_sizetokens before inference.