Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions expose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,30 @@ extern "C"
{
return gpttype_load_state_kv(slot);
}

void set_savestate_limit(int limit)
{
gpttype_set_savestate_limit(limit);
}

void set_precomputed_lcs(const int* tokens, size_t len)
{
gpttype_set_precomputed_lcs(tokens, len);
}

bool clear_vram_kv()
{
return gpttype_clear_vram_kv();
}

bool clear_state_kv()
{
return gpttype_clear_state_kv(true);
}

bool free_slot_kv(int slot)
{
return gpttype_free_slot_kv(slot);
}

}
226 changes: 219 additions & 7 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,10 @@ static int delayed_generated_tokens_limit = 0;
std::deque<std::string> delayed_generated_tokens; //for use with antislop sampling
static std::map<int,std::vector<int>> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index

const int savestate_limit = 3;
static savestate_data savestates[savestate_limit];
static std::vector<savestate_data> savestates;

std::vector<int> smart_cache_lcs_precomputed;
bool smart_cache_has_precomputed_lcs = false;

inline int kcpp_cpu_has_blas(void) {
#if defined(GGML_USE_BLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_SYCL)
Expand Down Expand Up @@ -1908,6 +1910,94 @@ void PurgeMissingTokens(llama_context * ctx, llama_context * draft_ctx, std::vec

}

extern "C" int smart_cache_compute_purge_diff(
const int* vram_tokens, size_t vram_len,
const int* new_tokens, size_t new_len,
int nctx, int genamt,
int* out_trimstart, int* out_purge_offset, int* out_purge_length)
{
if (!vram_tokens || !new_tokens || vram_len == 0 || new_len == 0)
{
*out_trimstart = 0;
*out_purge_offset = 0;
*out_purge_length = 0;
return 0;
}

std::vector<int> vram_vec(vram_tokens, vram_tokens + vram_len);
std::vector<int> new_vec(new_tokens, new_tokens + new_len);

// Use same thresholds as PurgeMissingTokens for consistency
const int ShortfallThreshold = 200 + std::min((nctx/30),140);
const int SlackAllowance = 60 + std::min((nctx/60),70);

int trimstart = 0;
int new_tokens_len = new_vec.size();
bool purgeneeded = true;

for (int i = 0; i < vram_vec.size(); ++i)
{
if (vram_vec[i] == new_vec[i])
{
trimstart += 1;
}
else
{
break;
}
if ((i + 2) >= new_tokens_len)
{
purgeneeded = false;
break;
}
}

if (!purgeneeded || new_tokens_len < 6 || vram_vec.size() < 6 || new_tokens_len - trimstart < ShortfallThreshold)
{
*out_trimstart = trimstart;
*out_purge_offset = 0;
*out_purge_length = 0;
return 0;
}

// Use same LCS threshold calculation as PurgeMissingTokens
const int LCSTokThreshold = std::max(std::min((new_tokens_len - trimstart) - (genamt+SlackAllowance), (int)(nctx*0.45)), ShortfallThreshold-SlackAllowance);

auto curr_ctx_without_memory = std::vector<int>(vram_vec.begin() + trimstart, vram_vec.end());
auto new_ctx_without_memory = std::vector<int>(new_vec.begin() + trimstart, new_vec.end());

auto shared = LongestCommonSubseq(curr_ctx_without_memory, new_ctx_without_memory);

if (shared.size() > LCSTokThreshold && ArrStartWith(new_ctx_without_memory, shared))
{
int found = ArrFindIndexOf(vram_vec, shared);
if (found >= 0 && found > trimstart)
{
*out_trimstart = trimstart;
*out_purge_offset = trimstart;
*out_purge_length = found - trimstart;
return 1;
}
}

*out_trimstart = trimstart;
*out_purge_offset = 0;
*out_purge_length = 0;
return 0;
}

extern "C" int smart_cache_get_context_params(int* out_nctx, int* out_max_length)
{
if (!kcpp_data || !out_nctx || !out_max_length)
{
return 0;
}

*out_nctx = kcpp_data->n_ctx;
*out_max_length = 512; // Default generation length - can be adjusted
return 1;
}

static int GetBatchSize(int desiredBlasBatchSize,FileFormat in_file_format)
{
//check if approved to use BLAS
Expand Down Expand Up @@ -4660,10 +4750,12 @@ size_t gpttype_calc_new_state_kv()
}
size_t gpttype_calc_old_state_kv(int slot)
{
if (slot >= savestates.size()) return 0;
return savestates[slot].current_savestate_size + savestates[slot].current_draft_savestate_size;
}
size_t gpttype_calc_old_state_tokencount(int slot)
{
if (slot >= savestates.size()) return 0;
return savestates[slot].savestate_context_tokens.size();
}
size_t gpttype_calc_new_state_tokencount()
Expand All @@ -4678,6 +4770,8 @@ size_t gpttype_save_state_kv(int slot)
}
if(file_format == FileFormat::GGUF_GENERIC)
{
if (slot >= savestates.size()) return 0;

size_t totalbytes = 0;
if (!savestates[slot].current_savestate_buffer.empty()) { //JIT free
savestates[slot].current_savestate_buffer.clear();
Expand All @@ -4702,7 +4796,9 @@ size_t gpttype_save_state_kv(int slot)
totalbytes += res;
savestates[slot].current_savestate_size = newsize;
savestates[slot].savestate_context_tokens = current_context_tokens;
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
if (debugmode == 1) {
printf("\nKV Save State %d: Saved %zu tokens (%zu MB)\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
}
}

if(draft_ctx)
Expand All @@ -4722,7 +4818,9 @@ size_t gpttype_save_state_kv(int slot)
if (res2 > 0) {
totalbytes += res2;
savestates[slot].current_draft_savestate_size = newsize2;
printf("\nKV Save State %d: Created DraftSaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_draft_savestate_size/(1024*1024));
if (debugmode == 1) {
printf("\nKV Save State %d: Created DraftSaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_draft_savestate_size/(1024*1024));
}
}
}
return totalbytes;
Expand All @@ -4737,25 +4835,93 @@ bool gpttype_load_state_kv(int slot)
}
if(file_format == FileFormat::GGUF_GENERIC)
{
if (slot >= savestates.size()) return false;

if (savestates[slot].current_savestate_buffer.empty()) {
return false;
}
auto res = llama_state_set_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), savestates[slot].current_savestate_size);
if(res > 0)
{
current_context_tokens = savestates[slot].savestate_context_tokens;
printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
if (debugmode == 1) {
printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
}
if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
{
llama_memory_clear(llama_get_memory(draft_ctx),true);
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
if (debugmode == 1) {
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
}
}
}
return (res > 0);
}
return false;
}

