@@ -334,6 +334,8 @@ int main(int argc, char ** argv) {
334
334
// number of tokens to keep when resetting context
335
335
if (params.n_keep < 0 || params.n_keep > (int ) embd_inp.size () || params.instruct || params.chatml ) {
336
336
params.n_keep = (int )embd_inp.size ();
337
+ } else {
338
+ params.n_keep += add_bos; // always keep the BOS token
337
339
}
338
340
339
341
// prefix & suffix for instruct mode
@@ -383,8 +385,8 @@ int main(int argc, char ** argv) {
383
385
}
384
386
}
385
387
386
- if (params.n_keep > 0 ) {
387
- LOG_TEE (" %s: static prompt based on n_keep: '" , __func__);
388
+ if (params.n_keep > add_bos ) {
389
+ LOG_TEE (" %s: static prompt based on n_keep: '" , __func__);
388
390
for (int i = 0 ; i < params.n_keep ; i++) {
389
391
LOG_TEE (" %s" , llama_token_to_piece (ctx, embd_inp[i]).c_str ());
390
392
}
@@ -540,14 +542,14 @@ int main(int argc, char ** argv) {
540
542
break ;
541
543
}
542
544
543
- const int n_left = n_past - params.n_keep - 1 ;
545
+ const int n_left = n_past - params.n_keep ;
544
546
const int n_discard = n_left/2 ;
545
547
546
548
LOG (" context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n " ,
547
549
n_past, n_left, n_ctx, params.n_keep , n_discard);
548
550
549
- llama_kv_cache_seq_rm (ctx, 0 , params.n_keep + 1 , params.n_keep + n_discard + 1 );
550
- llama_kv_cache_seq_shift (ctx, 0 , params.n_keep + 1 + n_discard, n_past, -n_discard);
551
+ llama_kv_cache_seq_rm (ctx, 0 , params.n_keep , params.n_keep + n_discard);
552
+ llama_kv_cache_seq_shift (ctx, 0 , params.n_keep + n_discard, n_past, -n_discard);
551
553
552
554
n_past -= n_discard;
553
555
0 commit comments