diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 91b1d6078a252..8d84c68053d32 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const { return n_outputs; } +uint32_t llama_batch_allocr::get_n_used() const { + return n_used; +} + std::vector & llama_batch_allocr::get_out_ids() { return out_ids; } @@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { void llama_batch_allocr::split_reset() { out_ids.clear(); + n_used = 0; + used.clear(); used.resize(get_n_tokens(), false); @@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; ++cur_idx; @@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { idxs_per_seq[s].push_back(idx); used[idx] = true; + ++n_used; ++cur_idx[s]; } @@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; if (idxs.size() >= n_ubatch) { break; diff --git a/src/llama-batch.h b/src/llama-batch.h index d2c5376188a0b..edff8cdd63cfc 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -54,6 +54,7 @@ class llama_batch_allocr { uint32_t get_n_tokens() const; uint32_t get_n_outputs() const; + uint32_t get_n_used() const; // the array of output indices in the order they were encountered during the ubatch splitting std::vector & get_out_ids(); @@ -125,6 +126,8 @@ class llama_batch_allocr { // batch indices of the output std::vector out_ids; + uint32_t n_used; + // used[i] indicates if token i has already been used in a previous ubatch std::vector used; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index ee202cc710bd6..ab4c41c780397 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; @@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ff22079851b2a..d3129cc53281e 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos = prepare(ubatches); if (sinfos.empty()) { break; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 03d974d852039..908e927faa2aa 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { // TODO: will the recurrent cache be in an undefined context at this point? diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6ed84057ccfe2..ca0c8dd56f2fa 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & ubatch = balloc.split_equal(n_ubatch); } - if (ubatch.n_tokens == 0) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; }