// Clear VRAM KV cache only (does NOT touch RAM slots)
bool gpttype_clear_vram_kv()
{
if(kcpp_data==nullptr)
{
return false;
}
if(file_format == FileFormat::GGUF_GENERIC)
{
// Clear only the VRAM KV cache in llama_ctx_v4
llama_memory_clear(llama_get_memory(llama_ctx_v4), true);
if(draft_ctx)
{
llama_memory_clear(llama_get_memory(draft_ctx), true);
}

// Reset current_context_tokens since VRAM is now empty
current_context_tokens.clear();

if (debugmode == 1) {
printf("\nKV Clear VRAM: Cleared KV cache in VRAM (RAM slots preserved).\n");
}
return true;
}
return false;
}

// Free a single slot buffer (for smart cache slot pooling)
bool gpttype_free_slot_kv(int slot)
{
if(kcpp_data==nullptr)
{
return false;
}
if(file_format == FileFormat::GGUF_GENERIC)
{
if (slot >= savestates.size()) return false;

if (!savestates[slot].current_savestate_buffer.empty()) {
if (debugmode == 1) {
printf("\nKV Free Slot %d: Freed %zu MB.\n",slot, savestates[slot].current_savestate_size / (1024 * 1024));
}
savestates[slot].current_savestate_buffer.clear();
savestates[slot].current_savestate_buffer.shrink_to_fit();
savestates[slot].savestate_context_tokens.clear();
savestates[slot].current_savestate_size = 0;

if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
{
savestates[slot].current_draft_savestate_buffer.clear();
savestates[slot].current_draft_savestate_buffer.shrink_to_fit();
savestates[slot].current_draft_savestate_size = 0;
}
return true;
}
return false;
}
return false;
}

// Clear ALL saved states in RAM (original behavior, for admin/reset)
bool gpttype_clear_state_kv(bool shrink)
{
if(kcpp_data==nullptr)
Expand All @@ -4764,7 +4930,7 @@ bool gpttype_clear_state_kv(bool shrink)
}
if(file_format == FileFormat::GGUF_GENERIC)
{
for(int slot=0;slot<savestate_limit;++slot)
for(int slot=0;slot<savestates.size();++slot)
{
if (!savestates[slot].current_savestate_buffer.empty()) {
printf("\nKV Clear SaveState %d: Freed %zu MB.\n",slot, savestates[slot].current_savestate_size / (1024 * 1024));
Expand All @@ -4790,3 +4956,49 @@ bool gpttype_clear_state_kv(bool shrink)
}
return false;
}

extern "C" const int* get_current_context_tokens(size_t* out_size)
{
if (current_context_tokens.empty())
{
*out_size = 0;
return nullptr;
}

*out_size = current_context_tokens.size();
return current_context_tokens.data();
}

void gpttype_set_savestate_limit(int limit)
{
if (limit < 1)
{
fprintf(stderr, "Warning: savestate limit must be >= 1, using 1\n");
limit = 1;
}

if (limit > 128)
{
fprintf(stderr, "Warning: savestate limit = %d may use excessive RAM. Recommended: <= 32\n", limit);
}

savestates.resize(limit);
}

void gpttype_set_precomputed_lcs(const int* tokens, size_t len)
{
if (tokens == nullptr || len == 0)
{
smart_cache_lcs_precomputed.clear();
smart_cache_has_precomputed_lcs = false;
return;
}

// Pre-allocate to avoid reallocation in hot path
if (smart_cache_lcs_precomputed.capacity() < len)
{
smart_cache_lcs_precomputed.reserve(std::max(len, size_t(4096)));
}
smart_cache_lcs_precomputed.assign(tokens, tokens + len);
smart_cache_has_precomputed_lcs = true;
}
Loading