diff --git a/CMakeLists.txt b/CMakeLists.txt index f32df5fe52335..29b770236705b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,8 @@ option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) +option(LLAMA_SEQREP_SAMPLER "llama: build with support for seqrep sampler" ON) + option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" ON) diff --git a/Makefile b/Makefile index a6d2c2ec0f380..793ac372542c1 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = \ main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ - speculative infill tokenize benchmark-matmult parallel finetune export-lora tests/test-c.o + speculative infill tokenize benchmark-matmult parallel finetune export-lora simple-inference tests/test-c.o # Binaries only useful for tests TEST_TARGETS = \ @@ -572,6 +572,14 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h train.o: common/train.cpp common/train.h $(CXX) $(CXXFLAGS) -c $< -o $@ +ifndef LLAMA_NO_SEQREP_SAMPLER +COMMON_H_DEFS += common/seqrep-sampler.h +COMMON_DEPS += seqrep-sampler.o + +seqrep-sampler.o: common/seqrep-sampler.cpp common/seqrep-sampler.h $(COMMON_H_DEPS) + $(CXX) $(CXXFLAGS) -c $< -o $@ +endif + libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) @@ -594,13 +602,16 @@ infill: examples/infill/infill.cpp ggml.o llama.o $(C simple: examples/simple/simple.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +simple-inference: examples/simple-inference/simple-inference.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + tokenize: examples/tokenize/tokenize.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) batched: examples/batched/batched.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -batched-bench: examples/batched-bench/batched-bench.cpp build-info.o ggml.o llama.o common.o $(OBJS) +batched-bench: examples/batched-bench/batched-bench.cpp build-info.o ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) quantize: examples/quantize/quantize.cpp build-info.o ggml.o llama.o $(OBJS) diff --git a/build.zig b/build.zig index 699738f3dd509..004a89a185315 100644 --- a/build.zig +++ b/build.zig @@ -111,6 +111,8 @@ pub fn build(b: *std.build.Builder) !void { var make = try Maker.init(b); make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false; + try make.addFlag("-DLLAMA_NO_SEQREP_SAMPLER"); + const ggml = make.obj("ggml", "ggml.c"); const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c"); const ggml_backend = make.obj("ggml-backend", "ggml-backend.c"); diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 4f930bdc59059..650106ff28ee1 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -54,6 +54,12 @@ add_library(${TARGET} STATIC train.cpp ) +if (LLAMA_SEQREP_SAMPLER) + target_sources(${TARGET} PRIVATE seqrep-sampler.h seqrep-sampler.cpp) +else() + add_compile_definitions(LLAMA_NO_SEQREP_SAMPLER) +endif() + if (BUILD_SHARED_LIBS) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() diff --git a/common/common.cpp b/common/common.cpp index 1dcc235eac0e6..ad87515c6eab6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,6 +1,10 @@ #include "common.h" #include "llama.h" +#ifndef LLAMA_NO_SEQREP_SAMPLER +#include "seqrep-sampler.h" +#endif + #include #include #include @@ -336,6 +340,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } sparams.penalty_present = std::stof(argv[i]); +#ifndef LLAMA_NO_SEQREP_SAMPLER + } else if (arg == "-seqrep" || arg == "--seqrep-penalty") { + if (++i >= argc) { + invalid_param = true; + break; + } + if (std::strcmp(argv[i], "help") == 0) { + seqrep_sampler_help(); + exit(0); + } + llama_sampler_seqrep_params sr_params; + seqrep_sampler_params_init(&sr_params); + if (!seqrep_sampler_params_parse(argv[i], &sr_params)) { + seqrep_sampler_help(); + exit(1); + } + sparams.seqrep_params.push_back(sr_params); +#endif } else if (arg == "--mirostat") { if (++i >= argc) { invalid_param = true; @@ -770,6 +792,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); +#ifndef LLAMA_NO_SEQREP_SAMPLER + printf(" -seqrep CFG, --seqrep-penalty CFG\n"); + printf(" add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n"); +#endif printf(" --mirostat N use Mirostat sampling.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c2c11c..ec75b944c071b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -103,7 +103,8 @@ llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - const int idx) { + const int idx, + const std::vector & all_last_tokens) { const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -155,6 +156,13 @@ llama_token llama_sampling_sample( prev.data() + prev.size() - penalty_last_n, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); +#ifndef LLAMA_NO_SEQREP_SAMPLER + for (auto & sr_params : params.seqrep_params) { + if ((sr_params.flags & LLAMA_SEQREP_REWIND_MODE) != 0) continue; + llama_sample_seqrep_penalty(ctx_main, &cur_p, all_last_tokens, &sr_params); + } +#endif + if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { diff --git a/common/sampling.h b/common/sampling.h index 7c9b8dcf23bcb..43b0d53d9de5d 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -4,6 +4,10 @@ #include "grammar-parser.h" +#ifndef LLAMA_NO_SEQREP_SAMPLER +#include "seqrep-sampler.h" +#endif + #include #include #include @@ -35,6 +39,11 @@ typedef struct llama_sampling_params { float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens + +#ifndef LLAMA_NO_SEQREP_SAMPLER + std::vector seqrep_params; +#endif + } llama_sampling_params; // general sampler context @@ -101,7 +110,8 @@ llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - int idx = 0); + int idx = 0, + const std::vector & all_last_tokens = {}); void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, diff --git a/common/seqrep-sampler.cpp b/common/seqrep-sampler.cpp new file mode 100644 index 0000000000000..9e0d5614a778f --- /dev/null +++ b/common/seqrep-sampler.cpp @@ -0,0 +1,1014 @@ +#include "llama.h" + +#include "ggml.h" +#include "common.h" + +#include "seqrep-sampler.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +#define SR_FLAG(flags, flag_val) (((flags) & (flag_val)) != 0) + +static std::wstring utf8_to_wstring(const char * start, const char * end) { + if (end == NULL) { + const size_t len = strlen(start); + end = len > 0 ? start + (len - 1) : start; + } + std::wstring_convert> temp; + return temp.from_bytes(start, end); +} + +static std::string wstring_to_string(const std::wstring & s) { + + std::wstring_convert> temp; + return temp.to_bytes(s); +} + +void seqrep_sampler_params_init(llama_sampler_seqrep_params * params) { + assert(params != NULL); + *params = {}; + params->max_length = 24; + params->last_n = 256; + params->mid_word_scale = 0.1f; + params->rewind_max_visits = 2; + params->rewind_ban_length = 1; +} + +static void seqrep_params_dump_flags(const int flags) { + const char *flag_names[] = { + "tolerance_no_consecutive", + "tolerance_no_first", + "penalize_length_max_seen", + "absolute_penalty", + "rewind_mode", + "rewind_skip_ws_punct", + "rewind_use_shortest_match", + "rewind_require_wbound", + "rewind_persist_require_wbound" + }; + for (int i = 0, fcount = 0; i <= 8; i++) { + if ((flags & (1 << i)) != 0) { + printf("%s%s", fcount > 0 ? " + " : "", flag_names[i]); + fcount++; + } + } +} + +void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params) { + assert(params != NULL); + printf("seqrep(last_n = %d, min_length = %zd, max_length = %zd, start_offset = %zd, presence_penalty = %.4f, length_penalty = %.4f, tolerance = %.4f, mid_word_scale = %.4f, tolerance_match_credit = %.4f, tolerance_cap = %.4f, rewind_min_length = %zd, rewind_seek_word_boundary = %zd, flags = [", + params->last_n, params->min_length, params->max_length, params->start_offset, params->presence_penalty, + params->length_penalty, params->tolerance, params->mid_word_scale, + params->tolerance_match_credit, params->tolerance_cap, + params->rewind_min_length, params->rewind_seek_word_boundary); + seqrep_params_dump_flags(params->flags); + puts("])"); +} + +// FIXME: Error handling. +static bool seqrep_load_file(const std::string & filename, std::vector & result) { + std::ifstream fp(filename); + if (!fp) { + return false; + } + std::string buf; + + while (std::getline(fp, buf)) { + while (!buf.empty() && (buf.back() == L'\r' || buf.back() == L'\n')) { + buf.resize(buf.size() - 1); + } + if (!buf.empty()) { + std::wstring temp = utf8_to_wstring(buf.data(), buf.data() + buf.size()); + result.push_back(std::move(temp)); + } + } + return true; +} + +// FIXME: Error handling. More efficient loading? +static bool seqrep_load_regex_file(const std::string & filename, std::vector & result) { + std::vector buf; + + if (!seqrep_load_file(filename, buf)) { + return false; + } + result.clear(); + for (const std::wstring & line : buf) { + if (line.empty() || line.front() == L'#') { + continue; + } + result.emplace_back(line); + } + return true; +} + +void seqrep_sampler_help() { + llama_sampler_seqrep_params p; + + seqrep_sampler_params_init(&p); + printf("==== Sequence Repetition Sampler Help ====\n\n"); + printf(" The sequence repetition sampler takes a configuration string in the format:\n"); + printf(" arg1:arg2:argN\n"); + printf(" A colon separated argument can be a key value pair like xyz=1 or flag like xyz\n"); + printf("\n- Available key/value arguments\n"); + printf(" * repetition_mode=REPEAT_PENALTY\n emulates the repetition penalty sampler. warning: 1.0 disables penalties since this preset enables flag_divide_by_penalty. using 0.0 is probably not what you want\n"); + printf(" * presence_mode=PRESENCE_PENALTY\n emulates the presence penalty sampler\n"); + printf(" * frequency_mode=FREQUENCY_PENALTY\n Emulates the repetition penalty sampler\n"); + printf(" * rewind_mode\n Enables rewind mode and sets skip_ws_punct, require_wbound and persist_require_wbound flags\n"); + printf(" * last_n\n last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", p.last_n); + printf(" * min_length\n minimum matching sequence length (default: %zd, < 2 = disabled)\n", p.min_length); + printf(" * presence_penalty\n presence penalty for tokens that can continue a sequence (default: %f)\n", p.presence_penalty); + printf(" * length_penalty\n penalty for tokens that can continue a sequence, scaled by length (default: %f)\n", p.length_penalty); + printf(" * tolerance\n tolerance for fuzzy matching sequences (default: %f, 0 = disabled)\n", p.tolerance); + printf(" * mid_word_scale\n scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", p.mid_word_scale); + printf(" * tolerance_match_credit\n credit tolerance on matched tokens (default: %f, 0.0 = disabled)\n", p.tolerance_match_credit); + printf(" * tolerance_cap\n Caps tolerance at the specified value. Only meaningful when tolerance_match_credit > 0 (default: %f)\n", p.tolerance_cap); + printf(" * start_offset\n advanced option to set the initial offset for pattern matching. This is relative to the start of last_n. For example, you can set last_n=-1:start_offset=NUM_PROMPT_TOKENS to limit sequence matching to the prompt (default: %zu)\n", p.start_offset); + printf(" * rewind_min_length\n Ensure the sequence is at least the specified length in rewind mode after whitespace skipping and other modifications (default: %zu)\n", p.rewind_min_length); + printf(" * rewind_max_visits\n A position is limited to the specified number of rewinds. When the limit is exceeded, future rewinds cannot target it or earlier tokens. (default: %zu)\n", p.rewind_max_visits); + printf(" * rewind_persist_bans\n Tokens banned by rewind remain banned for an additional number of positions equal to the value. i.e. setting this to 1 would mean the token is banned for 2 positions. (default: %zu)\n", p.rewind_persist_bans); + printf(" * rewind_ban_length\n Number of tokens from the sequence to ban when rewinding. (default: %zu)\n", p.rewind_ban_length); + printf(" * include_re_file=FILENAME\n loads a list of regexps from the file, seqrep matching will only occur if the regex matches\n"); + printf(" * exclude_re_file=FILENAME\n loads a list of regexps from the file, seqrep matching will only occur if the regex matches\n"); + printf("\n- Available flags arguments (currently all default to disabled)\n"); + printf(" * flag_tolerance_no_consecutive\n do not allow using tolerance consecutively\n"); + printf(" * flag_tolerance_no_first\n do not allow using tolerance before the first match\n"); + printf(" * flag_penalize_length_max_seen\n when applying length_penalty, use the maximum seen sequence length rather than the total length of seen sequences\n"); + printf(" * flag_absolute_penalty\n Apply an absolute penalty rather than dividing the logit by the penalty.\n"); + printf(" * flag_rewind_mode\n Rather than penalizing tokens that can continue a sequence, this mode will actually rewind and ban the token that _starts_ the sequence. Note: Requires support in the caller. Also only applies when min_length is at least 2. Most other settings will be ignored in this mode\n"); + printf(" * flag_rewind_skip_ws_punct\n When rewinding, skip past whitespace and punctuation. For example, if the matched sequence was \"'hello\" then we will rewind to the token starting with 'h' and ban it.\n"); + printf(" * flag_rewind_use_shortest_match\n Rewind to the shortest matching sequence of at least min_length rather than the longest. Only meaningful when multiple rewind seqrep samplers are defined.\n"); + printf(" * flag_rewind_require_wbound\n Rewinding requires a word boundary. Only has an effect when rewind_seek_word_boundary isn't 0.\n"); + printf(" * flag_rewind_persist_require_wbound\n Persisted bans are only applied if at a word bound.\n"); + printf("\n- Regex file notes:\n"); + printf(" The regex file should contain one regex per line. Blank lines or lines that start with # are ignored.\n"); + printf(" When matching, the last max_length tokens are converted to a string and invalid unicode is trimmed from the beginning/end.\n"); + printf(" Note: Current regexes only apply for rewind mode seqrep samplers.\n"); + printf("\n- Examples:\n"); + printf(" * repetition_mode=1.2:last_n=32\n same as --repeat-last-n 32 --repeat-penalty 1.2\n"); + printf(" * presence_mode=.2:last_n=32\n same as --repeat-last-n 32 --presence-penalty .2\n"); + printf(" * frequency_mode=.2:last_n=32\n same as --repeat-last-n 32 --frequency-penalty .2\n"); + printf(" * min_length=3:tolerance=1:length_penalty=1.1:last_n=-1\n match repeated sequences of at least 3 tokens within the entire context and apply a penalty of 1 + 0.1*total_length to the token that would continue the sequence. allow one non-matching token in matched sequences.\n"); +} + +bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params) { + assert(params != NULL); + assert(s != NULL); + size_t offset = 0; + std::string sparams = s; + size_t slen = sparams.size(); + + while (offset < slen) { + size_t argsep = sparams.find_first_of(':', offset); + std::string argchunk; + if (argsep == std::string::npos) { + argchunk = sparams.substr(offset); + } else if (argsep > offset) { + argchunk = sparams.substr(offset, argsep - offset); + } + std::string argval; + size_t valsep = argchunk.find_first_of('='); + if (valsep != std::string::npos && valsep < argchunk.size()) { + argval = argchunk.substr(valsep + 1); + argchunk.resize(valsep); + } + if (argchunk.empty() && argval.empty()) { + // pass + } else if (argchunk == "repetition_mode") { + params->last_n = 64; + params->min_length = 1; + params->mid_word_scale = 1.0f; + params->flags = 0; + params->length_penalty = 1.0f; + params->presence_penalty = argval.empty() ? 1.1f : std::atof(argval.c_str()); + } else if (argchunk == "presence_mode") { + params->last_n = 64; + params->min_length = 1; + params->mid_word_scale = 1.0f; + params->flags = LLAMA_SEQREP_ABSOLUTE_PENALTY; + params->length_penalty = 0.0f; + params->presence_penalty = std::atof(argval.c_str()); + } else if (argchunk == "frequency_mode") { + params->last_n = 64; + params->min_length = 1; + params->mid_word_scale = 1.0f; + params->flags = LLAMA_SEQREP_ABSOLUTE_PENALTY; + params->length_penalty = std::atof(argval.c_str()); + params->presence_penalty = 0.0f; + } else if (argchunk == "rewind_mode") { + params->flags = LLAMA_SEQREP_REWIND_REQUIRE_WBOUND + | LLAMA_SEQREP_REWIND_PERSIST_REQUIRE_WBOUND + | LLAMA_SEQREP_REWIND_SKIP_WS_PUNCT + | LLAMA_SEQREP_REWIND_MODE; + } else if (argchunk == "flag_tolerance_no_consecutive") { + params->flags |= LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE; + } else if (argchunk == "flag_tolerance_no_first") { + params->flags |= LLAMA_SEQREP_TOLERANCE_NO_FIRST; + } else if (argchunk == "flag_penalize_length_max_seen") { + params->flags |= LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN; + } else if (argchunk == "flag_absolute_penalty") { + params->flags |= LLAMA_SEQREP_ABSOLUTE_PENALTY; + } else if (argchunk == "flag_rewind_mode") { + params->flags |= LLAMA_SEQREP_REWIND_MODE; + } else if (argchunk == "flag_rewind_skip_ws_punct") { + params->flags |= LLAMA_SEQREP_REWIND_SKIP_WS_PUNCT | LLAMA_SEQREP_REWIND_MODE; + } else if (argchunk == "flag_rewind_use_shortest_match") { + params->flags |= LLAMA_SEQREP_REWIND_USE_SHORTEST_MATCH | LLAMA_SEQREP_REWIND_MODE; + } else if (argchunk == "flag_rewind_require_wbound") { + params->flags |= LLAMA_SEQREP_REWIND_REQUIRE_WBOUND | LLAMA_SEQREP_REWIND_MODE; + } else if (argchunk == "flag_rewind_persist_require_wbound") { + params->flags |= LLAMA_SEQREP_REWIND_PERSIST_REQUIRE_WBOUND | LLAMA_SEQREP_REWIND_MODE; + } else if (argchunk == "min_length") { + params->min_length = std::atoi(argval.c_str()); + } else if (argchunk == "rewind_min_length") { + params->rewind_min_length = std::atoi(argval.c_str()); + } else if (argchunk == "rewind_seek_word_boundary") { + params->rewind_seek_word_boundary = std::atoi(argval.c_str()); + } else if (argchunk == "rewind_max_visits") { + params->rewind_max_visits = std::atoi(argval.c_str()); + } else if (argchunk == "rewind_persist_bans") { + params->rewind_persist_bans = std::atoi(argval.c_str()); + } else if (argchunk == "rewind_ban_length") { + params->rewind_ban_length = std::atoi(argval.c_str()); + } else if (argchunk == "start_offset") { + params->start_offset = std::atoi(argval.c_str()); + } else if (argchunk == "last_n") { + params->last_n = std::atoi(argval.c_str()); + } else if (argchunk == "tolerance") { + params->tolerance = std::atof(argval.c_str()); + } else if (argchunk == "tolerance_cap") { + params->tolerance_cap = std::atof(argval.c_str()); + } else if (argchunk == "presence_penalty") { + params->presence_penalty = std::atof(argval.c_str()); + } else if (argchunk == "length_penalty") { + params->length_penalty = std::atof(argval.c_str()); + } else if (argchunk == "mid_word_scale") { + params->mid_word_scale = std::atof(argval.c_str()); + } else if (argchunk == "tolerance_match_credit") { + params->tolerance_match_credit = std::atof(argval.c_str()); + } else if (argchunk == "include_re_file" && !argval.empty()) { + if (!seqrep_load_regex_file(argval, params->include_re)) { + fprintf(stderr, "seqrep: Failed to read include_re file: %s\n", argval.c_str()); + return false; + }; + } else if (argchunk == "exclude_re_file" && !argval.empty()) { + if (!seqrep_load_regex_file(argval, params->exclude_re)) { + fprintf(stderr, "seqrep: Failed to read exclude_re file: %s\n", argval.c_str()); + return false; + } + } else { + fprintf(stderr, "seqrep: Bad argument [%s]=[%s]!\n", argchunk.c_str(), argval.c_str()); + return false; + } + if (argsep != std::string::npos) { + offset = argsep + 1; + } else { + break; + } + } + if (params->tolerance_cap == 0.0f) { + params->tolerance_cap = params->tolerance; + } + return true; +} + + +// Internal helper function for sequence matching. +static size_t seqrep_find_match( + const llama_token * tail_tokens, + const size_t tail_tokens_size, + const llama_token * search_tokens, + const size_t search_tokens_size, + const bool overlaps, + const llama_sampler_seqrep_params *params) { + + if (params->min_length < 2 + || tail_tokens_size < params->min_length + || search_tokens_size < params->min_length) { + return 0; + } + + int flags = params->flags; + float tolerance = params->tolerance; + size_t tail_steps = 0, search_steps = 0; + + int matches = 0, pending_matches = 0; + bool last_matched = true; + + while (search_steps < search_tokens_size && tail_steps < tail_tokens_size) { + if (*(search_tokens - search_steps) == *(tail_tokens - tail_steps)) { + tail_steps++; + search_steps++; + matches += 1 + pending_matches; + pending_matches = 0; + tolerance += params->tolerance_match_credit; + if (params->tolerance_cap > 0.0f) { + tolerance = std::min(params->tolerance_cap, tolerance); + } + last_matched = true; + continue; + } + + + if (SR_FLAG(flags, LLAMA_SEQREP_TOLERANCE_NO_FIRST) + && search_steps + tail_steps == 0) { + break; + } else if (SR_FLAG(flags, LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE) + && last_matched == false) { + break; + } + + last_matched = false; + + if (tolerance < 1.0f) { + break; + } + tolerance -= 1.0f; + if (search_steps + 1 < search_tokens_size + && *(search_tokens - (search_steps + 1)) == *(tail_tokens - tail_steps)) { + search_steps++; + continue; + } else if (!overlaps || tail_steps + 1 <= search_steps) { + if (tail_steps + 1 < tail_tokens_size && + *(tail_tokens - (tail_steps + 1)) == *(search_tokens - search_steps)) { + tail_steps++; + continue; + } + } + + // A tolerance charge can count as a match, but only if we can find a + // real match before the search is terminated. + pending_matches++; + + tail_steps++; + search_steps++; + } + return matches; +} + +// Note: Only handles partial sequences, can't handle ones that are simply malformed. +static void seqrep_check_utf8( + const char * s, const size_t len, + const char ** first_valid, + const char ** last_valid) { + size_t expect_bytes = 0; + const char * maybe_valid = NULL; + *first_valid = *last_valid = NULL; + + for (size_t i = 0; i < len; i++) { + const uint8_t c = uint8_t(s[i]); + + if (expect_bytes > 0) { + expect_bytes--; + // 10xxxxxxb -> 10b == 2 + if (c >> 6 == 2) { + if (expect_bytes == 0) { + if (*first_valid == NULL) { + *first_valid = maybe_valid != NULL ? maybe_valid : s + i; + } + *last_valid = s + i; + maybe_valid = NULL; + } + } else { + // Invalid sequence + maybe_valid = *first_valid = *last_valid = NULL; + expect_bytes = 0; + } + continue; + } + + // Not in a sequence. First check for a single byte character. + if ((c & 128) == 0) { + if (*first_valid == NULL) { + *first_valid = s + i; + } + *last_valid = s + i; + maybe_valid = NULL; + continue; + } + + // If we end up here it's either the start of a multi byte sequence or invalid. + maybe_valid = s + i; + // 110xxxxxb -> 110b == 6 + if (c >> 5 == 6) { + expect_bytes = 1; + // 1110xxxxb -> 1110b == 14 + } else if (c >> 4 == 14) { + expect_bytes = 2; + // 11110xxxb -> 11110b == 30 + } else if (c >> 3 == 30) { + expect_bytes = 3; + // Invalid + } else { + maybe_valid = *first_valid = *last_valid = NULL; + } + } +} + +// FIXME: Make this efficient. +static std::wstring seqrep_get_tail_string(const struct llama_context * ctx, const std::vector & last_tokens, size_t len) { + const size_t last_tokens_len = last_tokens.size(); + std::string buf; + + len = std::min(len, last_tokens_len); + if (len == 0) return std::wstring(); + + buf.reserve(8 * len); + + const llama_token *curr_token = last_tokens.data() + (last_tokens_len - len); + + for (size_t i = 0; i < len; i++, curr_token++) { + buf.append(llama_token_to_piece(ctx, *curr_token)); + } + + const char * first_valid = NULL, * last_valid = NULL; + + if (!buf.empty()) { + seqrep_check_utf8(buf.data(), buf.size(), &first_valid, &last_valid); + } + if (first_valid == NULL) return std::wstring(); + + return utf8_to_wstring(first_valid, last_valid + 1); +} + +// Helper function for sequence matching. +// Bit 1 set indicates token is a word boundary. NL, " blah", "," - word boundary. "blah", "blah:" - not a word boundary. +// Bit 2 set indicates token ends on a word boundary. NL, "blah:", "blah " - ends on word boundary. " blah", "blah" - doesn't end on word boundary. +// Bit 3 set indicates all codepoints in the character count as boundary. +// FIXME: Handle partial/invalid UTF8 (just crashes currently). +int llama_seqrep_check_word( + const struct llama_context * ctx, + const llama_token token, + std::vector & buf) { + const llama_model * model = llama_get_model(ctx); + if (token == llama_token_bos(model) || token == llama_token_eos(model) || token == llama_token_nl(model)) { + // BOS, EOS, NL are always a boundary. + return SEQREP_CW_START_IS_WBOUND | SEQREP_CW_END_IS_WBOUND | SEQREP_CW_ALL_WS_PUNCT; + } + if (buf.size() < 128) buf.resize(128); + + int n_chars = llama_token_to_piece(model, token, buf.data(), buf.size() - 1); + if (n_chars < 0) { + buf.resize(size_t(-n_chars) + 128); + const int check = llama_token_to_piece(model, token, buf.data(), buf.size() - 1); + GGML_ASSERT(check == -n_chars); + n_chars = check; + } else if (n_chars == 0) { + return 0; + } + buf[n_chars] = 0; + + const char * first_valid = NULL, * last_valid = NULL; + + seqrep_check_utf8(buf.data(), n_chars, &first_valid, &last_valid); + + // If first_valid != NULL then last_valid also must be != NULL. + if (first_valid == NULL) { + return SEQREP_CW_START_IS_WBOUND | SEQREP_CW_END_IS_WBOUND + | SEQREP_CW_START_IS_INVALID | SEQREP_CW_END_IS_INVALID; + } + + int result = 0; + const bool start_invalid = first_valid > buf.data(); + const bool end_invalid = last_valid < (buf.data() + (n_chars - 1)); + std::wstring decoded = utf8_to_wstring(first_valid, last_valid + 1); + size_t decoded_len = decoded.size(); + + if (start_invalid) result |= SEQREP_CW_START_IS_INVALID; + if (end_invalid) result |= SEQREP_CW_END_IS_INVALID; + if (decoded_len == 0) return result; + + // Can only be all punctuation if the full sequence is valid. + result |= !start_invalid && !end_invalid ? SEQREP_CW_ALL_WS_PUNCT : 0; + + for (size_t i = 0; i < decoded_len; i++) { + wchar_t c = decoded[i]; + bool iswbound = c != L'\'' && c != L'’' && (std::iswpunct(c) || std::iswspace(c)); + + if (!iswbound) { + result &= ~SEQREP_CW_ALL_WS_PUNCT; + continue; + } + + if (i == 0 && !start_invalid) + result |= SEQREP_CW_START_IS_WBOUND; + if (i == decoded_len - 1 && !end_invalid) + result |= SEQREP_CW_END_IS_WBOUND; + } + return result; +} + +static void seqrep_apply_penalties( + const struct llama_context * ctx, + const llama_token * last_tokens_p, + const size_t last_tokens_size, + llama_token_data_array * candidates, + const llama_sampler_seqrep_params * params, + const std::unordered_map & penalize_tokens) { + std::vector buf(128, 0); + const int flags = params->flags; + + const bool ends_on_word = params->mid_word_scale == 1.0f + || SR_FLAG(llama_seqrep_check_word(ctx, last_tokens_p[last_tokens_size - 1], buf), SEQREP_CW_END_IS_WBOUND); + + for (size_t i = 0; i < candidates->size; ++i) { + auto pt_iter = penalize_tokens.find(candidates->data[i].id); + if (pt_iter == penalize_tokens.end()) { + continue; + } + + const size_t count = pt_iter->second; + const bool pt_starts_word = params->mid_word_scale == 1.0f || + SR_FLAG(llama_seqrep_check_word(ctx, candidates->data[i].id, buf), SEQREP_CW_START_IS_WBOUND); + float penalty_scale = ends_on_word || pt_starts_word ? 1.0f : params->mid_word_scale; + float logit = candidates->data[i].logit; + + if (SR_FLAG(flags, LLAMA_SEQREP_ABSOLUTE_PENALTY)) { + float penalty = + ( float(count) * params->length_penalty + + float(count > 0) * params->presence_penalty ); + logit -= penalty * penalty_scale; + } else { + const float l_penalty = (params->length_penalty != 0 ? params->length_penalty : 1.0) - 1.0; + const float p_penalty = (params->presence_penalty != 0 ? params->presence_penalty : 1.0) - 1.0; + + // This looks complicated. The point is to scale be able to scale penalties like + // 1.2. For example, suppose length penalty is 1.2 and length is 3. 1.2 * 3 = 3.6 + // would be ridiculous. What we actually want is more like 1.6. + // An alternative approach would be to iteratively apply the scale. + // 10.0 / 1.6 == 6.25, however ((10.0 / 1.2) / 1.2) / 1.2 == 5.787 + float penalty = + ( (float(count) * l_penalty) + + (float(count > 0) * p_penalty) ) * penalty_scale + + 1.0f; + if (logit <= 0) { + logit *= penalty; + } else if (penalty != 0.0f) { + logit /= penalty; + } + } + candidates->data[i].logit = logit; + } + +} + + +size_t llama_sample_seqrep_penalty( + struct llama_context * ctx, + llama_token_data_array * candidates, + const std::vector & last_tokens, + const llama_sampler_seqrep_params * params) { + + const size_t min_length = params->min_length; + const int flags = params->flags; + size_t last_tokens_size = last_tokens.size(); + const llama_token *last_tokens_p = last_tokens.data(); + + if (params->last_n == 0 || params->min_length < 1) { + return 0; + } else if (params->last_n > 0) { + size_t window_offset = last_tokens_size - std::min(size_t(params->last_n), last_tokens_size); + + last_tokens_size -= window_offset; + last_tokens_p += window_offset; + } + + if (last_tokens_size == 0 || (min_length > 1 && last_tokens_size <= min_length)) { + return 0; + } else if (!SR_FLAG(params->flags, LLAMA_SEQREP_REWIND_MODE)) { + const float disabled = SR_FLAG(params->flags, LLAMA_SEQREP_ABSOLUTE_PENALTY) ? 0.0f : 1.0f; + // We accept 0.0 here even when the penalty isn't absolute because a non-absolute + // penalty of 0.0 implies divide by zero which makes no sense. + if ( (params->presence_penalty == disabled || params->presence_penalty == 0) + && (params->length_penalty == disabled || params->length_penalty == 0)) { + return 0; + } + } + + if (params->mid_word_scale != 1.0f || SR_FLAG(params->flags, LLAMA_SEQREP_REWIND_SKIP_WS_PUNCT)) { + // Only need ctx when mid_word_scale or REWIND_SKIP_WS_PUNCT flag is in effect. + assert(ctx); + } + + // const int64_t t_start_sample_us = ggml_time_us(); + + // This will hold a map of token ids that can continue the sequence with its sequence length. + std::unordered_map penalize_tokens; + + if (min_length > 1) { + // Normal sequence matching mode. + size_t start_offset = params->start_offset; + size_t max_matched_length = 0; + size_t min_matched_length = last_tokens_size; + + if (start_offset == 0 || start_offset >= last_tokens_size - 1) { + start_offset = last_tokens_size - 2; + } + + const llama_token * tail_p = last_tokens_p + (last_tokens_size - 1); + const size_t tail_len = std::min(params->max_length, last_tokens_size); + + for (size_t offset = start_offset; offset >= min_length - 1; offset--) { + const llama_token * search_p = last_tokens_p + offset; + const size_t search_len = std::min(params->max_length, last_tokens_size - (offset + 1)); + const size_t matched_length = + seqrep_find_match(tail_p, tail_len, search_p, search_len, true, params); + + if (matched_length < min_length) { + continue; + } + + max_matched_length = std::max(max_matched_length, matched_length); + min_matched_length = std::min(min_matched_length, matched_length); + + // The token one past where we started trying to match is the one that could continue + // the previously observed sequence. + llama_token penalize_token = last_tokens_p[offset + 1]; + + auto pt_iter = penalize_tokens.find(penalize_token); + if (pt_iter == penalize_tokens.end() + || !SR_FLAG(flags, LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN)) { + penalize_tokens[penalize_token] += matched_length; + } else { + penalize_tokens[penalize_token] = std::max(pt_iter->second, matched_length); + } + } + + if ((flags & LLAMA_SEQREP_REWIND_MODE) != 0) { + size_t result = !SR_FLAG(flags, LLAMA_SEQREP_REWIND_USE_SHORTEST_MATCH) || max_matched_length < min_length + ? max_matched_length + : min_matched_length; + + if (max_matched_length > 0 && SR_FLAG(params->flags, LLAMA_SEQREP_REWIND_SKIP_WS_PUNCT)) { + std::vector buf(128, 0); + for (size_t i = last_tokens_size - result; i < last_tokens_size; i++) { + if (SR_FLAG(llama_seqrep_check_word(ctx, last_tokens_p[i], buf), SEQREP_CW_ALL_WS_PUNCT)) { + result--; + } else { + break; + } + } + } + return result; + } + } else { + // Single token matching mode. Can emulate existing repetition, presence and frequency samplers. + size_t start_offset = params->start_offset; + + if (start_offset == 0 || start_offset >= last_tokens_size) { + start_offset = last_tokens_size - 1; + } + + for (int i = int(start_offset); i >= 0; i--) { + llama_token penalize_token = last_tokens_p[i]; + + if (SR_FLAG(flags, LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN)) { + penalize_tokens[penalize_token] = 1; + } else { + penalize_tokens[penalize_token]++; + } + } + } + + seqrep_apply_penalties(ctx, last_tokens_p, last_tokens_size, candidates, params, penalize_tokens); + + if (!penalize_tokens.empty()) { + candidates->sorted = false; + } + + // FIXME: Find a way to set stuff in ctx + // if (ctx) { + // ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + // } + return 0; +} + +seqrep_logit_info::seqrep_logit_info(llama_context * ctx, const size_t k, const int32_t ith) + : n_vocab(llama_n_vocab(llama_get_model(ctx))) + , token_data(top_k(llama_get_logits_ith(ctx, ith), k)) + { } + +const std::vector & seqrep_logit_info::get_token_data(void) { + return token_data; +} + +llama_token_data seqrep_logit_info::get_token_id(const llama_token token_id) const { + for (const llama_token_data & td : token_data) { + if (td.id == token_id) + return td; + } + return {-1, 0, 0}; +} + +void seqrep_logit_info::rebuild(llama_context *ctx, const size_t k, const int32_t ith) { + token_data = top_k(llama_get_logits_ith(ctx, ith), k); +} + +void seqrep_logit_info::populate_logits(float * logits) { + const float neginf = std::numeric_limits::infinity() * -1; + for (int i = 0; i < n_vocab; i++) { + logits[i] = neginf; + } + for (const llama_token_data & td : token_data) { + logits[td.id] = td.logit; + } +} + +// Yoinked from beam search code. +// Return top k token_data by logit. +std::vector seqrep_logit_info::top_k( + const float * const logits, + const size_t k) { + + std::vector min_heap; // min-heap by logit + const llama_token k_min = std::min(static_cast(k), n_vocab); + min_heap.reserve(k_min); + constexpr auto p = std::numeric_limits::quiet_NaN(); // never used + for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) { + const llama_token_data td = {token_id, logits[token_id], p}; + min_heap.push_back(td); + } + auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; + std::make_heap(min_heap.begin(), min_heap.end(), comp); + for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) { + if (min_heap.front().logit < logits[token_id]) { + std::pop_heap(min_heap.begin(), min_heap.end(), comp); + min_heap.back().id = token_id; + min_heap.back().logit = logits[token_id]; + std::push_heap(min_heap.begin(), min_heap.end(), comp); + } + } + return min_heap; +} + + +seqrep_rewind_state::seqrep_rewind_state( + const size_t n_vocab, + const size_t n_ctx, + const size_t k) + : n_vocab(n_vocab) + , n_ctx(n_ctx) + , k(k) +{ + logit_slots.reserve(n_ctx); + rewind_slots.resize(n_ctx); +} + +void seqrep_rewind_state::set_logits_slot(llama_context * ctx, const size_t idx, const int32_t ith) { + GGML_ASSERT(idx <= logit_slots.size()); + if (idx == logit_slots.size()) { + logit_slots.emplace_back(ctx, k, ith); + } else { + logit_slots[idx].rebuild(ctx, k, ith); + } +} + +struct seqrep_rewind_slot & seqrep_rewind_state::get_rewind_slot(const size_t idx) { + GGML_ASSERT(idx <= rewind_slots.size()); + return rewind_slots[idx]; +} + +void seqrep_rewind_state::populate_logits(llama_context * ctx, const size_t idx, const int32_t ith) { + logit_slots[idx].populate_logits(llama_get_logits_ith(ctx, ith)); +} + +static size_t seqrep_check_rewind_internal( + struct llama_context * ctx, + const std::vector & last_tokens, + const std::vector & params_list, + const llama_sampler_seqrep_params & merged_params, + size_t * high_water_mark) { + const size_t last_tokens_size = last_tokens.size(); + + size_t min_matched_len = 0, max_matched_len = 0; + + for (auto & sr_params : params_list) { + if (!SR_FLAG(sr_params.flags, LLAMA_SEQREP_REWIND_MODE)) + continue; + const size_t matched_len = llama_sample_seqrep_penalty(ctx, NULL, last_tokens, &sr_params); + max_matched_len = std::max(max_matched_len, matched_len); + min_matched_len = min_matched_len == 0 + ? matched_len + : std::min(min_matched_len, matched_len); + } + if (max_matched_len < 2 || max_matched_len >= last_tokens_size) { + return 0; + } + + const size_t matched_len = !SR_FLAG(merged_params.flags, LLAMA_SEQREP_REWIND_USE_SHORTEST_MATCH) + ? max_matched_len + : min_matched_len; + size_t idx = last_tokens_size - matched_len; + + if (idx < *high_water_mark) { + if (*high_water_mark >= last_tokens_size - 2) { + return 0; + } + idx = *high_water_mark; + } + + if (merged_params.rewind_seek_word_boundary > 0) { + std::vector buf(128, 0); + const size_t orig_idx = idx; + bool found_idx = false; + + for (size_t steps = merged_params.rewind_seek_word_boundary + 1; idx >= *high_water_mark && steps > 0; idx--, steps--) { + if (SR_FLAG(llama_seqrep_check_word(ctx, last_tokens[idx], buf), SEQREP_CW_START_IS_WBOUND) + || SR_FLAG(llama_seqrep_check_word(ctx, last_tokens[idx - 1], buf), SEQREP_CW_END_IS_WBOUND)) { + found_idx = true; + break; + } + } + if (!found_idx) { + idx = orig_idx; + for (size_t steps = merged_params.rewind_seek_word_boundary + 1; idx < last_tokens_size && steps > 0; idx++, steps--) { + if (SR_FLAG(llama_seqrep_check_word(ctx, last_tokens[idx], buf), SEQREP_CW_START_IS_WBOUND) + || SR_FLAG(llama_seqrep_check_word(ctx, last_tokens[idx - 1], buf), SEQREP_CW_END_IS_WBOUND)) { + found_idx = true; + break; + } + } + if (!found_idx || last_tokens_size - idx < merged_params.rewind_min_length) { + if (SR_FLAG(merged_params.flags, LLAMA_SEQREP_REWIND_REQUIRE_WBOUND)) { + return 0; + } + idx = orig_idx; + } + } + } + + const size_t rewind_distance = last_tokens.size() - idx; + if (merged_params.rewind_min_length != 0 && rewind_distance < merged_params.rewind_min_length) { + return 0; + } + + return rewind_distance; +} + +size_t llama_seqrep_handle_rewind( + struct llama_context * ctx, + struct seqrep_rewind_state & rewind_state, + const std::vector & generated_tokens, + const size_t n_generated, + const std::vector & prompt_tokens, + const std::vector & params_list, + size_t * high_water_mark, + const int32_t ith) { + const size_t prompt_tokens_size = prompt_tokens.size(); + + if (n_generated < 3) { + return 0; + } + + // FIXME: This copying is inefficient. + std::vector last_tokens; + // printf("<%zu,%zu,%zu>", prompt_tokens_size, generated_tokens.size(), n_generated); + // fflush(stdout); + last_tokens.resize(n_generated + prompt_tokens.size()); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_tokens.begin()); + std::copy(generated_tokens.begin(), generated_tokens.end(), last_tokens.begin() + prompt_tokens.size()); + + llama_sampler_seqrep_params merged_params = llama_seqrep_merge_params(params_list, LLAMA_SEQREP_REWIND_MODE, 0); + size_t rewind_distance = 0; + size_t slot_idx, token_idx; + std::vector rewind_token_text_buf(128, 0); + + while (true) { + rewind_distance = seqrep_check_rewind_internal( + ctx, last_tokens, params_list, merged_params, high_water_mark ); + + if (rewind_distance == 0) + break; + + GGML_ASSERT(rewind_distance < n_generated); + slot_idx = n_generated - rewind_distance; + token_idx = n_generated + prompt_tokens_size - rewind_distance; + + const size_t ban_length = std::min(rewind_distance, merged_params.rewind_ban_length); + struct seqrep_rewind_slot &rw_slot = rewind_state.get_rewind_slot(slot_idx); + const bool at_wbound = token_idx == 0 || + SR_FLAG(llama_seqrep_check_word(ctx, last_tokens[token_idx - 1], rewind_token_text_buf), SEQREP_CW_END_IS_WBOUND); + + for (size_t i = slot_idx; i < slot_idx + ban_length; i++) { + const llama_token penalize_token = generated_tokens[i]; + if (i > slot_idx + && !at_wbound + && !SR_FLAG(llama_seqrep_check_word(ctx, penalize_token, rewind_token_text_buf), SEQREP_CW_START_IS_WBOUND)) { + continue; + } + if (std::find(rw_slot.tokens.begin(), rw_slot.tokens.end(), penalize_token) == rw_slot.tokens.end()) { + rw_slot.tokens.push_back(penalize_token); + } + } + + if (++rw_slot.count >= merged_params.rewind_max_visits) { + // This slot already hit max visits so we can set the HWM to the index one past it. + *high_water_mark = token_idx + 1; + } + + GGML_ASSERT(slot_idx > 0); + break; + } + + if (rewind_distance == 0) return 0; + + { + const std::wstring tail = !merged_params.include_re.empty() || !merged_params.exclude_re.empty() + ? seqrep_get_tail_string(ctx, last_tokens, rewind_distance + 8) + : std::wstring(); + for (const auto & re : merged_params.include_re) { + if (!std::regex_search(tail, re)) return 0; + } + for (const auto & re : merged_params.exclude_re) { + if (std::regex_search(tail, re)) return 0; + } + // { + // std::string x = wstring_to_string(tail); + // printf(" [[ %s ]] ", x.c_str()); + // } + } + + GGML_ASSERT(slot_idx > 0 && "Invalid slot for populate logits"); + rewind_state.populate_logits(ctx, slot_idx, ith); + + float * logits = llama_get_logits_ith(ctx, ith); + const float neg_infinity = std::numeric_limits::infinity() * -1; + const size_t target_idx = token_idx; + const bool at_wbound = target_idx == 0 || + SR_FLAG(llama_seqrep_check_word(ctx, last_tokens[target_idx - 1], rewind_token_text_buf), SEQREP_CW_END_IS_WBOUND); + const bool persist_require_wbound = SR_FLAG(merged_params.flags, LLAMA_SEQREP_REWIND_PERSIST_REQUIRE_WBOUND); + const size_t persist_count = std::min(prompt_tokens_size - target_idx, merged_params.rewind_persist_bans); + + for (size_t i = target_idx - persist_count; i <= target_idx; i++) { + // FIXME: There's a better way to calculate this. + if (i < prompt_tokens_size) { + continue; + } + if (persist_require_wbound && i != target_idx && !at_wbound) { + // We don't apply this logic when i == target_idx because the previous + // checks should have taken it into account when the specific ban was applied + // initially. + continue; + } + for (const llama_token token_id : rewind_state.get_rewind_slot(i - prompt_tokens_size).tokens) { + logits[token_id] = neg_infinity; + } + } + + return rewind_distance; +} + + +// Note: Doesn't merge presence or length penalties because of the divide_by_penalty flag. +struct llama_sampler_seqrep_params llama_seqrep_merge_params( + const std::vector & params_list, + const int and_flags, + const int not_flags) { + struct llama_sampler_seqrep_params result = {}; + + for (auto & sr_params : params_list) { + if ((sr_params.flags & and_flags) != and_flags || (sr_params.flags & not_flags) != 0) { + continue; + } + result.flags |= sr_params.flags; + result.min_length = std::max(result.min_length, sr_params.min_length); + result.max_length = std::max(result.max_length, sr_params.max_length); + result.last_n = sr_params.last_n < 0 || result.last_n < 0 + ? -1 + : std::max(result.last_n, sr_params.last_n); + result.tolerance = std::max(result.tolerance, sr_params.tolerance); + result.mid_word_scale = std::max(result.mid_word_scale, sr_params.mid_word_scale); + result.tolerance_match_credit = std::max(result.tolerance_match_credit, sr_params.tolerance_match_credit); + result.rewind_min_length = std::max(result.rewind_min_length, sr_params.rewind_min_length); + result.rewind_seek_word_boundary = std::max(result.rewind_seek_word_boundary, sr_params.rewind_seek_word_boundary); + result.rewind_max_visits = std::max(result.rewind_max_visits, sr_params.rewind_max_visits); + result.rewind_persist_bans = std::max(result.rewind_persist_bans, sr_params.rewind_persist_bans); + result.rewind_ban_length = std::max(result.rewind_ban_length, sr_params.rewind_ban_length); + // FIXME: Copying like this isn't ideal. + result.include_re.insert(result.include_re.end(), sr_params.include_re.begin(), sr_params.include_re.end()); + result.exclude_re.insert(result.exclude_re.end(), sr_params.exclude_re.begin(), sr_params.exclude_re.end()); + } + return result; +} diff --git a/common/seqrep-sampler.h b/common/seqrep-sampler.h new file mode 100644 index 0000000000000..f56b904ae0faa --- /dev/null +++ b/common/seqrep-sampler.h @@ -0,0 +1,200 @@ +#pragma once + +#include + +#include +#include + +#include "llama.h" + +enum llama_sampler_seqrep_flags { + // Tolerance charges can't be used consecutively. + LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE = (1 << 0), + + // Tolerance charges can't be used before the first actual match. + LLAMA_SEQREP_TOLERANCE_NO_FIRST = (1 << 1), + + // When applying the length penalty, use the length of the longest observed + // sequence matching the token rather than the total length of + // sequences matching the token. In other words, if we find a sequence + // of length 3 and a sequence of length 4 continued by token 69 then + // with this flag on we penalize based on length 4, with it off we + // penalize based on length 7 (3 + 4). + LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN = (1 << 2), + + // Apply an absolute penalty rather than dividing the logit by the penalty. + LLAMA_SEQREP_ABSOLUTE_PENALTY = (1 << 3), + + // Rewind to cut off the head of sequences rather than the end. + // Ignored when min_length < 2. + // Since it wouldn't make sense to rewind and then let sampling pick + // the same token again, penalty values and mid_word_scale have no + // effect. + LLAMA_SEQREP_REWIND_MODE = (1 << 4), + + // When rewinding, skip past whitespace and punctuation. For example, + // if the matched sequence was "'hello" then we will rewind to the + // token starting with 'h' and ban it. + LLAMA_SEQREP_REWIND_SKIP_WS_PUNCT = (1 << 5), + + // Rewind to the shortest matching sequence of at least min_length rather than the longest. + LLAMA_SEQREP_REWIND_USE_SHORTEST_MATCH = (1 << 6), + + // Rewinding requires a word boundary. Only has an effect when rewind_seek_word_boundary isn't 0. + LLAMA_SEQREP_REWIND_REQUIRE_WBOUND = (1 << 7), + + // Persisted bans are only applied if at a word bound. + LLAMA_SEQREP_REWIND_PERSIST_REQUIRE_WBOUND = (1 << 8), +}; + +typedef struct llama_sampler_seqrep_params { + // The minimum length of a matching sequence of tokens. When this is < 2 then + // the sampler works in single token mode and tolerance is ignored. + size_t min_length; + + // Maximum length for a matching sequence of tokens. + size_t max_length; + + // Starting offset for matching against the end of the sequence. This can be used + // to only match against sequences in the initial prompt, for example. Matching + // starts at the offset and moves toward the beginning of the list. + // Use 0 for penultimate token when min_length > 1 otherwise 0 for last token. + size_t start_offset; + + // Window of last tokens to consider, starting from the end. < 0 means + // the whole list. + int last_n; + + // Flags based on llama_sampler_seqrep_flags enum values ORed together. + int flags; + + // Tolerance for non-matching tokens in a sequence. + float tolerance; + + // Flat penalty applied to the token that can continue a repeated sequence. + float presence_penalty; + + // Scaling penalty applied to the token that can continue a repeated sequence. + // The penalty is multiplied by the total length of sequences that are continued by this token unless + // the PENALIZE_LENGTH_MAX_SEEN is set. + float length_penalty; + + // Scale for penalizing tokens from repeated sequences that aren't at/form a word boundary. + float mid_word_scale; + + // Tolerance credit per real match. I.E. .5 means +1 tolerance per 2 matched tokens. + float tolerance_match_credit; + + // Caps tolerance at the specified value. Only meaningful when tolerance_match_credit > 0 + float tolerance_cap; + + // Ensure the sequence is at least the specified length in rewind mode after + // whitespace skipping and other modifications. + size_t rewind_min_length; + + // When rewinding, try to find a word boundary within the specified distance, starting with tokens earlier than the rewind point. + size_t rewind_seek_word_boundary; + + // A position is limited to the specified number of rewinds. When the limit is exceeded, future rewinds cannot target it or earlier tokens. + size_t rewind_max_visits; + + // Tokens banned by rewind remain banned for an additional number of positions equal to the value. i.e. setting this to 1 would mean the token is banned for 2 positions. + size_t rewind_persist_bans; + + // Number of tokens from the sequence to ban when rewinding. + size_t rewind_ban_length; + + std::vector include_re; + std::vector exclude_re; +} llama_sampler_seqrep_params; + +enum seqrep_check_word_flags { + SEQREP_CW_START_IS_WBOUND = 1 << 0, + SEQREP_CW_END_IS_WBOUND = 1 << 1, + SEQREP_CW_ALL_WS_PUNCT = 1 << 2, + SEQREP_CW_START_IS_INVALID = 1 << 3, // Start of token is invalid/incomplete UTF8 + SEQREP_CW_END_IS_INVALID = 1 << 4 // End of token is invalid/incomplete UTF8 +}; + + +struct seqrep_logit_info { + const int n_vocab; + std::vector token_data; + + seqrep_logit_info(llama_context * ctx, const size_t k, const int32_t ith); + + const std::vector & get_token_data(void); + + llama_token_data get_token_id(const llama_token token_id) const; + + void rebuild(llama_context *ctx, const size_t k, const int32_t ith); + + void populate_logits(float * logits); + + // Yoinked from beam search code. + // Return top k token_data by logit. + std::vector top_k(const float * const logits, const size_t k); + + seqrep_logit_info(const int n_vocab, const std::vector & token_data = {}) + : n_vocab(n_vocab) + , token_data(token_data) + {} +}; + +struct seqrep_rewind_slot { + size_t count; + std::vector tokens; + struct llama_sampling_context * ctx_sampling = nullptr; +}; + +struct seqrep_rewind_state { + const size_t n_vocab; + const size_t n_ctx; + const size_t k; + + std::vector logit_slots; + std::vector rewind_slots; + + seqrep_rewind_state( + const size_t n_vocab, + const size_t n_ctx, + const size_t k = 2000); + + struct seqrep_rewind_slot & get_rewind_slot(const size_t idx); + + void set_logits_slot(llama_context * ctx, const size_t idx, const int32_t ith = 0); + + void populate_logits(llama_context * ctx, const size_t idx, const int32_t ith = 0); + +}; + +// Sequence repetition penalty with semi-fuzzy matching. Note: Handles the last_n window itself. +size_t llama_sample_seqrep_penalty( + struct llama_context * ctx, + llama_token_data_array * candidates, + const std::vector & last_tokens, + const llama_sampler_seqrep_params * params); + +int llama_seqrep_check_word( + const struct llama_context * ctx, + const llama_token token, + std::vector & buf); + +size_t llama_seqrep_handle_rewind( + struct llama_context * ctx, + struct seqrep_rewind_state & rewind_state, + const std::vector & generated_tokens, + const size_t n_generated, + const std::vector & prompt_tokens, + const std::vector & params_list, + size_t * high_water_mark, + const int32_t ith = 0); + +void seqrep_sampler_help(); +void seqrep_sampler_params_init(llama_sampler_seqrep_params * params); +void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params); +bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params); +struct llama_sampler_seqrep_params llama_seqrep_merge_params( + const std::vector & params_list, + const int and_flags, + const int not_flags); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 71bcb6893e20d..711241ad3efdd 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -31,6 +31,7 @@ else() add_subdirectory(quantize-stats) add_subdirectory(save-load-state) add_subdirectory(simple) + add_subdirectory(simple-inference) add_subdirectory(speculative) add_subdirectory(train-text-from-scratch) if (LLAMA_METAL) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 31ec8cade19be..f8eec2cdae019 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -430,6 +430,11 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); +#ifndef LLAMA_NO_SEQREP_SAMPLER + for (auto & sr_params : sparams.seqrep_params) { + seqrep_sampler_params_dump(&sr_params); + } +#endif LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); diff --git a/examples/simple-inference/CMakeLists.txt b/examples/simple-inference/CMakeLists.txt new file mode 100644 index 0000000000000..39294d4bb6a1b --- /dev/null +++ b/examples/simple-inference/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET simple-inference) +add_executable(${TARGET} simple-inference.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/simple-inference/simple-inference.cpp b/examples/simple-inference/simple-inference.cpp new file mode 100644 index 0000000000000..022b69baa0fcb --- /dev/null +++ b/examples/simple-inference/simple-inference.cpp @@ -0,0 +1,986 @@ +// Defines sigaction on msys: +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "common.h" + +#include "console.h" +#include "llama.h" +#include "grammar-parser.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#define SI_DUMP_SEQUENCES_INTERVAL 40 + +static std::atomic interrupted {false}; +static std::atomic done {false}; + +typedef struct tokens_chunk { + bool is_input; + size_t consumed; + std::vector tokens; + + tokens_chunk(const bool is_input = false, const size_t consumed = 0, const std::vector & tokens = {}) + : is_input(is_input) + , consumed(consumed) + , tokens(tokens) + {} +} tokens_chunk; + +enum seq_state { + SEQ_GENERATING, + SEQ_SHARE_PROMPT, + SEQ_INPUT, + SEQ_DONE, +}; + +typedef struct seq_ctx { + llama_seq_id seq_id; + int32_t batch_idx; + enum seq_state state; + size_t n_remain; + size_t n_toks; // Note: Does not include initial prompt size. + llama_sampling_context *ctx_sampling; + + llama_token last_sampled; + std::vector chunks; +#ifndef LLAMA_NO_SEQREP_SAMPLER + size_t high_water_mark; + struct seqrep_rewind_state rewind_state; + size_t rewind_count; + size_t rewind_tokens; +#endif +} seq_ctx; + + +typedef struct gen_ctx { + llama_context * ctx = nullptr; + llama_model * model = nullptr; + llama_sampling_context * ctx_sampling = nullptr; + + llama_batch batch; + gpt_params params; + llama_sampling_params & sparams = params.sparams; + + + int n_ctx; + int n_vocab; + + std::vector scratch; + std::vector prompt_tokens; + size_t prompt_size = 0; + + llama_seq_id focused_sequence = 0; + + size_t decode_count = 0; + int64_t decode_time_total = 0, decode_time_last = 0; + + std::vector ctxs_seq; + + private: + bool init_params(const int argc, char ** argv); + bool init_model(); + bool init_prompt(); + bool init_handlers(); + bool init_sampling(); + bool init_batches(); + + public: + gen_ctx(const int argc, char ** argv); + ~gen_ctx(); + void dump_batches(const size_t prompt_start = 0); + void dump_chunks(const std::vector & chunks, const size_t start_offset = 0); + void handle_seq(seq_ctx & sctx); +#ifndef LLAMA_NO_SEQREP_SAMPLER + void handle_seq_seqrep(seq_ctx & sctx); +#endif + bool feed_prompt( + const std::vector & tokens, + llama_pos pos = 0, + llama_seq_id seq = 0); + bool go(); +} gen_ctx; + + +static void concat_chunks(const std::vector & chunks, std::vector & dst, const size_t start_offset) { + size_t offset = 0; + + for (const tokens_chunk & chunk : chunks) { + if (offset + chunk.tokens.size() <= start_offset) { + offset += chunk.tokens.size(); + continue; + } + + const size_t chunk_offset = offset < start_offset ? start_offset - offset : 0; + const size_t chunk_size = chunk.tokens.size() - chunk_offset; + const llama_token * tp = chunk.tokens.data() + chunk_offset; + + for (size_t i = 0; i < chunk_size; i++, tp++) { + dst.push_back(*tp); + } + offset += chunk.tokens.size(); + } +} + + +static void write_logfile( + const llama_context * ctx, const gpt_params & params, const llama_model * model, + const std::vector input_tokens, const std::string output, const std::vector output_tokens) { + + if (params.logdir.empty()) { + return; + } + + const std::string timestamp = get_sortable_timestamp(); + + const bool success = create_directory_with_parents(params.logdir); + if (!success) { + fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", + __func__, params.logdir.c_str()); + return; + } + + const std::string logfile_path = params.logdir + timestamp + ".yml"; + FILE * logfile = fopen(logfile_path.c_str(), "w"); + + if (logfile == NULL) { + fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); + return; + } + + fprintf(logfile, "binary: simple-inference\n"); + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc); + + fprintf(logfile, "\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "# Generation Results #\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "\n"); + + dump_string_yaml_multiline(logfile, "output", output.c_str()); + dump_vector_int_yaml(logfile, "output_tokens", output_tokens); + + llama_dump_timing_info_yaml(logfile, ctx); + fclose(logfile); +} + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +static void sigint_handler(int signo) { + if (signo == SIGINT) { + if (interrupted) { + done.store(true); + } else { + interrupted.store(true); + } + } +} +#endif + + +static bool check_unsupported(const gpt_params * params) { + std::string nope; + const llama_sampling_params & sparams = params->sparams; + + if (params->embedding) + nope = "embedding"; + else if (!sparams.grammar.empty()) + nope = "grammar"; // Currently broken most likely + else if (sparams.cfg_scale != 1.0f) + nope = "cfg_scale"; + else if (!sparams.cfg_negative_prompt.empty()) + nope = "cfg_negative_prompt"; + else if (!params->path_prompt_cache.empty()) + nope = "prompt cache"; + else if (params->escape) + nope = "prompt escaping"; + else if (params->interactive_first || params->instruct) + nope = "interactive first or instruct mode"; + else if (!params->input_prefix.empty() || !params->input_suffix.empty() || params->input_prefix_bos) + nope = "input prefix or suffix"; + else if (params->hellaswag) + nope = "hellaswag"; + else if (params->n_keep != 0) + nope = "keep"; + else if (!params->antiprompt.empty()) + nope = "reverse prompt"; + if (!nope.empty()) { + printf("%s: error: We don't support %s here.\n", __func__, nope.c_str()); + return false; + } + return true; +} + +bool gen_ctx::init_params(const int argc, char ** argv) { + if (gpt_params_parse(argc, argv, params) == false) { + return false; + } + + if (!check_unsupported(¶ms)) { + return false; + } + + if (params.rope_freq_base != 10000.0) { + printf("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); + } + + if (params.rope_freq_scale != 1.0) { + printf("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); + } + + if (params.n_ctx < 8) { + printf("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; + } + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + printf("%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + return true; +} + +bool gen_ctx::init_model() { + LOG("%s: llama backend init\n", __func__); + llama_backend_init(params.numa); + + // load the model and apply lora adapter, if any + LOG("%s: load the model and apply lora adapter, if any\n", __func__); + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + if (model == NULL) { + printf("%s: error: unable to load model\n", __func__); + return false; + } + + // print system information + { + printf("\n"); + printf("system_info: n_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + } + + n_ctx = llama_n_ctx(ctx); + n_vocab = llama_n_vocab(llama_get_model(ctx)); + + return true; +} + +bool gen_ctx::init_prompt() { + const bool add_bos = llama_should_add_bos_token(model); + LOG("add_bos: %d\n", add_bos); + + if (!params.prompt.empty()) { + LOG("tokenize the prompt\n"); + prompt_tokens = ::llama_tokenize(ctx, params.prompt, add_bos, true); + } + + LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); + LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, prompt_tokens).c_str()); + + // Should not run without any tokens + if (prompt_tokens.empty()) { + prompt_tokens.push_back(llama_token_bos(model)); + LOG("input was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, prompt_tokens).c_str()); + } + + LOG("n_ctx: %d\n", n_ctx); + + if ((int) prompt_tokens.size() > n_ctx - 4) { + printf("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4); + return false; + } + prompt_size = prompt_tokens.size(); + + if (params.verbose_prompt) { + printf("\n"); + printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + printf("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size()); + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + printf("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + + printf("\n"); + } + return true; +} + +bool gen_ctx::init_handlers() { + // save choice to use color for later + // (note for later: this is a slightly awkward choice) + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + return true; +} + +bool gen_ctx::init_sampling() { + printf("sampling: %s\n", llama_sampling_print(sparams).c_str()); +#ifndef LLAMA_NO_SEQREP_SAMPLER + for (auto & sr_params : sparams.seqrep_params) { + seqrep_sampler_params_dump(&sr_params); + } +#endif + ctx_sampling = llama_sampling_init(sparams); + return true; +} + +bool gen_ctx::init_batches() { + batch = llama_batch_init(std::max(int32_t(prompt_size), params.n_batch), 0, 1); + int n_remain = params.n_predict; + + if (n_remain < 0 || n_remain + int(prompt_size) > n_ctx) { + n_remain = n_ctx - prompt_size; + } + + ctxs_seq.reserve(params.n_parallel); + for (int32_t i = 0; i < params.n_parallel; i++) { + seq_ctx && bs = { + llama_seq_id(i), + -1, + i == 0 ? SEQ_INPUT : SEQ_SHARE_PROMPT, + size_t(n_remain), + 0, + llama_sampling_init(params.sparams), + -1, + {}, +#ifndef LLAMA_NO_SEQREP_SAMPLER + prompt_size + 1, + seqrep_rewind_state(n_vocab, n_ctx, 2000), + 0, + 0, +#endif + }; + GGML_ASSERT(prompt_size > 0); + bs.chunks.emplace_back(true, 0, prompt_tokens); + if (i > 0) { + bs.chunks.emplace_back(false, 0, std::vector()); + } +#ifndef LLAMA_NO_SEQREP_SAMPLER + seqrep_rewind_slot & rw_slot = bs.rewind_state.get_rewind_slot(0); + rw_slot.ctx_sampling = llama_sampling_init(params.sparams); +#endif + ctxs_seq.push_back(bs); + } + if (!ctxs_seq.empty()) { + focused_sequence = ctxs_seq.size() - 1; + } + + return true; +} + + +gen_ctx::gen_ctx(const int argc, char ** argv) { + bool result = true; + + result = result && init_params(argc, argv); + result = result && init_model(); + result = result && init_prompt(); + result = result && init_handlers(); + result = result && init_sampling(); + result = result && init_batches(); + if (!result) { + throw std::runtime_error("Initialization failed"); + } +} + +gen_ctx::~gen_ctx() { + for (auto & sctx : ctxs_seq) { + llama_sampling_free(sctx.ctx_sampling); +#ifndef LLAMA_NO_SEQREP_SAMPLER + for (auto & rs : sctx.rewind_state.rewind_slots) { + if (rs.ctx_sampling != nullptr) { + llama_sampling_free(rs.ctx_sampling); + rs.ctx_sampling = nullptr; + } + } +#endif + } + llama_sampling_free(ctx_sampling); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); +} + + +bool gen_ctx::feed_prompt(const std::vector & tokens, llama_pos pos, llama_seq_id seq) { + int32_t tokens_remain = tokens.size(); + const llama_token * tokens_curr = tokens.data(); + + console::set_display(console::prompt); + while (tokens_remain > 0 && !interrupted) { + const int32_t chunk_size = std::min(int32_t(tokens_remain), params.n_batch); + llama_batch_clear(batch); + for (int i = 0; i < chunk_size; i++) { + llama_batch_add(batch, tokens_curr[i], pos + i, {seq}, false); + } + pos += batch.n_tokens; + tokens_remain -= batch.n_tokens; + batch.logits[batch.n_tokens - 1] = tokens_remain < 1; + + if (llama_decode(ctx, batch) != 0) { + console::set_display(console::reset); + printf("%s : failed to eval\n", __func__); + return false; + } + decode_count++; + + // display text + for (int i = 0; i < batch.n_tokens; i++) { + const std::string token_str = llama_token_to_piece(ctx, tokens_curr[i]); + fputs(token_str.c_str(), stdout); + } + fflush(stdout); + + tokens_curr += batch.n_tokens; + } + console::set_display(console::reset); + return true; +} + +void gen_ctx::dump_chunks(const std::vector & chunks, const size_t start_offset) { + size_t offset = 0; + bool prompt_mode = false; + console::set_display(console::reset); + + for (const tokens_chunk & chunk : chunks) { + if (offset + chunk.tokens.size() <= start_offset) { + offset += chunk.tokens.size(); + continue; + } + + const size_t chunk_offset = offset < start_offset ? start_offset - offset : 0; + const size_t chunk_size = chunk.tokens.size() - chunk_offset; + const llama_token * tp = chunk.tokens.data() + chunk_offset; + + if (chunk.is_input != prompt_mode) { + prompt_mode = chunk.is_input; + console::set_display(prompt_mode ? console::prompt : console::reset); + } + + for (size_t i = 0; i < chunk_size; i++, tp++) { + const std::string token_str = llama_token_to_piece(ctx, *tp); + fputs(token_str.c_str(), stdout); + } + offset += chunk.tokens.size(); + } + if (prompt_mode) { + console::set_display(console::reset); + } + fflush(stdout); +} + +void gen_ctx::dump_batches(const size_t prompt_start) { + + bool first = true; + + for (seq_ctx & sctx : ctxs_seq) { + if (sctx.seq_id == focused_sequence) continue; + printf("\n\n%s Result #%d (size: %zu", + !first ? "====================" : "####################", + sctx.seq_id + 1, prompt_size + sctx.n_toks); +#ifndef LLAMA_NO_SEQREP_SAMPLER + printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens); +#endif + printf("%s):", sctx.state == SEQ_DONE ? ", DONE" : ""); + dump_chunks(sctx.chunks, prompt_start); + first = false; + } + seq_ctx & sctx = ctxs_seq[focused_sequence]; + printf("\n\n%s Result #%d (size: %zu", + !first ? "====================" : "####################", + sctx.seq_id + 1, prompt_size + sctx.n_toks); +#ifndef LLAMA_NO_SEQREP_SAMPLER + printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens); +#endif + puts("):"); + dump_chunks(sctx.chunks, prompt_start); +} + +void gen_ctx::handle_seq(seq_ctx & sctx) { + switch (sctx.state) { + case SEQ_DONE: + case SEQ_SHARE_PROMPT: break; + + case SEQ_GENERATING: { + GGML_ASSERT(sctx.batch_idx >= 0); + scratch.resize(prompt_size); + concat_chunks(sctx.chunks, scratch, prompt_size); +#ifndef LLAMA_NO_SEQREP_SAMPLER + handle_seq_seqrep(sctx); +#endif + sctx.last_sampled = llama_sampling_sample(ctx_sampling, ctx, NULL, sctx.batch_idx, scratch); + llama_sampling_accept(sctx.ctx_sampling, ctx, sctx.last_sampled, true); + if (sctx.seq_id == focused_sequence) { + const std::string token_str = llama_token_to_piece(ctx, sctx.last_sampled); + fputs(token_str.c_str(), stdout); + fflush(stdout); + } + sctx.n_toks++; + sctx.n_remain--; + if (sctx.chunks.empty() || sctx.chunks.back().is_input) { + sctx.chunks.emplace_back(0, false, std::vector()); + } + sctx.chunks.back().tokens.push_back(sctx.last_sampled); + if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) { + sctx.state = SEQ_DONE; + llama_kv_cache_seq_rm(ctx, sctx.seq_id, -1, -1); + sctx.batch_idx = -1; + // printf(" [end of text]\n"); + // break; + } else { + sctx.batch_idx = batch.n_tokens; + llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_toks, {sctx.seq_id}, true); + } + } break; + + case SEQ_INPUT: { + sctx.last_sampled = -1; + GGML_ASSERT(!sctx.chunks.empty()); + tokens_chunk & chunk = sctx.chunks.back(); + GGML_ASSERT(chunk.is_input); + GGML_ASSERT(chunk.consumed < chunk.tokens.size()); + GGML_ASSERT(!chunk.tokens.empty()); + + const size_t remain = chunk.tokens.size() - chunk.consumed; + const size_t to_consume = std::min(size_t(params.n_batch), remain); + for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) { + llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_toks + i), {sctx.seq_id}, false); + } + chunk.consumed += to_consume; + sctx.n_remain -= to_consume; + sctx.n_toks += to_consume; + if (chunk.consumed == chunk.tokens.size()) { +#ifndef LLAMA_NO_SEQREP_SAMPLER + // FIXME: Move this logic to a more appropriate place. + for (size_t i = 0; i < chunk.consumed; i++) { + sctx.rewind_state.logit_slots.emplace_back(n_vocab); + } + sctx.high_water_mark = sctx.n_toks + 1; +#endif + sctx.batch_idx = batch.n_tokens - 1; + batch.logits[sctx.batch_idx] = true; + sctx.chunks.emplace_back(false, 0, std::vector()); + sctx.chunks.back().tokens.reserve(sctx.n_remain); + sctx.state = SEQ_GENERATING; + } else { + sctx.batch_idx = -1; + } + } break; + + default: + throw std::runtime_error("Unexpected state in handle_seq"); + } +} + +#ifndef LLAMA_NO_SEQREP_SAMPLER + void gen_ctx::handle_seq_seqrep(seq_ctx & sctx) { + if (sctx.n_toks > 0) { + seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_toks); + if (rw_slot.ctx_sampling == nullptr) { + rw_slot.ctx_sampling = llama_sampling_init(params.sparams); + } + llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling); + sctx.rewind_state.set_logits_slot(ctx, sctx.n_toks, sctx.batch_idx); + } else { + return; + } + std::vector seq_last_tokens; + seq_last_tokens.reserve(sctx.n_toks); + concat_chunks(sctx.chunks, seq_last_tokens, prompt_size); + + size_t rewind_distance = + llama_seqrep_handle_rewind( + ctx, sctx.rewind_state, seq_last_tokens, sctx.n_toks, prompt_tokens, + sparams.seqrep_params, &sctx.high_water_mark, sctx.batch_idx); + if (rewind_distance < 1) { + return; + } + GGML_ASSERT(rewind_distance <= sctx.n_toks && "Rewind index out of bounds somehow?"); + const size_t slot_idx = sctx.n_toks - rewind_distance; + const llama_token nl_id = llama_token_nl(model); + + seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(slot_idx); + llama_sampling_cp(rw_slot.ctx_sampling, sctx.ctx_sampling); + + if (sctx.seq_id == focused_sequence) { + console::set_display(console::error); + fputs("\u3010", stdout); + for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) { + if (seq_last_tokens[i] == nl_id) { + fputs("\\n", stdout); + continue; + } + const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]); + fputs(token_str.c_str(), stdout); + } + fputs("\u3011", stdout); + console::set_display(console::reset); + fflush(stdout); + } + + sctx.n_remain += rewind_distance; + sctx.n_toks -= rewind_distance; + sctx.rewind_count++; + sctx.rewind_tokens += rewind_distance; + llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_toks + 1, -1); + while (!sctx.chunks.empty() && rewind_distance > 0) { + tokens_chunk & last_chunk = sctx.chunks.back(); + GGML_ASSERT(!last_chunk.is_input); + + if (last_chunk.tokens.size() >= rewind_distance) { + last_chunk.tokens.resize(last_chunk.tokens.size() - rewind_distance); + rewind_distance = 0; + break; + } + rewind_distance -= last_chunk.tokens.size(); + sctx.chunks.pop_back(); + } + } +#endif + +bool gen_ctx::go() { + if (ctxs_seq.empty()) { + return false; + } + + if (decode_count == 0) { + scratch.reserve(n_ctx); + scratch.resize(prompt_size); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), scratch.begin()); + // FIXME: Hacky. + if (!feed_prompt(prompt_tokens)) { + throw std::runtime_error("Prompt processing failed"); + } + for (auto & sctx : ctxs_seq) { + sctx.batch_idx = batch.n_tokens - 1; + sctx.state = SEQ_GENERATING; + if (sctx.seq_id == 0) { + sctx.chunks.back().consumed = prompt_size; + sctx.chunks.emplace_back(false, 0, std::vector()); + } else { + sctx.chunks.front().consumed = prompt_size; + llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size); + } +#ifndef LLAMA_NO_SEQREP_SAMPLER + seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(0); + rw_slot.ctx_sampling = llama_sampling_init(params.sparams); + llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling); + sctx.rewind_state.set_logits_slot(ctx, 0, sctx.batch_idx); +#endif + } + } + + llama_batch_clear(batch); + for (auto & sctx : ctxs_seq) { + handle_seq(sctx); + } + if (batch.n_tokens == 0) return false; + + decode_time_last = ggml_time_us(); + const int decode_result = llama_decode(ctx, batch); + decode_time_last = std::max(int64_t(0), ggml_time_us() - decode_time_last); + decode_time_total += decode_time_last; + + // FIXME: Handle KV cache pressure better. + if (decode_result != 0) { + fprintf(stderr, "%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens, + decode_result == 1 ? "couldn't find slot" : "unknown error"); + return false; + } + decode_count++; + return true; +} + +static bool handle_commands(gen_ctx & gctx) { + std::string line; + line.reserve(1024); + + + printf("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1); + fflush(stdout); + while (1) { + printf("> "); + fflush(stdout); + console::readline(line, false); + console::set_display(console::reset); + while (!line.empty() && std::isspace(line.back())) { + line.pop_back(); + } + if (line.empty()) break; + if (line.size() < 2 || line.front() != '/') { + printf("\n- Bad command\n"); + continue; + } + size_t sep_idx = line.find(' '); + std::string command, rest; + if (sep_idx != std::string::npos) { + command = line.substr(1, sep_idx - 1); + rest = line.substr(sep_idx + 1); + } else { + command = line.substr(1); + } + for (char & c : command) c = std::tolower(c); + + if (command == "h" || command == "help") { + printf("- Help: For commands with [SEQ], optionally specify a sequence number here to set the target.\n"); + printf(" If sequence isn't specified, then the current focus is used if possible.\n"); + printf(" One of any punctuation character is allowed after the number.\n"); + printf(" For example, '/1add hello' and '/1,add hello' both add 'hello' to sequence 1.\n"); + printf("- Available commands:\n"); + printf(" /[SEQ]add TEXT : Adds the specified text to the focused sequence. Alias: /a\n"); + printf(" /[SEQ]addesc TEXT : Same as /add but handles escapes (\\n, \\x20, etc) and tokenizes without a leading space. Alias: /ae\n"); + printf(" /[SEQ]addline TEXT : Same as /add but appends a newline. Alias: /al\n"); + printf(" /help : Show this help. Alias: /h\n"); + printf(" /[SEQ]dump [N] : Dump the last N tokens of SEQ showing offsets from the end. N defaults to 200 if not specified. Alias: /d\n"); + printf(" /[SEQ]dumptokens N : Same as /dump but displays token IDs as well. Alias: /dt\n"); + printf(" /[SEQ]kill : Stop sequence SEQ. Alias: /k\n"); + printf(" /list : List sequences and their state. Alias: /l\n"); + printf(" /[SEQ]focus : Focus sequence SEQ. Alias: Just use /1, /2, etc\n"); + printf(" /[SEQ]print : Display the content of SEQ. Alias: /p\n"); + printf(" /quit : Exit the program. Alias: /q\n"); + printf("- End listing\n"); + continue; + } + + if (command == "q" || command == "quit") return false; + + llama_seq_id target = gctx.focused_sequence; + + // Focus + if (isdigit(command[0])) { + char * parse_end = nullptr; + target = std::strtol(command.c_str(), &parse_end, 10); + if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) { + printf("! Bad seq id\n"); + continue; + } + target--; + if (std::ispunct(*parse_end)) parse_end++; + command = std::string(parse_end); + } + + if (command.empty() || command == "focus") { + printf("- Focus changed from %d to %d\n", gctx.focused_sequence + 1, target + 1); + gctx.focused_sequence = llama_seq_id(target); + continue; + } + + if (command == "k" || command == "kill") { + if (target == gctx.focused_sequence) { + printf("! Kill: Can't kill focus\n"); + } else { + printf("- Killed sequence %d\n", target + 1); + gctx.ctxs_seq[target].state = SEQ_DONE; + llama_kv_cache_seq_rm(gctx.ctx, target, -1, -1); + } + continue; + } + + if (command == "l" || command == "list") { + printf("- Listing %zu sequence%s:\n", + gctx.ctxs_seq.size(), + gctx.ctxs_seq.size() != 1 ? "s" : ""); + for (const seq_ctx & sctx : gctx.ctxs_seq) { + std::string label; + switch (sctx.state) { + case SEQ_DONE: label = "DONE"; break; + case SEQ_GENERATING: label = "LIVE"; break; + case SEQ_INPUT: label = "FEED"; break; + case SEQ_SHARE_PROMPT: label = "WAIT"; break; + default: GGML_ASSERT(false); + } + printf(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ", + sctx.seq_id == gctx.focused_sequence ? "*" : " ", + sctx.seq_id + 1, label.c_str(), + sctx.n_toks, sctx.n_remain); + for (const tokens_chunk & chunk : sctx.chunks) { + if (chunk.is_input) { + printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed); + + } else { + printf("GEN(%5zu), ", chunk.tokens.size()); + } + } + printf("\n"); + } + continue; + } + + if ( command == "al" || command == "a" || command == "ae" + || command == "add" || command == "addline" || command == "addesc") { + bool is_special = false; + seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target]; + + if (command == "al" || command == "addline") { + rest.push_back('\n'); + } else if (command == "ae" || command == "addesc") { + process_escapes(rest); + is_special = true; + } + std::vector input_tokens = ::llama_tokenize(gctx.model, rest, false, is_special); + if (input_tokens.size() > sctx.n_remain) { + printf("! Input is %zu token(s) but sequence %d only has space for %zu\n", + input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain); + continue; + } + if (!sctx.chunks.back().is_input) { + sctx.chunks.emplace_back(true, 0, input_tokens); + } else { + tokens_chunk & chunk = sctx.chunks.back(); + const size_t old_size = chunk.tokens.size(); + + chunk.tokens.resize(old_size + input_tokens.size()); + std::copy(input_tokens.begin(), input_tokens.end(), chunk.tokens.begin() + old_size); + } + sctx.state = SEQ_INPUT; + continue; + } + + if (command == "p" || command == "print") { + seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target]; + std::string label; + switch (sctx.state) { + case SEQ_DONE: label = "DONE"; break; + case SEQ_GENERATING: label = "LIVE"; break; + case SEQ_INPUT: label = "FEED"; break; + case SEQ_SHARE_PROMPT: label = "WAIT"; break; + default: GGML_ASSERT(false); + } + + printf("- Showing sequence %3d%s: state %s, generated %5zu, remain %5zu. chunks: ", + sctx.seq_id + 1, + sctx.seq_id == gctx.focused_sequence ? "(focus)" : " ", + label.c_str(), sctx.n_toks, sctx.n_remain); + for (const tokens_chunk & chunk : sctx.chunks) { + if (chunk.is_input) { + printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed); + + } else { + printf("GEN(%5zu), ", chunk.tokens.size()); + } + } + printf("\n"); + gctx.dump_chunks(sctx.chunks); + printf("\n- Done\n"); + continue; + } + + if (command == "d" || command == "dt" || command == "dump" || command == "dumptokens") { + seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target]; + const bool with_id = command == "dt" || command == "dumptokens"; + const size_t max_n = sctx.n_toks + gctx.prompt_size; + size_t dump_n = size_t(std::max(0, atoi(rest.c_str()))); + if (dump_n == 0) dump_n = 200; + dump_n = std::min(dump_n, max_n); + + printf("- Dumping last %zu token%s from sequence %d\n", + dump_n, dump_n != 1 ? "s" : "", target + 1); + + std::vector result; + result.reserve(dump_n); + concat_chunks(sctx.chunks, result, max_n - dump_n); + GGML_ASSERT(result.size() == dump_n); + for (size_t i = 0; i < dump_n; i++) { + const llama_token tid = result[i]; + console::set_display(console::user_input); + printf("[%zu", dump_n - i); + if (with_id) { + printf(",%d", tid); + } + fputs("]", stdout); + console::set_display(console::reset); + fputs(llama_token_to_piece(gctx.ctx, tid).c_str(), stdout); + + } + console::set_display(console::reset); + printf("\n\n- Dump complete.\n"); + continue; + } + + printf("! Bad command\n"); + } + return true; +} + +int main(int argc, char ** argv) { + gen_ctx gctx(argc, argv); + + // This might look weird but done can get set while go() is running. + while (!done && gctx.go() && !done) { + bool need_dump = gctx.params.n_parallel > 1 && gctx.decode_count % SI_DUMP_SEQUENCES_INTERVAL == 0; + if (interrupted) { + if (!gctx.params.interactive || !handle_commands(gctx)) break; + // Double check that ^C wasn't hit again. + if (done) break; + interrupted = false; + need_dump = true; + } + if (need_dump) { + printf("\n-- Last decode[%zu]: %.3f, avg: %.3f", + gctx.decode_count, double(gctx.decode_time_last) / 1000000, + (double(gctx.decode_time_total) / 1000000) / double(gctx.decode_count)); + gctx.dump_batches((gctx.prompt_size > 20) ? (gctx.prompt_size - 10) : 0); + } + } + gctx.focused_sequence = gctx.ctxs_seq.size() - 1; + gctx.dump_batches(); + puts(""); + console::cleanup(); + + llama_print_timings(gctx.ctx); +} diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 32e58941c0ee0..b8b19081714fa 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,10 +1,14 @@ #include "ggml.h" #include "llama.h" +#ifndef LLAMA_NO_SEQREP_SAMPLER +#include "common/seqrep-sampler.h" +#endif #ifdef NDEBUG #undef NDEBUG #endif +#include #include #include #include @@ -128,6 +132,79 @@ static void test_repetition_penalties( } } +// FIXME: This should probably just be moved to a separate test executable. +#ifndef LLAMA_NO_SEQREP_SAMPLER +// NOTE: Compares expected_probs at id position, not sorted position like the other +// test functions. +static void test_seqrep_penalty( + const std::vector & probs, + const std::vector & last_tokens, + const std::vector & expected_probs, + const llama_sampler_seqrep_params * params) { + assert(probs.size() == expected_probs.size()); + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sample_softmax(nullptr, &candidates_p); + DUMP(&candidates_p); + llama_sample_seqrep_penalty(nullptr, &candidates_p, last_tokens, params); + llama_sample_softmax(nullptr, &candidates_p); + DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[candidates_p.data[i].id]) < 1e-3); + } +} + +static void run_seqrep_tests(void) { + llama_sampler_seqrep_params params; + + // Compatible with frequency/presence penalty + memset(¶ms, 0, sizeof(llama_sampler_seqrep_params)); + params.last_n = 1024; + params.min_length = 1; + params.mid_word_scale = 1.0f; + params.presence_penalty = 5.0f; + params.length_penalty = 5.0f; + params.flags |= LLAMA_SEQREP_ABSOLUTE_PENALTY; + test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.000011f, 0.249997f, 0.249997f, 0.249997f, 0.249997f}, ¶ms); + test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.000023f, 0.000023f, 0.000023f, 0.499966f, 0.499966f}, ¶ms); + test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.000000f, 0.000023f, 0.000023f, 0.499977f, 0.499977f}, ¶ms); + + // Compatible with repetition penalty + memset(¶ms, 0, sizeof(llama_sampler_seqrep_params)); + params.last_n = 1024; + params.min_length = 1; + params.mid_word_scale = 1.0f; + params.presence_penalty = 50.0f; + params.length_penalty = 1.0f; + test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0, 0.25f, 0.25f, 0.25f, 0.25f}, ¶ms); + test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0, 0, 0, 0.5f, 0.5f}, ¶ms); + test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5f, 0.5f}, ¶ms); + + // Seqrep mode + // memset(¶ms, 0, sizeof(llama_sampler_seqrep_params)); + // params.last_n = 1024; + // params.min_length = 3; + // params.mid_word_scale = 1.0f; + // params.presence_penalty = 50.0f; + // params.length_penalty = 1.0f; + // test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, ¶ms); + // test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.20f, 0.20f, 0.20f, 0.20f, 0.20f}, ¶ms); + // params.tolerance = 1.0f; + // test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, ¶ms); +} +#endif + + int main(void) { ggml_time_init(); @@ -154,6 +231,10 @@ int main(void) { test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); +#ifndef LLAMA_NO_SEQREP_SAMPLER + run_seqrep_tests(); +#endif + printf("OK\n"); return 0;