Skip to content

batch : add n_used count #14512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}

uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}

std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids;
}
Expand All @@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
void llama_batch_allocr::split_reset() {
out_ids.clear();

n_used = 0;

used.clear();
used.resize(get_n_tokens(), false);

Expand All @@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
idxs.push_back(cur_idx);

used[cur_idx] = true;
++n_used;

++cur_idx;

Expand Down Expand Up @@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
idxs_per_seq[s].push_back(idx);

used[idx] = true;
++n_used;

++cur_idx[s];
}
Expand Down Expand Up @@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
idxs.push_back(cur_idx);

used[cur_idx] = true;
++n_used;

if (idxs.size() >= n_ubatch) {
break;
Expand Down
3 changes: 3 additions & 0 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class llama_batch_allocr {

uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
uint32_t get_n_used() const;

// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
Expand Down Expand Up @@ -125,6 +126,8 @@ class llama_batch_allocr {
// batch indices of the output
std::vector<int32_t> out_ids;

uint32_t n_used;

// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;

Expand Down
10 changes: 10 additions & 0 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}

if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}

auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
Expand Down Expand Up @@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}

if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}

auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
Expand Down
5 changes: 5 additions & 0 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
ubatches.push_back(std::move(ubatch)); // NOLINT
}

if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}

auto sinfos = prepare(ubatches);
if (sinfos.empty()) {
break;
Expand Down
5 changes: 5 additions & 0 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
ubatches.push_back(std::move(ubatch)); // NOLINT
}

if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}

// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?
Expand Down
3 changes: 2 additions & 1 deletion src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
ubatch = balloc.split_equal(n_ubatch);
}

if (ubatch.n_tokens == 0) {
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
Comment on lines -380 to +381
Copy link
Collaborator

@compilade compilade Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks inference when a batch doesn't fit in a single ubatch.

(noticed when updating #14139 to the latest master and getting a SEGFAULT with Mamba)

The check should be outside of the split loop, like with the other memory types.

This should be fixed in #14575

break;
}

Expand Down
Loading