From d7a4f3e497a81d1920fc0b8cff90466a09687e23 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 3 Nov 2024 01:07:27 +0100 Subject: [PATCH 1/2] main : add special commands --- examples/main/main.cpp | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 374ed47ad6311..b812e4c96e392 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -31,6 +31,10 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const std::string CMD_READFILE = "/readfile"; +static const std::string CMD_SAVE_SESS = "/savesess"; +static const std::string CMD_LOAD_SESS = "/loadsess"; + static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; @@ -851,6 +855,43 @@ int main(int argc, char ** argv) { LOG_DBG("buffer: '%s'\n", buffer.c_str()); + // check for special commands + if (buffer.rfind(CMD_READFILE, 0) == 0) { + const std::string filename = string_strip(buffer.substr(CMD_READFILE.length())); + LOG_DBG("reading file: '%s'\n", filename.c_str()); + std::ifstream text_file(filename); + if (!text_file) { + LOG("failed to open file '%s'\n", filename.c_str()); + continue; + } + std::stringstream tmp; + tmp << text_file.rdbuf(); + buffer = tmp.str(); + LOG("%s\n", buffer.c_str()); + } else if (buffer.rfind(CMD_SAVE_SESS, 0) == 0) { + const std::string filename = string_strip(buffer.substr(CMD_SAVE_SESS.length())); + LOG("save session file: '%s'\n", filename.c_str()); + size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past); + if (res == 0) { + LOG("failed to save session file '%s'\n", filename.c_str()); + } + continue; + } else if (buffer.rfind(CMD_LOAD_SESS, 0) == 0) { + const std::string filename = string_strip(buffer.substr(CMD_LOAD_SESS.length())); + LOG("load session file: '%s'\n", filename.c_str()); + std::vector sess_tokens; + sess_tokens.resize(n_ctx); + size_t n_loaded_tokens; + size_t res = llama_state_load_file(ctx, filename.c_str(), sess_tokens.data(), sess_tokens.size(), &n_loaded_tokens); + if (res == 0) { + LOG("failed to load session file '%s'\n", filename.c_str()); + } else { + n_past = n_loaded_tokens; + LOG("loaded %zu tokens from session file '%s'\n", n_loaded_tokens, filename.c_str()); + } + continue; + } + const size_t original_size = embd_inp.size(); if (params.escape) { From 1716e6b25a665cea377281c1439728253761cb4b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 3 Nov 2024 22:16:14 +0100 Subject: [PATCH 2/2] add some other commands --- common/arg.cpp | 7 +++ common/common.h | 1 + examples/main/main.cpp | 126 ++++++++++++++++++++++++++++++++++------- 3 files changed, 113 insertions(+), 21 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 7c5c5e5cd5b88..ce1be878fc96f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1939,6 +1939,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.simple_io = true; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + add_opt(common_arg( + {"-nsc", "--no-special-command"}, + string_format("disable special commands in conversation mode (default: %s)", params.special_cmds ? "enabled" : "disabled"), + [](common_params & params) { + params.special_cmds = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( {"-ld", "--logdir"}, "LOGDIR", "path under which to save YAML logs (no logging if unset)", diff --git a/common/common.h b/common/common.h index cd5a8e051d33a..9039ff6f76bdc 100644 --- a/common/common.h +++ b/common/common.h @@ -251,6 +251,7 @@ struct common_params { bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it + bool special_cmds = true; // enable special commands in main example bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\" bool multiline_input = false; // reverse the usage of `\` diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b812e4c96e392..9d2946d64505a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -31,10 +31,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const std::string CMD_READFILE = "/readfile"; -static const std::string CMD_SAVE_SESS = "/savesess"; -static const std::string CMD_LOAD_SESS = "/loadsess"; - static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; @@ -45,6 +41,13 @@ static std::vector * g_output_tokens; static bool is_interacting = false; static bool need_insert_eot = false; +static const char * help_special_cmds = "special commands in conversation mode:\n" + " /readfile FILE read prompt from file\n" + " /savesess FILE save session to file\n" + " /loadsess FILE load session from file\n" + " /regen regenerate the last response\n" + " /dump FILE dump chat content to a file\n"; + static void print_usage(int argc, char ** argv) { (void) argc; @@ -52,6 +55,8 @@ static void print_usage(int argc, char ** argv) { LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]); LOG("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]); LOG("\n"); + LOG("%s", help_special_cmds); + LOG("\n"); } static bool file_exists(const std::string & path) { @@ -109,6 +114,21 @@ static void write_logfile( fclose(logfile); } +static std::vector try_parse_command(std::string text) { + if (text.empty() || text[0] != '/') { + return {}; + } + std::vector elem = string_split(text, ' '); + std::vector res; + // filter empty strings + for (const auto & e : elem) { + if (!e.empty()) { + res.push_back(string_strip(e)); + } + } + return res; +} + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) static void sigint_handler(int signo) { if (signo == SIGINT) { @@ -131,7 +151,11 @@ static void sigint_handler(int signo) { } #endif +// return the formatted turn to be decoded static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, const std::string & role, const std::string & content) { + if (content.empty()) { + return ""; + } common_chat_msg new_msg{role, content}; auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); chat_msgs.push_back({role, content}); @@ -193,6 +217,7 @@ int main(int argc, char ** argv) { llama_context * ctx = nullptr; common_sampler * smpl = nullptr; + std::vector pos_history; // history of positions of chat messages std::vector chat_msgs; g_model = &model; @@ -519,6 +544,7 @@ int main(int argc, char ** argv) { display = params.display_prompt; std::vector embd; + llama_batch batch = llama_batch_init(params.n_batch, 0, 1); // tokenized antiprompts std::vector> antiprompt_ids; @@ -546,6 +572,8 @@ int main(int argc, char ** argv) { embd_inp.push_back(decoder_start_token_id); } + std::stringstream pending_input; // used by "/readfile" command + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (!embd.empty()) { @@ -652,7 +680,19 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + common_batch_clear(batch); + for (int j = 0; j < n_eval; j++) { + int idx = i + j; + common_batch_add( + batch, + embd[idx], + n_past + idx, + {0}, + idx == (int) embd.size() - 1 + ); + } + + if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -856,42 +896,83 @@ int main(int argc, char ** argv) { LOG_DBG("buffer: '%s'\n", buffer.c_str()); // check for special commands - if (buffer.rfind(CMD_READFILE, 0) == 0) { - const std::string filename = string_strip(buffer.substr(CMD_READFILE.length())); + const std::vector cmd = params.special_cmds + ? try_parse_command(buffer) + : std::vector(); + + if (cmd.size() == 2 && cmd[0] == "/readfile") { + const std::string filename = cmd[1]; LOG_DBG("reading file: '%s'\n", filename.c_str()); std::ifstream text_file(filename); if (!text_file) { LOG("failed to open file '%s'\n", filename.c_str()); continue; } - std::stringstream tmp; - tmp << text_file.rdbuf(); - buffer = tmp.str(); - LOG("%s\n", buffer.c_str()); - } else if (buffer.rfind(CMD_SAVE_SESS, 0) == 0) { - const std::string filename = string_strip(buffer.substr(CMD_SAVE_SESS.length())); + pending_input << text_file.rdbuf() << "\n\n"; + LOG("read %zu characters from file\n", (size_t) text_file.tellg()); + continue; + } else if (cmd.size() == 2 && cmd[0] == "/savesess") { + const std::string filename = cmd[1]; LOG("save session file: '%s'\n", filename.c_str()); size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past); if (res == 0) { LOG("failed to save session file '%s'\n", filename.c_str()); } continue; - } else if (buffer.rfind(CMD_LOAD_SESS, 0) == 0) { - const std::string filename = string_strip(buffer.substr(CMD_LOAD_SESS.length())); + } else if (cmd.size() == 2 && cmd[0] == "/loadsess") { + const std::string filename = cmd[1]; LOG("load session file: '%s'\n", filename.c_str()); - std::vector sess_tokens; - sess_tokens.resize(n_ctx); - size_t n_loaded_tokens; - size_t res = llama_state_load_file(ctx, filename.c_str(), sess_tokens.data(), sess_tokens.size(), &n_loaded_tokens); + session_tokens.resize(n_ctx); + size_t n_token_count_out; + size_t res = llama_state_load_file(ctx, filename.c_str(), session_tokens.data(), session_tokens.size(), &n_token_count_out); if (res == 0) { LOG("failed to load session file '%s'\n", filename.c_str()); } else { - n_past = n_loaded_tokens; - LOG("loaded %zu tokens from session file '%s'\n", n_loaded_tokens, filename.c_str()); + session_tokens.resize(n_token_count_out); + embd_inp = session_tokens; + n_past = n_token_count_out; + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + LOG("loaded %zu tokens from session file '%s'\n", n_token_count_out, filename.c_str()); + } + continue; + } else if (cmd.size() == 1 && cmd[0] == "/regen") { + if (pos_history.empty()) { + LOG("no previous assistant message to regenerate\n"); + continue; + } + int last_n_past = pos_history.back(); + int n_tokens_removed = n_past - last_n_past; + llama_kv_cache_seq_rm(ctx, 0, last_n_past, -1); + n_remain += n_tokens_removed; + is_interacting = false; + // we intentionally do not reset the sampling, so new message will be more diverse + continue; + } else if (cmd.size() == 2 && cmd[0] == "/dump") { + const std::string filename = cmd[1]; + std::ofstream dump_file(filename); + if (!dump_file) { + LOG("failed to create file '%s'\n", filename.c_str()); + continue; + } + for (const auto & msg : chat_msgs) { + dump_file << msg.role << ":\n" << msg.content << "\n---\n"; } + dump_file.close(); + LOG("dumped chat messages to file '%s'\n", filename.c_str()); + continue; + } else if (!cmd.empty()) { + LOG("unknown command: %s\n", buffer.c_str()); + LOG("%s", help_special_cmds); continue; } + if (pending_input.tellp() > 0) { + // concatenate read file and the prompt + pending_input << buffer; + buffer = pending_input.str(); + pending_input.clear(); + } + const size_t original_size = embd_inp.size(); if (params.escape) { @@ -926,6 +1007,8 @@ int main(int argc, char ** argv) { output_ss << common_token_to_piece(ctx, token); } + pos_history.push_back(n_past + embd_inp.size() - original_size); + // reset assistant message assistant_ss.str(""); @@ -971,6 +1054,7 @@ int main(int argc, char ** argv) { common_sampler_free(smpl); + llama_batch_free(batch); llama_free(ctx); llama_free_model(model);