Skip to content

Commit 0ae4c07

Browse files
committed
Move out of llama lib and into common directory
1 parent 167762d commit 0ae4c07

File tree

9 files changed

+553
-468
lines changed

9 files changed

+553
-468
lines changed

Makefile

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ ifdef LLAMA_DISABLE_LOGS
177177
MK_CPPFLAGS += -DLOG_DISABLE_LOGS
178178
endif # LLAMA_DISABLE_LOGS
179179

180+
ifdef LLAMA_DISABLE_SEQREP_SAMPLER
181+
MK_CPPFLAGS += -DLLAMA_NO_SEQREP_SAMPLER
182+
endif
183+
180184
# warnings
181185
MK_CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith \
182186
-Wmissing-prototypes -Werror=implicit-int -Wno-unused-function
@@ -476,7 +480,13 @@ OBJS += ggml-alloc.o
476480
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h
477481
$(CXX) $(CXXFLAGS) -c $< -o $@
478482

479-
common.o: common/common.cpp common/common.h build-info.h common/log.h
483+
COMMON_DEPS = common/common.cpp common/common.h build-info.h common/log.h
484+
COMMON_OBJS = common.o
485+
ifndef LLAMA_DISABLE_SEQREP_SAMPLER
486+
COMMON_DEPS += common/seqrep-sampler.cpp common/seqrep-sampler.h
487+
COMMON_OBJS += seqrep-sampler.o
488+
endif
489+
common.o: $(COMMON_DEPS)
480490
$(CXX) $(CXXFLAGS) -c $< -o $@
481491

482492
console.o: common/console.cpp common/console.h
@@ -485,6 +495,9 @@ console.o: common/console.cpp common/console.h
485495
grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
486496
$(CXX) $(CXXFLAGS) -c $< -o $@
487497

498+
seqrep-sampler.o: common/seqrep-sampler.cpp common/seqrep-sampler.h
499+
$(CXX) $(CXXFLAGS) -c $< -o $@
500+
488501
libllama.so: llama.o ggml.o $(OBJS)
489502
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
490503

@@ -495,7 +508,7 @@ clean:
495508
# Examples
496509
#
497510

498-
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS)
511+
main: examples/main/main.cpp build-info.h ggml.o llama.o $(COMMON_OBJS) console.o grammar-parser.o $(OBJS)
499512
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
500513
@echo
501514
@echo '==== Run ./main -h for help. ===='

common/common.cpp

Lines changed: 10 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#include "build-info.h"
33
#include "llama.h"
44

5+
#ifndef LLAMA_NO_SEQREP_SAMPLER
6+
#include "seqrep-sampler.h"
7+
#endif
8+
59
#include <algorithm>
610
#include <cassert>
711
#include <cmath>
@@ -102,144 +106,6 @@ void process_escapes(std::string& input) {
102106
input.resize(output_idx);
103107
}
104108

