Skip to content

Commit cb02274

Browse files
committed
Allow scaling seqrep penalty for mid-word tokens
1 parent bd727dd commit cb02274

File tree

5 files changed

+60
-6
lines changed

5 files changed

+60
-6
lines changed

examples/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
280280
break;
281281
}
282282
params.seqrep_lpenalty = std::stof(argv[i]);
283+
} else if (arg == "--seqrep-mw-scale") {
284+
if (++i >= argc) {
285+
invalid_param = true;
286+
break;
287+
}
288+
params.seqrep_mw_scale = std::stof(argv[i]);
283289
} else if (arg == "--mirostat") {
284290
if (++i >= argc) {
285291
invalid_param = true;
@@ -591,6 +597,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
591597
fprintf(stdout, " --seqrep-tolerance N tolerance for fuzzy matching sequences (default: %d, 0 = disabled)\n", params.seqrep_tolerance);
592598
fprintf(stdout, " --seqrep-ppenalty N presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", params.seqrep_ppenalty);
593599
fprintf(stdout, " --seqrep-lpenalty N penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", params.seqrep_lpenalty);
600+
fprintf(stdout, " --seqrep-mw-scale N scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", params.seqrep_mw_scale);
594601
fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
595602
fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
596603
fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);

examples/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ struct gpt_params {
4949
int32_t seqrep_tolerance = 0; // tolerance for fuzzy sequence matching (0 = disabled)
5050
float seqrep_ppenalty = 0.0f; // flat penalty (0.0 = disabled)
5151
float seqrep_lpenalty = 0.0f; // stacking penalty based on length (0.0 = disabled)
52+
float seqrep_mw_scale = 0.1f; // scale penalty when applied to mid-word tokens (1.0 = apply full penalty)
5253
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
5354
float mirostat_tau = 5.00f; // target entropy
5455
float mirostat_eta = 0.10f; // learning rate

examples/main/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,9 @@ int main(int argc, char ** argv) {
334334
fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str());
335335
}
336336
}
337-
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, seqrep(last_n = %d, min_len = %d, tolerance = %d, ppenalty = %f, lpenalty = %f), top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
337+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, seqrep(last_n = %d, min_len = %d, tolerance = %d, ppenalty = %f, lpenalty = %f, mw_scale = %f), top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
338338
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty,
339-
params.seqrep_last_n, params.seqrep_min_len, params.seqrep_tolerance, params.seqrep_ppenalty, params.seqrep_lpenalty,
339+
params.seqrep_last_n, params.seqrep_min_len, params.seqrep_tolerance, params.seqrep_ppenalty, params.seqrep_lpenalty, params.seqrep_mw_scale,
340340
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
341341
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
342342
fprintf(stderr, "\n\n");
@@ -604,7 +604,7 @@ int main(int argc, char ** argv) {
604604
llama_sample_seqrep_penalty(ctx, &candidates_p,
605605
last_n_tokens.data() + last_n_tokens.size() - seqrep_last_n_repeat,
606606
seqrep_last_n_repeat, params.seqrep_min_len, params.seqrep_tolerance,
607-
params.seqrep_ppenalty, params.seqrep_lpenalty);
607+
params.seqrep_ppenalty, params.seqrep_lpenalty, params.seqrep_mw_scale);
608608
if (!penalize_nl) {
609609
logits[llama_token_nl()] = nl_logit;
610610
}

llama.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <queue>
4343
#include <cassert>
4444
#include <cstring>
45+
#include <cctype>
4546
#include <climits>
4647
#include <memory>
4748
#include <algorithm>
@@ -2690,11 +2691,44 @@ static size_t llama_seqrep_find_match(const llama_token * last_tokens_p, const s
26902691
return matches;
26912692
}
26922693

