From 8afc0f3784c1162b62573cd97086c94db03d4550 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 24 May 2024 00:15:00 +0200 Subject: [PATCH] --hf-repo without --hf-file --- common/common.cpp | 138 ++++++++++++++++++++++++++++++++++------------ common/common.h | 2 +- 2 files changed, 104 insertions(+), 36 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 7500e08ff1be4..c9b4490593894 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -193,18 +193,15 @@ int32_t cpu_get_num_math() { void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model - if (params.hf_file.empty()) { - if (params.model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); - } - params.hf_file = params.model; - } else if (params.model.empty()) { + params.model_url = llama_get_hf_model_url(params.hf_repo, params.hf_file); + if (params.model.empty()) { std::string cache_directory = fs_get_cache_directory(); const bool success = fs_create_directory_with_parents(cache_directory); if (!success) { throw std::runtime_error("failed to create cache directory: " + cache_directory); } - params.model = cache_directory + string_split(params.hf_file, '/').back(); + // TODO: cache with params.hf_repo in directory + params.model = cache_directory + string_split(params.model_url, '/').back(); } } else if (!params.model_url.empty()) { if (params.model.empty()) { @@ -1888,9 +1885,7 @@ std::tuple llama_init_from_gpt_par llama_model * model = nullptr; - if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); - } else if (!params.model_url.empty()) { + if (!params.model_url.empty()) { model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); } else { model = llama_load_model_from_file(params.model.c_str(), mparams); @@ -2061,6 +2056,16 @@ static bool starts_with(const std::string & str, const std::string & prefix) { return str.rfind(prefix, 0) == 0; } +static bool ends_with(const std::string & str, const std::string & suffix) { + return str.rfind(suffix) == str.length() - suffix.length(); +} + +static std::string tolower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c){ return std::tolower(c); }); + return s; +} + static bool llama_download_file(const std::string & url, const std::string & path) { // Initialize libcurl @@ -2341,26 +2346,91 @@ struct llama_model * llama_load_model_from_url( return llama_load_model_from_file(path_model, params); } -struct llama_model * llama_load_model_from_hf( - const char * repo, - const char * model, - const char * path_model, - const struct llama_model_params & params) { - // construct hugging face model url: - // - // --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf - // https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf - // - // --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf - // https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf - // - - std::string model_url = "https://huggingface.co/"; - model_url += repo; - model_url += "/resolve/main/"; - model_url += model; - - return llama_load_model_from_url(model_url.c_str(), path_model, params); +static std::string llama_get_hf_model_url( + std::string & repo, + std::string & custom_file_path) { + std::stringstream ss; + json repo_files; + + if (!custom_file_path.empty()) { + ss << "https://huggingface.co/" << repo << "/resolve/main/" << custom_file_path; + return ss.str(); + } + + { + // Initialize libcurl + std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + + // Make the request to Hub API + ss << "https://huggingface.co/api/models/" << repo << "/tree/main?recursive=true"; + std::string url = ss.str(); + std::string res_str; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + fprintf(stderr, "%s: cannot make GET request to Hugging Face Hub API\n", __func__); + return nullptr; + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code != 200) { + fprintf(stderr, "%s: Hugging Face Hub API responses with status code %ld\n", __func__, res_code); + return nullptr; + } else { + repo_files = json::parse(res_str); + } + } + + if (!repo_files.is_array()) { + fprintf(stderr, "%s: response from Hugging Face Hub API is not an array\nRaw response:\n%s", __func__, repo_files.dump(4).c_str()); + return nullptr; + } + + auto get_file_contains = [&](std::string piece) -> std::string { + for (auto elem : repo_files) { + std::string type = elem.at("type"); + std::string path = elem.at("path"); + if ( + type == "file" + && ends_with(path, ".gguf") + && tolower(path).find(piece) != std::string::npos + ) return path; + } + return ""; + }; + + std::string file_path = get_file_contains("q4_k_m"); + if (file_path.empty()) { + file_path = get_file_contains("q4"); + } + if (file_path.empty()) { + file_path = get_file_contains("00001"); + } + if (file_path.empty()) { + file_path = get_file_contains("gguf"); + } + + if (file_path.empty()) { + fprintf(stderr, "%s: Cannot find any gguf file in the given repository", __func__); + return nullptr; + } + + ss = std::stringstream(); + ss << "https://huggingface.co/" << repo << "/resolve/main/" << file_path; + return ss.str(); } #else @@ -2373,11 +2443,9 @@ struct llama_model * llama_load_model_from_url( return nullptr; } -struct llama_model * llama_load_model_from_hf( - const char * /*repo*/, - const char * /*model*/, - const char * /*path_model*/, - const struct llama_model_params & /*params*/) { +static std::string llama_get_hf_model_url( + std::string & /*repo*/, + std::string & /*custom_file_path*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); return nullptr; } diff --git a/common/common.h b/common/common.h index f68f3c2979b94..bfe659a0f562b 100644 --- a/common/common.h +++ b/common/common.h @@ -223,7 +223,7 @@ struct llama_model_params llama_model_params_from_gpt_params (const gpt_param struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params); -struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params); +static std::string llama_get_hf_model_url(std::string & repo, std::string & custom_file_path); // Batch utils