Skip to content

batch : add optional for sequential equal split #14511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ bool llama_batch_allocr::init(

// note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true;

has_cpl = true;
}
}
}
Expand Down Expand Up @@ -459,9 +461,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
return ubatch_add(idxs, idxs.size(), false);
}

llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);

return {};
}

std::vector<seq_set_t> cur_seq_set;

llama_seq_id last_seq_id = -1;

// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
Expand All @@ -478,9 +488,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
}
}

// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}

if (add) {
cur_seq_set.push_back(seq_set[i]);

last_seq_id = batch.seq_id[i][0];

if (cur_seq_set.size() > n_ubatch) {
break;
}
Expand Down
6 changes: 5 additions & 1 deletion src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class llama_batch_allocr {
llama_ubatch split_simple(uint32_t n_ubatch);

// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);

// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
Expand Down Expand Up @@ -112,6 +113,9 @@ class llama_batch_allocr {
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;

// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;

std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1

Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all

std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch);
auto ubatch = balloc.split_equal(n_ubatch, false);

if (ubatch.n_tokens == 0) {
break;
Expand Down
2 changes: 1 addition & 1 deletion src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}

if (ubatch.n_tokens == 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}

if (ubatch.n_tokens == 0) {
Expand Down
Loading