-
Notifications
You must be signed in to change notification settings - Fork 12.3k
llama : add high-throughput mode #14363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Right now I am comparatively less busy with my PhD so it would be a good time for me to write CUDA code that is still missing, if there is any. |
For now, these are the necessary CUDA changes:
// old
// q: [n_embd_k, n_batch, n_head, 1]
// k: [n_embd_k, n_kv, n_head_kv, 1]
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
...);
// new - supports `n_seq` dimension:
// q: [n_embd_k, n_batch, n_head, n_seq]
// k: [n_embd_k, n_kv, n_head_kv, n_seq]
// v: [n_embd_v, n_kv, n_head_kv, n_seq] !! not transposed !!
// mask: [n_kv, n_batch_pad, n_seq, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, n_seq] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
...); CPU might also need to be extended (not sure yet)
Edit: the CPU versions of |
ab2a2bb
to
1b74b9d
Compare
c246784
to
06bb08a
Compare
82277da
to
4534123
Compare
2f577c5
to
30b4d4e
Compare
6179578
to
dfceb01
Compare
eb5856c
to
ee0f729
Compare
ee0f729
to
deae7cd
Compare
988d0cd
to
dbcfcaa
Compare
src/llama-kv-cache-unified.cpp
Outdated
v_cells[s].resize(kv_size); | ||
} | ||
|
||
// by default, all sequence ids are mapped to the 0th virtual sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to understand the purpose of virtual sequences.
- Is it to make the unified cache not unified?
- Should it be a separate cache type instead?
- why is
n_seq_virt
a number and not abool
of whether or not the cache is unified?- Is it to eventually allow
n_seq_max % n_seq_virt == 0
for a partially-unified cache?
- Is it to eventually allow
Are virtual sequences intended to be used with other types of caches eventually (e.g. recurrent)?- The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Today I found a better term instead of "virtual sequences": "streams". So I'll use "streams" here and will update the code later today or tomorrow.
Is it to make the unified cache not unified?
Roughly yes. The user will be able to select between unified (i.e. single stream) or non-unified (multiple streams). Each mode has advantages in different scenarios. Single stream is good when the sequences share large common prefixes. Multiple streams are good when the sequences are mostly or completely independent from each other.
The first iteration will support 1 stream (i.e. same as master
, vanilla unified KV cache) and n_seq_max
streams. The latter means that each sequence id is assigned to a separate stream.
In theory, we could assign multiple sequence ids to the same stream to get a partially-unified KV cache, but this would need extra work and it might not have any useful applications. So out of scope for now.
Should it be a separate cache type instead?
There is too much similar logic. Still thinking about it, but most likely it will end up in the same cache type.
The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding)
Yes.
src/llama-batch.h
Outdated
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids | ||
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are sequential seq_ids
required when virtual sequences are used?
Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?
I wonder if there could be a way to avoid this requirement with ggml_get_rows
and/or ggml_mul_mat_id
. Might not be worth the extra indirection, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are sequential seq_ids required when virtual sequences are used?
Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?
Yes, we make a view of the KV cache across the streams here:
llama.cpp/src/llama-kv-cache-unified.cpp
Lines 976 to 992 in dbcfcaa
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { | |
const int32_t ikv = map_layer_ids.at(il); | |
auto * k = layers[ikv].k; | |
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; | |
const uint64_t kv_size = get_size(); | |
return ggml_view_4d(ctx, k, | |
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, | |
ggml_row_size(k->type, hparams.n_embd_head_k), | |
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), | |
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size), | |
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size)*sinfo.s0); | |
} | |
The ns
var is the number of streams that participate in the current ubatch
. Their stream indices range from [s0, s1]
.
I wonder if there could be a way to avoid this requirement with ggml_get_rows and/or ggml_mul_mat_id. Might not be worth the extra indirection, though.
It should be possible. But I'm not sure if it would be worth - both in performance and in complexity. We can explore though.
src/llama-kv-cache-unified.cpp
Outdated
@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( | |||
auto it = ctx_map.find(buft); | |||
if (it == ctx_map.end()) { | |||
ggml_init_params params = { | |||
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()), | |||
/*.mem_size =*/ size_t(2u*(1 + n_seq_virt)*n_layer_cache*ggml_tensor_overhead()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the 1 +
intended? Why was it added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the per-stream views of the KV cache:
llama.cpp/src/llama-kv-cache-unified.cpp
Lines 125 to 133 in dbcfcaa
std::vector<ggml_tensor *> k_seq; | |
std::vector<ggml_tensor *> v_seq; | |
for (uint32_t s = 0; s < n_seq_virt; ++s) { | |
k_seq.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); | |
v_seq.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); | |
} | |
These are used to implement the llama_memory_seq_cp()
. This operation is no longer just assigning ids - it performs actual copy of the buffers in memory when we use multiple streams. Using these helper views, the operation is quite simple to implement:
llama.cpp/src/llama-kv-cache-unified.cpp
Lines 289 to 329 in dbcfcaa
bool is_full = true; | |
if (p0 > 0 && p0 + 1 < (int) get_size()) { | |
is_full = false; | |
} | |
if (p1 > 0 && p1 + 1 < (int) get_size()) { | |
is_full = false; | |
} | |
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers"); | |
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (virt = %d) to %d (virt = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1); | |
for (uint32_t il = 0; il < layers.size(); ++il) { | |
const auto & layer = layers[il]; | |
ggml_backend_tensor_copy(layer.k_seq[s0], layer.k_seq[s1]); | |
ggml_backend_tensor_copy(layer.v_seq[s0], layer.v_seq[s1]); | |
// TODO: do we need synchronization here? | |
} | |
// TODO: support this: | |
GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift"); | |
v_cells[s1].reset(); | |
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) { | |
if (v_cells[s0].seq_has(i, seq_id_src)) { | |
v_cells[s1].pos_set(i, v_cells[s0].pos_get(i)); | |
v_cells[s1].seq_add(i, seq_id_dst); | |
} | |
} | |
v_heads[s1] = v_heads[s0]; | |
//for (uint32_t s = 0; s < n_seq_virt; ++s) { | |
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s)); | |
//} | |
} |
Though we cannot copy partial sequences when using multiple streams.
src/llama-batch.cpp
Outdated
// accept only increasing sequence ids | ||
if (sequential) { | ||
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about decreasing sequence ids? Is the requirement that they are increasing, or that the included seq_ids
should be in a contiguous range?
(decreasing sequence ids might not really happen often in practice though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decreasing would also work - we just need continuous range. We can either add this, if there is an elegant way to search for this. Or we add some batch pre-processing step to move the complexity at a higher level. Or just delegate it to the user by warning when the batch is not arranged optimally.
dbcfcaa
to
33dcc3c
Compare
33dcc3c
to
5363817
Compare
5363817
to
7b00429
Compare
d04f824
to
fa2573e
Compare
c96c48c
to
5c00eb2
Compare
@slaren PTAL - any suggestions are welcome. Note there is currently no way to test this on non-Apple hardware until the necessary operators are implemented by the backends. |
target #14285
Overview
Improve multi-sequence decoding performance by avoiding the cross-sequence attention compute.
Note
The functionality currently requires the
LLAMA_SET_ROWS
from #14285 and support forggml_soft_max_ext()
/ggml_flash_attn_ext()
broadcast (#14435)ggml_set_rows()
ggml_soft_max_ext()
ggml_flash_attn_ext()
ggml_set_rows()
ggml_soft_max_ext()
ggml_flash_attn_ext()
ggml_set_rows()
ggml_soft_max_ext()
ggml_flash_attn_ext()
ggml_set_rows()
ggml_soft_max_ext()
ggml_flash_attn_ext()
Description
One significant drawback of the unified KV cache is that it leads to performing a lot of unnecessary computation in the attention when the unified buffer is shared between many large independent sequences. The reason is that we have to view this buffer continuously and therefore we end up computing large potions of "cross-sequence attention" which we then simply discard.
With this change, we add option to split the unified KV cache buffer into multiple buffers - one for each sequence. This decouples the sequences from each other and improves the performance and memory usage of the attention when more than one sequence is used. To achieve that, when the batch reaches the attention, we split it into multiple "streams":
llama.cpp/src/llama-graph.cpp
Lines 1035 to 1044 in c96c48c
Each stream has its own KV cache buffer and thus no longer "sees" the rest of the other streams - it attends only to the tokens that belong to the same stream.
With this approach we now have 2 modes:
To enable the new mode, simply add the
--attn-streams
CLI arg to the llama.cpp tools. It should generally perform better for multi-user or multi-sequence scenarios.API Changes
bool llama_context_params::attn_streams
. Default isfalse
Testing
Define the
LLAMA_SET_ROWS=1
environment variable and add the--attn-streams
argument:Qwen 2.5 Coder 3B Q8_0, M2 Ultra
Geamma 3 4B Q8_0, M2 Ultra
Using a more real-world example with
llama-parallel
:TODO
ggml_soft_max_ext()
support for virtual sequencesllama_memory_seq_cp
support for virtual sequencessplit_equal
support sequential idsLLAMA_HT
become regular compute parametern_ctx
meaning (total vs per-sequence)Require(no longer needed)n_embd_v_gqa(il) == const
when FA is offNext PRs
(split_equal + padding)
and stream split[TAG_NO_CACHE_PAD]
ggml_set_rows()
is fully adoptedllama-parallel
to use different RNG seeds for the different clients