2693-
void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, size_t min_length, size_t tolerance, float flat_penalty, float length_penalty) {
2694+
// Internal helper function for sequence matching.
2695+
// Bit 1 set indicates token is a word boundary. NL, " blah", "," - word boundary. "blah", "blah:" - not a word boundary.
2696+
// Bit 2 set indicates token ends on a word boundary. NL, "blah:", "blah " - ends on word boundary. " blah", "blah" - doesn't end on word boundary.
2697+
// Errata: UTF8 safe but only considers ASCII characters. ASCII single quote is treated as a non-boundary which isn't always correct.
2698+
static uint8_t llama_seqrep_check_word(struct llama_context * ctx, const llama_token token) {
2699+
if (token == llama_token_bos() || token == llama_token_eos() || token == llama_token_nl()) {
2700+
// BOS, EOS, NL are always a boundary.
2701+
return 3;
2702+
}
2703+
const char * token_str = llama_token_to_str(ctx, token);
2704+
assert(token_str != NULL);
2705+
if (token_str[0] == '\0') {
2706+
// 0-length token string, can't be a boundary.
2707+
return 0;
2708+
}
2709+
2710+
const char start_char = token_str[0];
2711+
char end_char;
2712+
for (const char *curr_char = token_str; ; curr_char++) {
2713+
// Guaranteed to iterate at least once since we already checked if the string was 0-length.
2714+
if (*(curr_char + 1) == '\0') {
2715+
end_char = *curr_char;
2716+
break;
2717+
}
2718+
}
2719+
return uint8_t(
2720+
(start_char != '\'' && !isalnum((int)start_char) ? 1 : 0) +
2721+
(end_char != '\'' && !isalnum((int)end_char) ? 2 : 0)
2722+
);
2723+
2724+
}
2725+
2726+
void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, size_t min_length, size_t tolerance, float flat_penalty, float length_penalty, float mid_word_scale) {
26942727
if (min_length < 2 || last_tokens_size <= min_length ||
26952728
(flat_penalty == 0.0f && length_penalty == 0.0f)) {
26962729
return;
26972730
}
2731+
assert(ctx);
26982732

26992733
const int64_t t_start_sample_us = ggml_time_us();
27002734

@@ -2719,9 +2753,14 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
27192753
penalize_tokens[penalize_token] = pt_iter->second + matched_length;
27202754
}
27212755
}
2756+
2757+
const bool ends_on_word = (llama_seqrep_check_word(ctx, last_tokens_p[last_tokens_size - 1]) & 2) != 0;
2758+
27222759
for (const auto it : penalize_tokens) {
2760+
const bool pt_starts_word = (llama_seqrep_check_word(ctx, it.first) & 1) != 0;
27232761
candidates->data[it.first].logit -=
2724-
float(it.second) * length_penalty + float(it.second > 0) * flat_penalty;
2762+
(float(it.second) * length_penalty + float(it.second > 0) * flat_penalty)
2763+
* (ends_on_word || pt_starts_word ? 1.0f : mid_word_scale);
27252764
}
27262765

27272766
if (ctx) {

llama.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,14 @@ extern "C" {
412412
/// @params tolerance Tolerance for non-matching tokens in a sequence.
413413
/// @params flat_penalty Flat penalty applied to the token that can continue a repeated sequence.
414414
/// @params stacking_penalty Scaling penalty applied to the token that can continue a repeated sequence. The penalty is multiplied by the total length of sequences that are continued by this token.
415-
LLAMA_API void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, size_t min_length, size_t tolerance, float flat_penalty, float length_penalty);
415+
/// @params mid_word_scale Scale for penalizing tokens from repeated sequences that aren't at/form a word boundary.
416+
LLAMA_API void llama_sample_seqrep_penalty(
417+
struct llama_context * ctx,
418+
llama_token_data_array * candidates,
419+
const llama_token * last_tokens_p, size_t last_tokens_size,
420+
size_t min_length, size_t tolerance,
421+
float flat_penalty, float length_penalty,
422+
float mid_word_scale);
416423

417424
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
418425
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.

0 commit comments

Comments
 (0)