105-
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params) {
106-
assert(params != NULL);
107-
memset(params, 0, sizeof(llama_sampler_seqrep_params));
108-
params->last_n = 256;
109-
params->mid_word_scale = 0.1f;
110-
params->tolerance_half_step_cost = 1.0f;
111-
}
112-
113-
void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params) {
114-
assert(params != NULL);
115-
LOG_TEE("seqrep(last_n = %d, min_length = %zd, start_offset = %zd, presence_penalty = %.4f, length_penalty = %.4f, tolerance = %.4f, mid_word_scale = %.4f, tolerance_match_credit = %.4f, tolerance_half_step_cost = %.4f, flags = %d)\n",
116-
params->last_n, params->min_length, params->start_offset, params->presence_penalty,
117-
params->length_penalty, params->tolerance, params->mid_word_scale, params->tolerance_match_credit,
118-
params->tolerance_half_step_cost, params->flags);
119-
}
120-
121-
void seqrep_sampler_help() {
122-
llama_sampler_seqrep_params p;
123-
seqrep_sampler_params_init(&p);
124-
fprintf(stdout, "==== Sequence Repetition Sampler Help ====\n\n");
125-
fprintf(stdout, " The sequence repetition sampler takes a configuration string in the format:\n");
126-
fprintf(stdout, " arg1:arg2:argN\n");
127-
fprintf(stdout, " A colon separated argument can be a key value pair like xyz=1 or flag like xyz\n");
128-
fprintf(stdout, "\n- Available key/value arguments\n");
129-
fprintf(stdout, " * repetition_mode=REPEAT_PENALTY\n emulates the repetition penalty sampler. warning: 1.0 disables penalties since this preset enables flag_divide_by_penalty. using 0.0 is probably not what you want\n");
130-
fprintf(stdout, " * presence_mode=PRESENCE_PENALTY\n emulates the presence penalty sampler\n");
131-
fprintf(stdout, " * frequency_mode=FREQUENCY_PENALTY\n Emulates the repetition penalty sampler\n");
132-
fprintf(stdout, " * last_n\n last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", p.last_n);
133-
fprintf(stdout, " * min_length\n minimum matching sequence length (default: %zd, < 2 = disabled)\n", p.min_length);
134-
fprintf(stdout, " * presence_penalty\n presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", p.presence_penalty);
135-
fprintf(stdout, " * length_penalty\n penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", p.length_penalty);
136-
fprintf(stdout, " * tolerance\n tolerance for fuzzy matching sequences (default: %f, 0 = disabled)\n", p.tolerance);
137-
fprintf(stdout, " * mid_word_scale\n scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", p.mid_word_scale);
138-
fprintf(stdout, " * tolerance_match_credit\n credit tolerance on matched tokens (default: %f, 0.0 = disabled)\n", p.tolerance_match_credit);
139-
fprintf(stdout, " * tolerance_half_step_cost\n advanced option to adjust tolerance cost for failed matches within a half step of a match (default: %f, 1.0 = normal)\n", p.tolerance_half_step_cost);
140-
fprintf(stdout, "\n- Available flags arguments (currently all default to disabled)\n");
141-
fprintf(stdout, " * flag_immediate_wildcard\n when tolerance is consumed, by default it doesn't count as a match until a real match is found\n");
142-
fprintf(stdout, " * flag_tolerance_no_consecutive\n do not allow using tolerance consecutively\n");
143-
fprintf(stdout, " * flag_tolerance_no_first\n do not allow using tolerance before the first match\n");
144-
fprintf(stdout, " * flag_tolerance_cap_initial\n only meaningful with match credit, prevents match credit adjusting tolerance higher than the initial value\n");
145-
fprintf(stdout, " * flag_penalize_length_max_seen\n when applying length_penalty, use the maximum seen sequence length rather than the total length of seen sequences\n");
146-
fprintf(stdout, " * flag_divide_by_penalty\n divide the logit when applying a penalty rather than subtracting it. warning: when this flag is enabled, 1.0 disables penalties not 0.0. 0.0 is probably not what you want\n");
147-
fprintf(stdout, "\n- Examples:\n");
148-
fprintf(stdout, " * repetition_mode=1.2:last_n=32\n same as --repeat-last-n 32 --repeat-penalty 1.2\n");
149-
fprintf(stdout, " * presence_mode=.2:last_n=32\n same as --repeat-last-n 32 --presence-penalty .2\n");
150-
fprintf(stdout, " * frequency_mode=.2:last_n=32\n same as --repeat-last-n 32 --frequency-penalty .2\n");
151-
fprintf(stdout, " * min_length=3:tolerance=1:length_penalty=.2:last_n=-1\n match repeated sequences of at least 3 tokens within the entire context and apply a penalty of 0.2*total_length to the token that would continue the sequence. allow one non-matching token in matched sequences.\n");
152-
}
153-
154-
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params) {
155-
assert(params != NULL);
156-
assert(s != NULL);
157-
size_t offset = 0;
158-
std::string sparams = s;
159-
size_t slen = sparams.size();
160-
161-
while (offset < slen) {
162-
// printf("SR OFFS: %lu\n", offset);
163-
size_t argsep = sparams.find_first_of(':', offset);
164-
std::string argchunk;
165-
if (argsep == std::string::npos) {
166-
argchunk = sparams.substr(offset);
167-
} else if (argsep > offset) {
168-
argchunk = sparams.substr(offset, argsep - offset);
169-
}
170-
std::string argval;
171-
size_t valsep = argchunk.find_first_of('=');
172-
if (valsep != std::string::npos && valsep < argchunk.size()) {
173-
argval = argchunk.substr(valsep + 1);
174-
argchunk.resize(valsep);
175-
}
176-
// printf("SR: k[%s] = v[%s]\n", argchunk.c_str(), argval.c_str());
177-
if (argchunk.empty() && argval.empty()) {
178-
// pass
179-
} else if (argchunk == "repetition_mode") {
180-
params->last_n = 64;
181-
params->min_length = 1;
182-
params->mid_word_scale = 1.0f;
183-
params->flags = LLAMA_SEQREP_DIVIDE_BY_PENALTY;
184-
params->length_penalty = 1.0f;
185-
params->presence_penalty = argval.empty() ? 1.1f : std::atof(argval.c_str());
186-
} else if (argchunk == "presence_mode") {
187-
params->last_n = 64;
188-
params->min_length = 1;
189-
params->mid_word_scale = 1.0f;
190-
params->flags = 0;
191-
params->length_penalty = 0.0f;
192-
params->presence_penalty = std::atof(argval.c_str());
193-
} else if (argchunk == "frequency_mode") {
194-
params->last_n = 64;
195-
params->min_length = 1;
196-
params->mid_word_scale = 1.0f;
197-
params->flags = 0;
198-
params->length_penalty = std::atof(argval.c_str());
199-
params->presence_penalty = 0.0f;
200-
} else if (argchunk == "flag_immediate_wildcard") {
201-
params->flags |= LLAMA_SEQREP_IMMEDIATE_WILDCARD;
202-
} else if (argchunk == "flag_tolerance_no_consecutive") {
203-
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE;
204-
} else if (argchunk == "flag_tolerance_no_first") {
205-
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_FIRST;
206-
} else if (argchunk == "flag_tolerance_cap_initial") {
207-
params->flags |= LLAMA_SEQREP_TOLERANCE_CAP_INITIAL;
208-
} else if (argchunk == "flag_penalize_length_max_seen") {
209-
params->flags |= LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN;
210-
} else if (argchunk == "flag_divide_by_penalty") {
211-
params->flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
212-
} else if (argchunk == "min_length") {
213-
params->min_length = std::atoi(argval.c_str());
214-
} else if (argchunk == "start_offset") {
215-
params->start_offset = std::atoi(argval.c_str());
216-
} else if (argchunk == "last_n") {
217-
params->last_n = std::atoi(argval.c_str());
218-
} else if (argchunk == "tolerance") {
219-
params->tolerance = std::atof(argval.c_str());
220-
} else if (argchunk == "presence_penalty") {
221-
params->presence_penalty = std::atof(argval.c_str());
222-
} else if (argchunk == "length_penalty") {
223-
params->length_penalty = std::atof(argval.c_str());
224-
} else if (argchunk == "mid_word_scale") {
225-
params->mid_word_scale = std::atof(argval.c_str());
226-
} else if (argchunk == "tolerance_match_credit") {
227-
params->tolerance_match_credit = std::atof(argval.c_str());
228-
} else if (argchunk == "tolerance_half_step_cost") {
229-
params->tolerance_half_step_cost = std::atof(argval.c_str());
230-
} else {
231-
fprintf(stderr, "seqrep: Bad argument [%s]=[%s]!\n", argchunk.c_str(), argval.c_str());
232-
return false;
233-
}
234-
if (argsep != std::string::npos) {
235-
offset = argsep + 1;
236-
} else {
237-
break;
238-
}
239-
}
240-
return true;
241-
}
242-
243109
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
244110
bool invalid_param = false;
245111
std::string arg;
@@ -386,6 +252,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
386252
break;
387253
}
388254
params.presence_penalty = std::stof(argv[i]);
255+
#ifndef LLAMA_NO_SEQREP_SAMPLER
389256
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
390257
if (++i >= argc) {
391258
invalid_param = true;
@@ -405,6 +272,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
405272
&& (sr_params.presence_penalty != 0.0f || sr_params.length_penalty != 0.0f)) {
406273
params.seqrep_params.push_back(sr_params);
407274
}
275+
#endif
408276
} else if (arg == "--mirostat") {
409277
if (++i >= argc) {
410278
invalid_param = true;
@@ -779,8 +647,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
779647
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
780648
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
781649
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
650+
#ifndef LLAMA_NO_SEQREP_SAMPLER
782651
printf(" -seqrep CFG, --seqrep-penalty CFG\n");
783652
printf(" add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
653+
#endif
784654
printf(" --mirostat N use Mirostat sampling.\n");
785655
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
786656
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
@@ -1069,9 +939,11 @@ llama_token llama_sample_token(
1069939
last_tokens.data() + last_tokens.size() - last_n_repeat,
1070940
last_n_repeat, alpha_frequency, alpha_presence);
1071941

942+
#ifndef LLAMA_NO_SEQREP_SAMPLER
1072943
for (auto & sr_params : params.seqrep_params) {
1073944
llama_sample_seqrep_penalty(ctx, &cur_p, last_tokens.data(), last_tokens.size(), &sr_params);
1074945
}
946+
#endif
1075947

1076948
if (!penalize_nl) {
1077949
for (size_t idx = 0; idx < cur_p.size; idx++) {

common/common.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
#include "llama.h"
66

7+
#ifndef LLAMA_NO_SEQREP_SAMPLER
8+
#include "seqrep-sampler.h"
9+
#endif
10+
711
#define LOG_NO_FILE_LINE_FUNCTION
812
#include "log.h"
913

@@ -55,7 +59,9 @@ struct gpt_params {
5559
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
5660
float frequency_penalty = 0.00f; // 0.0 = disabled
5761
float presence_penalty = 0.00f; // 0.0 = disabled
62+
#ifndef LLAMA_NO_SEQREP_SAMPLER
5863
std::vector<llama_sampler_seqrep_params> seqrep_params;
64+
#endif
5965
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
6066
float mirostat_tau = 5.00f; // target entropy
6167
float mirostat_eta = 0.10f; // learning rate
@@ -205,7 +211,3 @@ std::string get_sortable_timestamp();
205211
void dump_non_result_info_yaml(
206212
FILE * stream, const gpt_params & params, const llama_context * lctx,
207213
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
208-
209-
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params);
210-
void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params);
211-
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params);

0 commit comments

Comments
 (0)