From 97f2167356a0bb43cb15b5b335c6fcad4de23842 Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 15:39:27 +0800 Subject: [PATCH 01/11] feat: add tokenizer and chat template factory. --- xllm_service/chat_template/CMakeLists.txt | 2 + .../chat_template/chat_template_factory.cpp | 50 +++++++++++++++++++ .../chat_template/chat_template_factory.h | 12 +++++ 3 files changed, 64 insertions(+) create mode 100644 xllm_service/chat_template/chat_template_factory.cpp create mode 100644 xllm_service/chat_template/chat_template_factory.h diff --git a/xllm_service/chat_template/CMakeLists.txt b/xllm_service/chat_template/CMakeLists.txt index d5c6081..d47e081 100644 --- a/xllm_service/chat_template/CMakeLists.txt +++ b/xllm_service/chat_template/CMakeLists.txt @@ -6,8 +6,10 @@ cc_library ( chat_template HDRS jinja_chat_template.h + chat_template_factory.h SRCS jinja_chat_template.cpp + chat_template_factory.cpp DEPS :minja :tokenizer diff --git a/xllm_service/chat_template/chat_template_factory.cpp b/xllm_service/chat_template/chat_template_factory.cpp new file mode 100644 index 0000000..8e347c4 --- /dev/null +++ b/xllm_service/chat_template/chat_template_factory.cpp @@ -0,0 +1,50 @@ +#include "chat_template/chat_template_factory.h" + +#include + +#include "chat_template/coded_chat_template.h" +#include "chat_template/common_chat_template.h" +#include "chat_template/jinja_chat_template.h" + +namespace xllm_service { + +constexpr std::array JINJA_CHAT_TEMPLATE_MODELS{ + "deepseek_v3_mtp", + "deepseek_v2", + "deepseek_v3", + "qwen2", + "qwen3"}; + +constexpr bool is_jinja_model(std::string_view model) { + for (auto m : JINJA_CHAT_TEMPLATE_MODELS) { + if (m == model) return true; + } + return false; +} + +std::unique_ptr create_chat_template( + const std::string& model_type, + const TokenizerArgs& tokenizer_args) { + if (is_jinja_model(model_type)) { + return std::make_unique(tokenizer_args); + } else if (model_type == "chatglm") { + return std::make_unique(); + } else if (model_type == "chatglm4") { + return std::make_unique(); + } else if (model_type == "llama") { + return std::make_unique(); + } else if (model_type == "llama3") { + return std::make_unique(); + } else if (model_type == "rhino") { + return std::make_unique(); + } else if (model_type == "minicpmv") { + return std::make_unique(); + } else if (model_type == "qwen") { + return std::make_unique(); + } else { + LOG(FATAL) << "Unknow model: " << model_type + << ", create ChatTemplate fail!"; + } +} + +} // namespace xllm_service diff --git a/xllm_service/chat_template/chat_template_factory.h b/xllm_service/chat_template/chat_template_factory.h new file mode 100644 index 0000000..9c2d939 --- /dev/null +++ b/xllm_service/chat_template/chat_template_factory.h @@ -0,0 +1,12 @@ +#pragma once + +#include "chat_template.h" +#include "tokenizer/tokenizer_args.h" + +namespace xllm_service { + +std::unique_ptr create_chat_template( + const std::string& model_type, + const TokenizerArgs& tokenizer_args); + +} // namespace xllm_service \ No newline at end of file From b2794eb3089810570d5937dc600c2aedaea6b361 Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 15:42:39 +0800 Subject: [PATCH 02/11] feat: add hash util for token ids. --- .gitmodules | 3 ++ CMakeLists.txt | 1 + third_party/CMakeLists.txt | 3 +- vcpkg.json | 4 ++ xllm_service/common/CMakeLists.txt | 4 +- xllm_service/common/global_gflags.cpp | 2 + xllm_service/common/global_gflags.h | 2 + xllm_service/common/hash_util.cpp | 62 +++++++++++++++++++++++++++ xllm_service/common/hash_util.h | 59 +++++++++++++++++++++++++ 9 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 xllm_service/common/hash_util.cpp create mode 100644 xllm_service/common/hash_util.h diff --git a/.gitmodules b/.gitmodules index 2c43fb2..507ceba 100644 --- a/.gitmodules +++ b/.gitmodules @@ -16,3 +16,6 @@ [submodule "third_party/minja"] path = third_party/minja url = https://gitcode.com/xLLM-AI/minja.git +[submodule "third_party/smhasher"] + path = third_party/smhasher + url = https://gitcode.com/xLLM-AI/smhasher.git diff --git a/CMakeLists.txt b/CMakeLists.txt index a5ee4cb..617a5e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,7 @@ endif() # find all dependencies from vcpkg find_package(Boost REQUIRED) +find_package(Boost REQUIRED COMPONENTS serialization) find_package(glog CONFIG REQUIRED) find_package(gflags CONFIG REQUIRED) find_package(leveldb CONFIG REQUIRED) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index bb3c3a1..280b058 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -4,4 +4,5 @@ add_subdirectory(cpprestsdk) add_subdirectory(etcd_cpp_apiv3) add_subdirectory(brpc) add_subdirectory(minja) -add_subdirectory(sentencepiece) \ No newline at end of file +add_subdirectory(sentencepiece) +add_subdirectory(smhasher/src) \ No newline at end of file diff --git a/vcpkg.json b/vcpkg.json index 86653fd..d57f269 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -38,6 +38,10 @@ "name": "boost-random", "version>=": "1.84.0" }, + { + "name": "boost-serialization", + "version>=": "1.84.0" + }, { "name": "protobuf", "version>=": "3.21.12", diff --git a/xllm_service/common/CMakeLists.txt b/xllm_service/common/CMakeLists.txt index df21039..2ffc2f9 100644 --- a/xllm_service/common/CMakeLists.txt +++ b/xllm_service/common/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library( threadpool.h types.h utils.h + hash_util.h xllm/output.h xllm/status.h xllm/uuid.h @@ -22,6 +23,7 @@ cc_library( json_reader.cpp threadpool.cpp utils.cpp + hash_util.cpp xllm/uuid.cpp DEPS absl::random_random @@ -29,6 +31,6 @@ cc_library( glog::glog gflags::gflags nlohmann_json::nlohmann_json + SMHasherSupport ) -target_link_libraries(common PRIVATE OpenSSL::SSL OpenSSL::Crypto) add_dependencies(common brpc-static) diff --git a/xllm_service/common/global_gflags.cpp b/xllm_service/common/global_gflags.cpp index ff84c32..d638379 100644 --- a/xllm_service/common/global_gflags.cpp +++ b/xllm_service/common/global_gflags.cpp @@ -55,6 +55,8 @@ DEFINE_string(etcd_addr, "0.0.0.0:2379", "etcd adderss for save instance meta info"); +DEFINE_uint32(murmur_hash3_seed, 1024, "default Murmur Hash seed"); + DEFINE_int32(port, 8888, "Port for xllm service to listen on"); DEFINE_int32(num_threads, 32, "Number of threads to process requests"); diff --git a/xllm_service/common/global_gflags.h b/xllm_service/common/global_gflags.h index 657a918..7409c0d 100644 --- a/xllm_service/common/global_gflags.h +++ b/xllm_service/common/global_gflags.h @@ -37,6 +37,8 @@ DECLARE_int32(rpc_server_num_threads); DECLARE_int32(rpc_server_max_concurrency); +DECLARE_uint32(murmur_hash3_seed); + DECLARE_string(test_instance_addr); DECLARE_int32(timeout_ms); diff --git a/xllm_service/common/hash_util.cpp b/xllm_service/common/hash_util.cpp new file mode 100644 index 0000000..5c131b7 --- /dev/null +++ b/xllm_service/common/hash_util.cpp @@ -0,0 +1,62 @@ + +#include "common/hash_util.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "common/global_gflags.h" + +namespace xllm_service { + +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value) { + if (pre_hash_value == nullptr) { + MurmurHash3_x64_128(reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size(), + FLAGS_murmur_hash3_seed, + hash_value); + } else { + uint8_t key[1024]; + + int32_t data_len = + sizeof(int32_t) * token_ids.size() + MURMUR_HASH3_VALUE_LEN; + assert(sizeof(key) > data_len); + + memcpy(key, pre_hash_value, MURMUR_HASH3_VALUE_LEN); + memcpy(key + MURMUR_HASH3_VALUE_LEN, + reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size()); + + // print_hex_array(key, data_len); + MurmurHash3_x64_128(reinterpret_cast(key), + data_len, + FLAGS_murmur_hash3_seed, + hash_value); + } +} + +void print_hex_array(uint8_t* array) { + for (size_t i = 0; i < MURMUR_HASH3_VALUE_LEN; ++i) { + unsigned char uc = static_cast(array[i]); + std::cout << std::hex << std::setw(2) << std::setfill('0') + << static_cast(uc); + + if (i % MURMUR_HASH3_VALUE_LEN == MURMUR_HASH3_VALUE_LEN - 1) { + std::cout << std::endl; + } + + else { + std::cout << " "; + } + } + std::cout << std::dec << std::endl; +} + +} // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/common/hash_util.h b/xllm_service/common/hash_util.h new file mode 100644 index 0000000..435944e --- /dev/null +++ b/xllm_service/common/hash_util.h @@ -0,0 +1,59 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "common/slice.h" + +namespace xllm_service { +constexpr uint32_t MURMUR_HASH3_VALUE_LEN = 16; + +struct Murmur3Key { + uint8_t data[MURMUR_HASH3_VALUE_LEN]; + + Murmur3Key() {} + Murmur3Key(const uint8_t* const input_data) { + memcpy(data, input_data, MURMUR_HASH3_VALUE_LEN); + } + Murmur3Key(const char* const input_data) { + memcpy(data, input_data, MURMUR_HASH3_VALUE_LEN); + } + + std::string to_string() const { + return std::string(reinterpret_cast(data), + MURMUR_HASH3_VALUE_LEN); + } + + bool operator==(const Murmur3Key& other) { + return strncmp(reinterpret_cast(data), + reinterpret_cast(other.data), + MURMUR_HASH3_VALUE_LEN); + } +}; + +struct FixedStringKeyHash { + size_t operator()(const Murmur3Key& key) const { + return std::hash()(std::string_view( + reinterpret_cast(key.data), sizeof(key.data))); + } +}; + +struct FixedStringKeyEqual { + bool operator()(const Murmur3Key& left, const Murmur3Key& right) const { + return strncmp(reinterpret_cast(left.data), + reinterpret_cast(right.data), + sizeof(left.data)) == 0; + } +}; + +void print_hex_array(uint8_t* array); + +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value); + +} // namespace xllm_service From d0662c178e34d2975510d8cdeba7da8e1f6888aa Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 16:02:38 +0800 Subject: [PATCH 03/11] feat: add service routing proto. --- xllm_service/proto/xllm/chat.proto | 6 ++++-- xllm_service/proto/xllm/common.proto | 6 ++---- xllm_service/proto/xllm/completion.proto | 6 ++++-- xllm_service/proto/xllm_rpc_service.proto | 1 - 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/xllm_service/proto/xllm/chat.proto b/xllm_service/proto/xllm/chat.proto index 5ded588..1fdecd2 100644 --- a/xllm_service/proto/xllm/chat.proto +++ b/xllm_service/proto/xllm/chat.proto @@ -20,7 +20,7 @@ message ChatMessage { } -// Next Id: 25 +// Next Id: 28 message ChatRequest { // ID of the model to use. You can use the ListModels endpoint to list available models. @@ -110,7 +110,9 @@ message ChatRequest { optional string service_request_id = 25; - Routing routing = 28; + repeated int32 token_ids = 26; + + Routing routing = 27; } message ChatLogProbData { diff --git a/xllm_service/proto/xllm/common.proto b/xllm_service/proto/xllm/common.proto index b821d1d..8015e64 100644 --- a/xllm_service/proto/xllm/common.proto +++ b/xllm_service/proto/xllm/common.proto @@ -24,9 +24,7 @@ message Status { } message Routing { - repeated int32 token_ids = 1; + string prefill_name = 1; - string prefill_name = 2; - - string decode_name = 3; + string decode_name = 2; } diff --git a/xllm_service/proto/xllm/completion.proto b/xllm_service/proto/xllm/completion.proto index 9a8f70d..9b1c78f 100644 --- a/xllm_service/proto/xllm/completion.proto +++ b/xllm_service/proto/xllm/completion.proto @@ -5,7 +5,7 @@ option cc_enable_arenas = true; import "common.proto"; -// Next ID: 25 +// Next ID: 28 message CompletionRequest { // ID of the model to use. (required) // You can use the ListModels endpoint to list available models. @@ -87,7 +87,9 @@ message CompletionRequest { optional string service_request_id = 25; - Routing routing = 28; + repeated int32 token_ids = 26; + + Routing routing = 27; } message LogProbs { diff --git a/xllm_service/proto/xllm_rpc_service.proto b/xllm_service/proto/xllm_rpc_service.proto index d769ba3..023f7f9 100644 --- a/xllm_service/proto/xllm_rpc_service.proto +++ b/xllm_service/proto/xllm_rpc_service.proto @@ -51,7 +51,6 @@ message LoadMetrics { float gpu_cache_usage_perc = 2; } -// TODO: add metainfo/metrics ? message HeartbeatRequest { string name = 1; KvCacheEvent cache_event = 2; From d8121161b4151438c4fbc95ee8cd0047faf485de Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 16:12:28 +0800 Subject: [PATCH 04/11] feat: update etcd client. --- xllm_service/common/types.h | 241 +++++++++++++++++++++-- xllm_service/rpc_service/etcd_client.cpp | 175 +++++++++++----- xllm_service/rpc_service/etcd_client.h | 118 ++++++++++- 3 files changed, 462 insertions(+), 72 deletions(-) diff --git a/xllm_service/common/types.h b/xllm_service/common/types.h index c681ab6..478e642 100644 --- a/xllm_service/common/types.h +++ b/xllm_service/common/types.h @@ -19,13 +19,23 @@ limitations under the License. #include #include +#include #include +#include +#include #include +#include "common/hash_util.h" #include "nlohmann/json.hpp" namespace xllm_service { +struct CacheLocations; +using Murmur3KeyCacheMap = std::unordered_map; + struct HttpServiceConfig { int num_threads = 16; int timeout_ms = -1; @@ -37,13 +47,32 @@ struct RpcServiceConfig { std::string etcd_addr = ""; std::string disagg_pd_policy = ""; int detect_disconnected_instance_interval = 15; // seconds + std::string service_name = ""; +}; + +struct ModelConfig { + int32_t block_size = 16; + std::string model_type = "chatglm"; + std::string tokenizer_path = ""; +}; + +struct Routing { + std::string prefill_name; + std::string decode_name; + + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["prefill_name"] = prefill_name; + json_val["decode_name"] = decode_name; + return json_val; + } + + std::string debug_string() const { return serialize_to_json().dump(2); } }; -// instances pair for prefill and decode in disagg PD mode. -struct InstancesPair { - std::string prefill_instance_http_addr = ""; - // empty means no decode instance, only prefill instance is available - std::string decode_instance_http_addr = ""; +struct SchduleResult { + std::vector token_ids; + Routing routing; }; enum class ErrorCode : int32_t { @@ -72,6 +101,42 @@ enum class InstanceType : int8_t { DECODE = 2, }; +struct LoadMetrics { + LoadMetrics() : waiting_requests_num(0), gpu_cache_usage_perc(0){}; + LoadMetrics(const uint64_t& waiting_reqs_num, const float& usage) + : waiting_requests_num(waiting_reqs_num), gpu_cache_usage_perc(usage){}; + + uint64_t waiting_requests_num; + float gpu_cache_usage_perc; + + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["waiting_requests_num"] = waiting_requests_num; + json_val["gpu_cache_usage_perc"] = gpu_cache_usage_perc; + return json_val; + } + + std::string debug_string() const { return serialize_to_json().dump(2); } + + bool parse_from_json(const std::string& json_str) { + try { + nlohmann::json json_value = nlohmann::json::parse(json_str); + + waiting_requests_num = + json_value.at("waiting_requests_num").get(); + gpu_cache_usage_perc = json_value.at("gpu_cache_usage_perc").get(); + + } catch (const std::exception& e) { + LOG(ERROR) << "json str:" << json_str + << ", parse to loadmetrics error: " << e.what(); + return false; + } + return true; + } + + bool empty() const { return false; } +}; + struct InstanceMetaInfo { public: InstanceMetaInfo() { set_init_timestamp(); } @@ -98,6 +163,61 @@ struct InstanceMetaInfo { // latest heatbeat timestamp uint64_t latest_timestamp = 0; + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["name"] = name; + json_val["rpc_address"] = rpc_address; + json_val["type"] = int8_t(type); + json_val["addrs"] = addrs; + json_val["cluster_ids"] = cluster_ids; + json_val["k_cache_ids"] = k_cache_ids; + json_val["v_cache_ids"] = v_cache_ids; + json_val["dp_size"] = dp_size; + return json_val; + } + + std::string debug_string() const { return serialize_to_json().dump(2); } + + bool parse_from_json(const std::string& json_str) { + try { + nlohmann::json json_value = nlohmann::json::parse(json_str); + name = json_value.at("name").get(); + rpc_address = json_value.at("rpc_address").get(); + type = static_cast(json_value.at("type").get()); + + for (const auto& item : + json_value.at("cluster_ids").get>()) { + cluster_ids.push_back(item); + } + + for (const auto& item : + json_value.at("k_cache_ids").get>()) { + k_cache_ids.push_back(item); + } + + for (const auto& item : + json_value.at("addrs").get>()) { + addrs.push_back(item); + } + + for (const auto& item : + json_value.at("v_cache_ids").get>()) { + v_cache_ids.push_back(item); + } + + dp_size = json_value.at("dp_size").get(); + + set_init_timestamp(); + } catch (const std::exception& e) { + LOG(ERROR) << "json str:" << json_str + << ", parse to instancemetainfo error: " << e.what(); + return false; + } + return true; + } + + bool empty() const { return rpc_address == ""; } + private: void set_init_timestamp() { auto now = std::chrono::system_clock::now(); @@ -108,17 +228,106 @@ struct InstanceMetaInfo { } }; -// the info be stored in etcd -struct InstanceIdentityInfo { - std::string instance_addr; - std::string rpc_addr; - int8_t instance_type; // convert to InstanceType - - const std::string debug_string() const { - std::string debug_str = - "instance_addr: " + instance_addr + ", rpc_addr: " + rpc_addr + - ", instance_type: " + std::to_string((int)(instance_type)); - return debug_str; +struct CacheLocations { + std::unordered_set hbm_instance_set; + std::unordered_set dram_instance_set; + std::unordered_set ssd_instance_set; + + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["hbm_instance_set"] = hbm_instance_set; + json_val["dram_instance_set"] = dram_instance_set; + json_val["ssd_instance_set"] = ssd_instance_set; + return json_val; + } + + std::string debug_string() { return serialize_to_json().dump(2); } + + bool parse_from_json(const std::string& json_str) { + try { + nlohmann::json json_value = nlohmann::json::parse(json_str); + for (const auto& item : + json_value.at("hbm_instance_set").get>()) { + hbm_instance_set.insert(item); + } + + for (const auto& item : + json_value.at("dram_instance_set").get>()) { + dram_instance_set.insert(item); + } + + for (const auto& item : + json_value.at("ssd_instance_set").get>()) { + ssd_instance_set.insert(item); + } + + } catch (const std::exception& e) { + LOG(ERROR) << "json str:" << json_str + << ", parse to cachelocation error: " << e.what(); + return false; + } + return true; + } + + bool empty() const { + return hbm_instance_set.empty() && dram_instance_set.empty() && + ssd_instance_set.empty(); + } +}; + +struct OverlapScores { + std::unordered_set instances; + std::unordered_map hbm_instance_score; + std::unordered_map dram_instance_score; + std::unordered_map ssd_instance_score; + uint32_t max_block_num = 0; + uint32_t max_matched_block_num = 0; + std::string max_matched_instance_name = ""; + + std::string debug_string() { + nlohmann::json json_val; + json_val["instances"] = instances; + json_val["hbm_instance_score"] = hbm_instance_score; + json_val["dram_instance_score"] = dram_instance_score; + json_val["ssd_instance_score"] = ssd_instance_score; + json_val["max_block_num"] = max_block_num; + json_val["max_matched_block_num"] = max_matched_block_num; + json_val["max_matched_instance_name"] = max_matched_instance_name; + return json_val.dump(2); + } +}; + +struct LoadBalanceInfos { + OverlapScores overlap_scores; + std::unordered_map prefill_load_metrics; + std::unordered_map decode_load_metrics; + uint64_t prefill_max_waiting_requests_num = 0; + uint64_t decode_max_waiting_requests_num = 0; + + std::string debug_string() { + nlohmann::json json_val; + + json_val["overlap_scores"] = + nlohmann::json::parse(overlap_scores.debug_string()); + + nlohmann::json prefill_json; + for (auto& [key, metrics] : prefill_load_metrics) { + prefill_json[key] = nlohmann::json::parse(metrics.debug_string()); + } + json_val["prefill_load_metrics"] = prefill_json; + + nlohmann::json decode_json; + for (auto& [key, metrics] : decode_load_metrics) { + decode_json[key] = nlohmann::json::parse(metrics.debug_string()); + } + json_val["decode_load_metrics"] = decode_json; + + json_val["prefill_max_waiting_requests_num"] = + prefill_max_waiting_requests_num; + json_val["decode_max_waiting_requests_num"] = + decode_max_waiting_requests_num; + + return json_val.dump(2); } }; diff --git a/xllm_service/rpc_service/etcd_client.cpp b/xllm_service/rpc_service/etcd_client.cpp index 26f7d13..10d1e08 100644 --- a/xllm_service/rpc_service/etcd_client.cpp +++ b/xllm_service/rpc_service/etcd_client.cpp @@ -23,92 +23,173 @@ namespace xllm_service { EtcdClient::EtcdClient(const std::string& etcd_addr) : client_(etcd_addr), etcd_addr_(etcd_addr) { + LOG(INFO) << "EtcdClient init put start!"; auto response = client_.put("XLLM_PING", "PING"); + LOG(INFO) << "EtcdClient init put end!"; if (!response.is_ok()) { LOG(FATAL) << "etcd connect to etcd server failed: " << response.error_message(); } } -EtcdClient::~EtcdClient() {} +EtcdClient::~EtcdClient() { stop_watch(); } -bool EtcdClient::get(const std::string& key, InstanceIdentityInfo& value) { - auto response = client_.get(key); +bool EtcdClient::set(const std::string& key, const std::string& value) { + auto response = client_.put(key, value); if (!response.is_ok()) { - LOG(ERROR) << "etcd get " << key << " failed: " << response.error_message(); + LOG(ERROR) << "etcd set " << key << " failed: " << response.error_message(); return false; } - auto json_str = response.value().as_string(); - try { - nlohmann::json json_value = nlohmann::json::parse(json_str); - value.instance_addr = json_value.at("instance_addr").get(); - value.instance_type = json_value.at("instance_type").get(); - } catch (const std::exception& e) { - LOG(ERROR) << "etcd get " << key - << " failed: json parse error: " << e.what(); + + return true; +} + +bool EtcdClient::set(const std::string& key, + const std::string& value, + const int ttl) { + auto keep_alive = std::make_shared(client_, ttl); + etcdv3::Transaction transaction; + transaction.add_compare_create(key, 0); + transaction.add_success_put(key, value, keep_alive->Lease()); + etcd::Response response = client_.txn(transaction); + if (response.is_ok()) { + keep_alives_.emplace_back(std::move(keep_alive)); + return true; + } else { + keep_alive->Cancel(); return false; } +} +bool EtcdClient::set(const std::string& key_prefix, + const Murmur3KeyCacheMap& values) { + bool rt = true; + for (const auto& iter : values) { + if (iter.second.empty()) { + rt = rt && client_.rm(key_prefix + iter.first.to_string()).is_ok(); + } else { + rt = rt && client_ + .put(key_prefix + iter.first.to_string(), + iter.second.serialize_to_json().dump()) + .is_ok(); + } + } return true; } -bool EtcdClient::get_prefix(const std::string& key_prefix, - std::vector& values) { - auto response = client_.ls(key_prefix); +bool EtcdClient::rm(const std::string& key) { + auto response = client_.rm(key); if (!response.is_ok()) { - LOG(ERROR) << "etcd get " << key_prefix - << " failed: " << response.error_message(); + LOG(ERROR) << "etcd rm " << key << " failed: " << response.error_message(); return false; } - for (const auto& v : response.values()) { - InstanceIdentityInfo value; - auto json_str = v.as_string(); - try { - nlohmann::json json_value = nlohmann::json::parse(json_str); - value.instance_addr = json_value.at("instance_addr").get(); - value.instance_type = json_value.at("instance_type").get(); - values.emplace_back(value); - } catch (const std::exception& e) { - LOG(ERROR) << "etcd get " << key_prefix - << " failed: json parse error: " << e.what(); - return false; - } - } return true; } -bool EtcdClient::set(const std::string& key, - const InstanceIdentityInfo& value) { - std::string json_str; - try { - nlohmann::json json_value; - json_value["instance_addr"] = value.instance_addr; - json_value["instance_type"] = value.instance_type; - json_str = json_value.dump(); - } catch (const std::exception& e) { - LOG(ERROR) << "etcd set " << key - << " failed: json dump error: " << e.what(); +bool EtcdClient::rm(const std::string& key_prefix, + const std::unordered_set& keys) { + etcdv3::Transaction transaction; + transaction.add_compare_version( + "XLLM:SERVICE:MASTER", etcdv3::CompareResult::GREATER, -1); + for (const auto& iter : keys) { + transaction.add_success_delete(key_prefix + iter); + } + return client_.txn(transaction).is_ok(); +} + +bool EtcdClient::get(const std::string& key, std::string* value) { + auto response = client_.get(key); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key << " failed: " << response.error_message(); return false; } + if (value) { + *value = response.value().as_string(); + } + return true; +} - auto response = client_.put(key, json_str); +bool EtcdClient::get_prefix(const std::string& key_prefix, + Murmur3KeyCacheMap* values) { + auto response = client_.ls(key_prefix); if (!response.is_ok()) { - LOG(ERROR) << "etcd set " << key << " failed: " << response.error_message(); + LOG(ERROR) << "etcd get " << key_prefix + << " failed: " << response.error_message(); return false; } + for (int i = 0; i < response.keys().size(); i++) { + Murmur3Key key(response.key(i).substr(key_prefix.size()).c_str()); + auto json_str = response.value(i).as_string(); + + CacheLocations value; + if (!value.parse_from_json(json_str)) { + LOG(ERROR) << "Parse json fail: " << json_str; + continue; + } + + values->insert_or_assign(std::move(key), std::move(value)); + } return true; } -bool EtcdClient::rm(const std::string& key) { - auto response = client_.rm(key); +bool EtcdClient::get_prefix( + const std::string& key_prefix, + std::unordered_map* values) { + auto response = client_.ls(key_prefix); if (!response.is_ok()) { - LOG(ERROR) << "etcd rm " << key << " failed: " << response.error_message(); + LOG(ERROR) << "etcd get " << key_prefix + << " failed: " << response.error_message(); return false; } + for (int i = 0; i < response.keys().size(); i++) { + auto key_str = response.key(i).substr(key_prefix.size()); + auto str = response.value(i).as_string(); + + values->insert_or_assign(std::move(key_str), std::move(str)); + } return true; } +void EtcdClient::add_watch(const std::string& key_prefix, + Callback callback, + bool recursive) { + std::lock_guard lock(watchers_mutex_); + + if (watchers_.find(key_prefix) != watchers_.end()) { + watchers_[key_prefix].watcher->Cancel(); + } + auto watcher = std::make_unique( + client_, + key_prefix, + [callback, key_prefix](etcd::Response response) { + callback(response, uint64_t(key_prefix.size())); + }, + recursive); + + watchers_[key_prefix] = {std::move(watcher), callback}; +} + +void EtcdClient::remove_watch(const std::string& key_prefix) { + std::lock_guard lock(watchers_mutex_); + + auto it = watchers_.find(key_prefix); + if (it != watchers_.end()) { + it->second.watcher->Cancel(); + watchers_.erase(it); + } +} + +void EtcdClient::stop_watch() { + std::lock_guard lock(watchers_mutex_); + + for (auto& pair : watchers_) { + pair.second.watcher->Cancel(); + } + + watchers_.clear(); +} + } // namespace xllm_service diff --git a/xllm_service/rpc_service/etcd_client.h b/xllm_service/rpc_service/etcd_client.h index 56ae128..06c9bb4 100644 --- a/xllm_service/rpc_service/etcd_client.h +++ b/xllm_service/rpc_service/etcd_client.h @@ -15,32 +15,132 @@ limitations under the License. #pragma once +#include #include +#include +#include #include +#include +#include "common/hash_util.h" #include "common/types.h" namespace xllm_service { -// the format is: -// key: XLLM:PREFILL:inst_id -> value -// or -// key: XLLM:DECODE:inst_id -> value +using Callback = std::function; + class EtcdClient { public: EtcdClient(const std::string& etcd_addr); ~EtcdClient(); - bool get(const std::string& key, InstanceIdentityInfo& value); - // get all keys with prefix - bool get_prefix(const std::string& key_prefix, - std::vector& values); - bool set(const std::string& key, const InstanceIdentityInfo& value); + template + bool set(const std::string& key, const T& value) { + auto response = client_.put(key, value.serialize_to_json().dump()); + if (!response.is_ok()) { + LOG(ERROR) << "etcd set " << key + << " failed: " << response.error_message(); + return false; + } + + return true; + } + + template + bool set(const std::string& key_prefix, + const unordered_map& values) { + bool rt = true; + for (const auto& iter : values) { + if (iter.second.empty()) { + rt = rt && client_.rm(key_prefix + iter.first).is_ok(); + } else { + rt = rt && client_ + .put(key_prefix + iter.first, + iter.second.serialize_to_json().dump()) + .is_ok(); + } + } + return true; + } + + bool set(const std::string& key_prefix, const Murmur3KeyCacheMap& values); + + bool set(const std::string& key, const std::string& value); + + // create key-value with lease and transaction + bool set(const std::string& key, const std::string& value, const int ttl); + bool rm(const std::string& key); + bool rm(const std::string& key_prefix, + const std::unordered_set& keys); + + template + bool get(const std::string& key, T* value) { + auto response = client_.get(key); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key + << " failed: " << response.error_message(); + return false; + } + if (value) { + return value->parse_from_json(response.value().as_string()); + } else { + return true; + } + } + + bool get(const std::string& key_prefix, std::string* value); + + template + bool get_prefix(const std::string& key_prefix, + std::unordered_map* values) { + auto response = client_.ls(key_prefix); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key_prefix + << " failed: " << response.error_message(); + return false; + } + + for (int i = 0; i < response.keys().size(); i++) { + auto key_str = response.key(i).substr(key_prefix.size()); + auto json_str = response.value(i).as_string(); + + T value; + if (!value.parse_from_json(json_str)) { + LOG(ERROR) << "Parse json fail: " << json_str; + continue; + } + + values->insert_or_assign(std::move(key_str), std::move(value)); + } + return true; + } + + bool get_prefix(const std::string& key_prefix, Murmur3KeyCacheMap* values); + + bool get_prefix(const std::string& key_prefix, + std::unordered_map* values); + + void add_watch(const std::string& key_prefix, + Callback callback, + bool recursive = true); + + void remove_watch(const std::string& key_prefix); + + void stop_watch(); + private: + struct WatcherInfo { + std::unique_ptr watcher; + Callback callback; + }; + etcd::SyncClient client_; std::string etcd_addr_; + std::mutex watchers_mutex_; + std::map watchers_; + std::vector> keep_alives_; }; } // namespace xllm_service From 6a038fdc30819f067b2a80f3ec9de9df009e1bd6 Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 16:13:50 +0800 Subject: [PATCH 05/11] refactor: refactor InstanceMgr. --- xllm_service/rpc_service/instance_mgr.cpp | 451 ++++++++++++---------- xllm_service/rpc_service/instance_mgr.h | 76 ++-- 2 files changed, 303 insertions(+), 224 deletions(-) diff --git a/xllm_service/rpc_service/instance_mgr.cpp b/xllm_service/rpc_service/instance_mgr.cpp index 8ad59ec..1d0e188 100644 --- a/xllm_service/rpc_service/instance_mgr.cpp +++ b/xllm_service/rpc_service/instance_mgr.cpp @@ -20,9 +20,11 @@ limitations under the License. #include #include +#include #include "common/types.h" #include "common/utils.h" + namespace xllm_service { // magic number, TODO: move to config file or env var @@ -34,257 +36,314 @@ static std::unordered_map ETCD_KEYS_PREFIX_MAP = { }; static std::string ETCD_ALL_KEYS_PREFIX = "XLLM:"; static std::string DEFAULT_DISAGG_PD_POLICY = "RR"; - -InstanceMgr::InstanceMgr(const RpcServiceConfig& config) : config_(config) { - if (config.etcd_addr.empty()) { - LOG(INFO) << "Disable etcd meta server"; - use_etcd_ = false; - } else { - LOG(INFO) << "Connect to etcd meta server: " << config.etcd_addr; - use_etcd_ = true; - etcd_client_ = std::make_unique(config.etcd_addr); +static std::string ETCD_LOADMETRICS_PREFIX = "XLLM:LOADMETRICS:"; + +InstanceMgr::InstanceMgr(const std::shared_ptr& etcd_client, + const HttpServiceConfig& config, + const bool is_master_service) + : http_config_(config), + is_master_service_(is_master_service), + etcd_client_(etcd_client) { + auto handle_instance_metainfo = + std::bind(&InstanceMgr::handle_instance_metainfo_watch, + this, + std::placeholders::_1, + std::placeholders::_2); + for (auto& it : ETCD_KEYS_PREFIX_MAP) { + etcd_client_->add_watch(it.second, handle_instance_metainfo); + } + if (!is_master_service_) { + auto handle_load_metrics = + std::bind(&InstanceMgr::handle_load_metrics_watch, + this, + std::placeholders::_1, + std::placeholders::_2); + etcd_client_->add_watch(ETCD_LOADMETRICS_PREFIX, handle_load_metrics); } - internal_init(); + init(); } -void InstanceMgr::internal_init() { - std::string pd_policy = config_.disagg_pd_policy; - if (config_.disagg_pd_policy.empty()) { - LOG(WARNING) << "Not specify diasgg pd policy, use `RR` policy as default."; - pd_policy = DEFAULT_DISAGG_PD_POLICY; +void InstanceMgr::init() { + { + std::unique_lock lock(inst_mutex_); + for (auto& it : ETCD_KEYS_PREFIX_MAP) { + etcd_client_->get_prefix(it.second, &instances_); + } + LOG(INFO) << "Load instance info from etcd:" << instances_.size(); + for (const auto& name : instances_) { + if (!create_channel(name.first)) { + zombie_nodes_.insert(name.first); + instances_.erase(name.first); + } + } } - if (pd_policy == "RR") { - disagg_pd_policy_ = std::make_unique(); - } else { - LOG(FATAL) << "Not supported diasgg pd policy: " << pd_policy; - return; + { + std::unique_lock lock(load_metric_mutex_); + etcd_client_->get_prefix(ETCD_LOADMETRICS_PREFIX, &load_metrics_); } +} + +InstanceMgr::~InstanceMgr() { exited_ = true; } - heartbeat_thread_ = std::make_unique( - &InstanceMgr::detect_disconnected_instances, this); +InstanceMetaInfo InstanceMgr::get_instance_info( + const std::string& instance_name) { + std::shared_lock lock(inst_mutex_); + if (instances_.find(instance_name) == instances_.end()) { + LOG(ERROR) << "Get instance info failed, instance is not registered, " + "instance_name: " + << instance_name; + return InstanceMetaInfo(); + } + return instances_[instance_name]; } -InstanceMgr::~InstanceMgr() { - exited_ = true; - if (heartbeat_thread_) { - heartbeat_thread_->join(); +// TODO: refactor later, currently return all decode instances +std::vector InstanceMgr::get_static_decode_list( + const std::string& instance_name) { + std::vector decode_list; + std::shared_lock lock(inst_mutex_); + for (auto& inst : instances_) { + if (inst.second.type == InstanceType::DECODE) { + decode_list.emplace_back(inst.second.name); + } } + + return decode_list; } -void InstanceMgr::detect_disconnected_instances() { - while (!exited_) { - std::this_thread::sleep_for(std::chrono::seconds(kDetectIntervals)); - { - std::lock_guard guard(inst_mutex_); - auto now = std::chrono::system_clock::now(); - auto timestamp_ms = std::chrono::duration_cast( - now.time_since_epoch()) - .count(); - std::vector disconnected_instances_name; - for (const auto& [name, info] : instances_) { - if (timestamp_ms - info.latest_timestamp > kDetectIntervals * 1000) { - LOG(WARNING) << "Instance maybe disconnected, instance_name: " << name - << ", last heartbeat interval(s): " - << (timestamp_ms - info.latest_timestamp) / 1000.0; - disconnected_instances_name.emplace_back(name); - } - } +void InstanceMgr::get_load_metrics(LoadBalanceInfos* infos) { + std::shared_lock inst_lock(inst_mutex_); + std::shared_lock metric_lock(load_metric_mutex_); - // no instances disconnected, return - if (disconnected_instances_name.empty()) { - continue; - } + for (auto name : infos->overlap_scores.instances) { + auto it = load_metrics_.find(name); + if (it == load_metrics_.end()) { + continue; + } + auto instance_it = instances_.find(name); + if (instance_it == instances_.end()) { + continue; + } - if (utils::enable_debug_log()) { - const auto instance_names = - absl::StrJoin(disconnected_instances_name, ", "); - LOG(WARNING) << "Detect disconnected instance, instance_name: " - << instance_names; - } - // detele instance metainfo from etcd - delete_persistence_metainfo(disconnected_instances_name); - for (const auto& name : disconnected_instances_name) { - disagg_pd_policy_->remove_instance(name, instances_[name].type); - instances_.erase(name); - } + if (instance_it->second.type == InstanceType::DECODE) { + infos->decode_load_metrics.insert(std::make_pair(name, it->second)); + infos->decode_max_waiting_requests_num = + std::max(infos->decode_max_waiting_requests_num, + it->second.waiting_requests_num); + } else { + infos->prefill_load_metrics.insert(std::make_pair(name, it->second)); + infos->prefill_max_waiting_requests_num = + std::max(infos->prefill_max_waiting_requests_num, + it->second.waiting_requests_num); } } -} -ErrorCode InstanceMgr::register_instance(const std::string& instance_name) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Register instance, instance_name: " << instance_name; + std::string least_loaded_prefill_instance; + int32_t least_loaded_prefill_waiting_reqs = INT32_MAX; + std::string least_loaded_decode_instance; + int32_t least_loaded_decode_waiting_reqs = INT32_MAX; + + if (infos->prefill_load_metrics.size() == 0 || + infos->decode_load_metrics.size() == 0) { + for (const auto& metric : load_metrics_) { + auto instance_it = instances_.find(metric.first); + if (instance_it != instances_.end()) { + if (instance_it->second.type != InstanceType::DECODE) { + if (metric.second.waiting_requests_num < + least_loaded_prefill_waiting_reqs) { + least_loaded_prefill_waiting_reqs = + metric.second.waiting_requests_num; + least_loaded_prefill_instance = metric.first; + } + } else { + if (metric.second.waiting_requests_num < + least_loaded_decode_waiting_reqs) { + least_loaded_decode_waiting_reqs = + metric.second.waiting_requests_num; + least_loaded_decode_instance = metric.first; + } + } + } + } } - if (instances_.find(instance_name) != instances_.end()) { - // update_instance_timestamp(instance_name); - LOG(ERROR) << "Instance is already registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_EXISTED; + + if (infos->prefill_load_metrics.size() == 0 && + !least_loaded_prefill_instance.empty()) { + infos->prefill_load_metrics.insert( + std::make_pair(least_loaded_prefill_instance, + load_metrics_[least_loaded_prefill_instance])); + infos->prefill_max_waiting_requests_num = least_loaded_prefill_waiting_reqs; } - InstanceMetaInfo default_info(instance_name, ""); - instances_[instance_name] = default_info; - disagg_pd_policy_->insert_instance(instance_name, - &(instances_[instance_name])); - // save instance metainfo to etcd - save_persistence_metainfo(default_info); - return ErrorCode::OK; + if (infos->decode_load_metrics.size() == 0 && + !least_loaded_decode_instance.empty()) { + infos->decode_load_metrics.insert( + std::make_pair(least_loaded_decode_instance, + load_metrics_[least_loaded_decode_instance])); + infos->decode_max_waiting_requests_num = least_loaded_decode_waiting_reqs; + } } -ErrorCode InstanceMgr::register_instance(const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Register instance, instance_name: " << instance_name; - } - if (instances_.find(instance_name) != instances_.end()) { - // update_instance_timestamp(instance_name); - LOG(ERROR) << "Instance is already registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_EXISTED; - } +void InstanceMgr::record_load_metrics_update( + const std::string& instance_name, + const proto::LoadMetrics& load_metrics) { + std::lock_guard lock(update_mutex_); - instances_[instance_name] = metainfo; - disagg_pd_policy_->insert_instance(instance_name, - &(instances_[instance_name])); - // save instance metainfo to etcd - save_persistence_metainfo(metainfo); - return ErrorCode::OK; + updated_metrics_.insert_or_assign( + instance_name, + LoadMetrics(load_metrics.waiting_requests_num(), + load_metrics.gpu_cache_usage_perc())); } -ErrorCode InstanceMgr::update_instance_metainfo( - const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Update instance metainfo, instance_name: " - << instance_name; - } - if (instances_.find(instance_name) == instances_.end()) { - LOG(ERROR) << "Instance is not registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_NOT_EXISTED; +bool InstanceMgr::upload_load_metrics() { + std::lock_guard lock(update_mutex_); + bool status = etcd_client_->set(ETCD_LOADMETRICS_PREFIX, updated_metrics_); + status = + status && etcd_client_->rm(ETCD_LOADMETRICS_PREFIX, removed_instance_); + { + std::unique_lock lock(inst_mutex_); + for (auto& iter : updated_metrics_) { + load_metrics_.insert_or_assign(iter.first, std::move(iter.second)); + } + for (auto& iter : removed_instance_) { + load_metrics_.erase(iter); + } } + updated_metrics_.clear(); + removed_instance_.clear(); - instances_[instance_name] = metainfo; - update_instance_timestamp(instance_name); - disagg_pd_policy_->update_instance(instance_name, - &(instances_[instance_name])); - return ErrorCode::OK; + return status; } -void InstanceMgr::save_persistence_metainfo(const InstanceMetaInfo& metainfo) { - if (!use_etcd_) { - return; - } - std::string key = ETCD_KEYS_PREFIX_MAP[metainfo.type] + metainfo.name; - InstanceIdentityInfo value; - value.instance_addr = metainfo.name; - value.rpc_addr = metainfo.rpc_address; - value.instance_type = static_cast(metainfo.type); - bool ok = etcd_client_->set(key, value); - if (!ok) { - LOG(ERROR) << "Save instance metainfo to etcd failed, key: " << key; - return; +void InstanceMgr::set_as_master() { + is_master_service_ = true; + etcd_client_->remove_watch(ETCD_LOADMETRICS_PREFIX); +} + +std::shared_ptr InstanceMgr::get_channel( + const std::string& instance_name) { + std::shared_lock lock(inst_mutex_); + auto iter = cached_channels_.find(instance_name); + if (iter == cached_channels_.end()) { + return nullptr; } + return iter->second; +} - if (utils::enable_debug_log()) { - InstanceIdentityInfo debug_value; - bool ok = etcd_client_->get(key, debug_value); - if (!ok) { - LOG(ERROR) << "Get instance metainfo from etcd failed, key: " << key; - return; +bool InstanceMgr::create_channel(const std::string& instance_name) { + if (cached_channels_.find(instance_name) == cached_channels_.end()) { + auto channel = std::make_shared(); + brpc::ChannelOptions options; + // Add to params + options.protocol = "http"; + options.timeout_ms = http_config_.timeout_ms; /*milliseconds*/ + options.max_retry = 3; + std::string load_balancer = ""; + if (channel->Init(instance_name.c_str(), load_balancer.c_str(), &options) != + 0) { + LOG(ERROR) << "Fail to initialize channel for " << instance_name; + return false; } - LOG(WARNING) << "Instance after put: " << debug_value.debug_string(); + cached_channels_[instance_name] = std::move(channel); } + + return true; } -void InstanceMgr::delete_persistence_metainfo( - const std::vector& instance_names) { - if (!use_etcd_ || instance_names.empty()) { +void InstanceMgr::handle_instance_metainfo_watch(const etcd::Response& response, + const uint64_t& prefix_len) { + if (response.events().empty()) { return; } - // TODO: use batch delete later - for (const auto& name : instance_names) { - InstanceMetaInfo& metainfo = instances_[name]; - std::string key = ETCD_KEYS_PREFIX_MAP[metainfo.type] + metainfo.name; - bool ok = etcd_client_->rm(key); - if (!ok) { - LOG(ERROR) << "Delete instance metainfo from etcd failed, key: " << key; + + std::unordered_map put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + std::string instance_name = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + InstanceMetaInfo metainfo; + auto json_str = event.kv().as_string(); + if (!metainfo.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } + + put_map.insert(std::make_pair(instance_name, std::move(metainfo))); + + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.push_back(instance_name); } } - if (utils::enable_debug_log()) { - std::vector debug_values; - bool ok = etcd_client_->get_prefix(ETCD_ALL_KEYS_PREFIX, debug_values); - if (!ok) { - LOG(ERROR) << "Get instance metainfo from etcd failed, key: " - << ETCD_ALL_KEYS_PREFIX; - return; + { + std::unique_lock lock(inst_mutex_); + for (auto& iter : put_map) { + if (instances_.find(iter.first) != instances_.end()) { + LOG(ERROR) << "Instance is already registered, instance_name: " + << iter.first; + continue; + } + instances_.insert(std::make_pair(iter.first, std::move(iter.second))); + create_channel(iter.first); } - std::string concat_debug_str; - for (const auto& v : debug_values) { - concat_debug_str += v.debug_string(); - concat_debug_str += "\n"; + + for (auto& iter : delete_list) { + if (instances_.find(iter) == instances_.end()) { + LOG(ERROR) << "Instance is already deleted, instance_name: " << iter; + continue; + } + // TODO: notify cache manager to clear expire cache + instances_.erase(iter); + cached_channels_.erase(iter); + { + std::lock_guard lock(update_mutex_); + updated_metrics_.erase(iter); + removed_instance_.insert(iter); + } } - LOG(WARNING) << "Instances after delete: " << concat_debug_str; } } -ErrorCode InstanceMgr::heartbeat(const std::string& instance_name) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Receive heartbeat, instance_name: " << instance_name; - } - if (instances_.find(instance_name) == instances_.end()) { - LOG(ERROR) << "Instance is not registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_NOT_EXISTED; +void InstanceMgr::handle_load_metrics_watch(const etcd::Response& response, + const uint64_t prefix_len) { + if (response.events().empty()) { + return; } - update_instance_timestamp(instance_name); + std::unordered_map put_map; + std::vector delete_list; - return ErrorCode::OK; -} + for (const auto& event : response.events()) { + std::string instance_name = event.kv().key().substr(prefix_len); -void InstanceMgr::update_instance_timestamp(const std::string& inst_name) { - auto now = std::chrono::system_clock::now(); - auto timestamp_ms = std::chrono::duration_cast( - now.time_since_epoch()) - .count(); - instances_[inst_name].latest_timestamp = timestamp_ms; -} + if (event.event_type() == etcd::Event::EventType::PUT) { + LoadMetrics load_metrics; + auto json_str = event.kv().as_string(); + if (!load_metrics.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } -InstancesPair InstanceMgr::select_instances_pair(bool only_prefill) { - return disagg_pd_policy_->select_instances_pair(only_prefill); -} + put_map.insert(std::make_pair(instance_name, std::move(load_metrics))); -InstanceMetaInfo InstanceMgr::get_instance_info( - const std::string& instance_name) { - std::lock_guard guard(inst_mutex_); - if (instances_.find(instance_name) == instances_.end()) { - LOG(ERROR) << "Get instance info failed, instance is not registered, " - "instance_name: " - << instance_name; - return InstanceMetaInfo(); + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.push_back(instance_name); + } } - return instances_[instance_name]; -} -// TODO: refactor later, currently return all decode instances -std::vector InstanceMgr::get_static_decode_list( - const std::string& instance_name) { - std::vector decode_list; - std::lock_guard guard(inst_mutex_); - for (auto& inst : instances_) { - if (inst.second.type == InstanceType::DECODE) { - decode_list.emplace_back(inst.second.name); + { + std::unique_lock lock(load_metric_mutex_); + for (auto& iter : put_map) { + load_metrics_.insert_or_assign(iter.first, std::move(iter.second)); } - } - return decode_list; + for (auto& iter : delete_list) { + load_metrics_.erase(iter); + } + } } } // namespace xllm_service diff --git a/xllm_service/rpc_service/instance_mgr.h b/xllm_service/rpc_service/instance_mgr.h index 4f6bdef..175c91a 100644 --- a/xllm_service/rpc_service/instance_mgr.h +++ b/xllm_service/rpc_service/instance_mgr.h @@ -15,58 +15,78 @@ limitations under the License. #pragma once -#include +#include + +#include #include #include #include +#include "common/macros.h" #include "common/types.h" -#include "disagg_pd_policy.h" #include "etcd_client.h" +#include "xllm_rpc_service.pb.h" namespace xllm_service { -class InstanceMgr { +class InstanceMgr final { public: - explicit InstanceMgr(const RpcServiceConfig& config); + explicit InstanceMgr(const std::shared_ptr& etcd_client, + const HttpServiceConfig& config, + const bool is_master_service); + ~InstanceMgr(); - ErrorCode heartbeat(const std::string& instance_name); - ErrorCode register_instance(const std::string& instance_name); - ErrorCode register_instance(const std::string& instance_name, - const InstanceMetaInfo& metainfo); - ErrorCode update_instance_metainfo(const std::string& instance_name, - const InstanceMetaInfo& metainfo); InstanceMetaInfo get_instance_info(const std::string& instance_name); - // select instances(prefill/decode/default etc.) to handle request - // according the disagg pd policy (or some other policies.). - InstancesPair select_instances_pair(bool only_prefill = false); - std::vector get_static_decode_list( const std::string& instance_name); + void get_load_metrics(LoadBalanceInfos* infos); + + std::shared_ptr get_channel(const std::string& instance_name); + + void record_load_metrics_update(const std::string& instance_name, + const proto::LoadMetrics& load_metrics); + bool upload_load_metrics(); + + void set_as_master(); + private: - void internal_init(); - // save instance metainfo to etcd - void save_persistence_metainfo(const InstanceMetaInfo& metainfo); - // delete instance metainfo from etcd - void delete_persistence_metainfo( - const std::vector& instance_names); - void detect_disconnected_instances(); - void update_instance_timestamp(const std::string& inst_name); + DISALLOW_COPY_AND_ASSIGN(InstanceMgr); + + void init(); + + bool create_channel(const std::string& target_uri); + // use etcd as ServiceDiscovery + void handle_instance_metainfo_watch(const etcd::Response& response, + const uint64_t& prefix_len); + + void handle_load_metrics_watch(const etcd::Response& response, + const uint64_t prefix_len); private: - RpcServiceConfig config_; bool exited_ = false; - std::mutex inst_mutex_; + bool use_etcd_ = false; + std::atomic_bool is_master_service_ = false; + + RpcServiceConfig config_; + HttpServiceConfig http_config_; + + std::shared_ptr etcd_client_; + + std::shared_mutex inst_mutex_; std::unordered_map instances_; - std::unique_ptr heartbeat_thread_; - std::unique_ptr disagg_pd_policy_; + std::shared_mutex load_metric_mutex_; + std::unordered_map load_metrics_; + std::unordered_map> + cached_channels_; + std::unordered_set zombie_nodes_; - bool use_etcd_ = false; - std::unique_ptr etcd_client_; + std::mutex update_mutex_; + std::unordered_map updated_metrics_; + std::unordered_set removed_instance_; }; } // namespace xllm_service From 56a4971fbcf1935b7bf8fcc806d294b206ea21b0 Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 16:14:40 +0800 Subject: [PATCH 06/11] feat: implement Scheduler and GlobalKVCacheMgr. --- xllm_service/rpc_service/CMakeLists.txt | 8 +- .../rpc_service/global_kvcache_mgr.cpp | 237 ++++++++++++++++++ xllm_service/rpc_service/global_kvcache_mgr.h | 48 ++++ xllm_service/rpc_service/scheduler.cpp | 160 ++++++++++++ xllm_service/rpc_service/scheduler.h | 86 +++++++ 5 files changed, 537 insertions(+), 2 deletions(-) create mode 100644 xllm_service/rpc_service/global_kvcache_mgr.cpp create mode 100644 xllm_service/rpc_service/global_kvcache_mgr.h create mode 100644 xllm_service/rpc_service/scheduler.cpp create mode 100644 xllm_service/rpc_service/scheduler.h diff --git a/xllm_service/rpc_service/CMakeLists.txt b/xllm_service/rpc_service/CMakeLists.txt index b876a6e..55fb4e9 100644 --- a/xllm_service/rpc_service/CMakeLists.txt +++ b/xllm_service/rpc_service/CMakeLists.txt @@ -1,24 +1,28 @@ include(cc_binary) include(cc_library) include(cc_test) +add_subdirectory(loadbalance_policy) cc_library( NAME xllm_rpc_service HDRS - disagg_pd_policy.h + scheduler.h etcd_client.h instance_mgr.h + global_kvcache_mgr.h response_handler.h service.h SRCS - disagg_pd_policy.cpp + scheduler.cpp etcd_client.cpp instance_mgr.cpp + global_kvcache_mgr.cpp response_handler.cpp service.cpp DEPS :common + :loadbalance_policy absl::random_random absl::strings cpprest diff --git a/xllm_service/rpc_service/global_kvcache_mgr.cpp b/xllm_service/rpc_service/global_kvcache_mgr.cpp new file mode 100644 index 0000000..44a28dd --- /dev/null +++ b/xllm_service/rpc_service/global_kvcache_mgr.cpp @@ -0,0 +1,237 @@ +#include "global_kvcache_mgr.h" + +#include + +#include "common/hash_util.h" + +namespace xllm_service { + +inline size_t round_down(size_t n, size_t multiple) { + return (n / multiple) * multiple; +} + +static std::string ETCD_CACHE_PREFIX = "XLLM:CACHE:"; + +GlobalKVCacheMgr::GlobalKVCacheMgr( + const std::shared_ptr& etcd_client, + const ModelConfig& model_config, + const bool is_master_service) + : model_config_(model_config), + is_master_service_(is_master_service), + etcd_client_(etcd_client) { + if (!is_master_service_) { + auto handle_kvcache = std::bind(&GlobalKVCacheMgr::handle_kvcache_watch, + this, + std::placeholders::_1, + std::placeholders::_2); + etcd_client_->add_watch(ETCD_CACHE_PREFIX, handle_kvcache); + } + + { + std::unique_lock lock(kvcache_mutex_); + etcd_client_->get_prefix(ETCD_CACHE_PREFIX, &kvcache_infos_); + DLOG(INFO) << "Load etcd cache infos:" << kvcache_infos_.size(); + } +} + +GlobalKVCacheMgr::~GlobalKVCacheMgr() { + exited_ = true; + etcd_client_->remove_watch(ETCD_CACHE_PREFIX); +} + +void set_score(const std::unordered_set& instance_names, + const uint32_t& match_length, + std::unordered_map* scores, + std::unordered_set* instances) { + for (const auto& name : instance_names) { + if (scores->count(name) == 0) { + scores->insert_or_assign(name, match_length); + } else { + (*scores)[name] = match_length; + } + instances->insert(name); + } +} + +void GlobalKVCacheMgr::match(const Slice& token_ids, + OverlapScores* overlap_scores) { + // allign tokens to block boundary + const size_t n_tokens = + round_down(token_ids.size(), model_config_.block_size); + if (n_tokens == 0) { + return; + } + + overlap_scores->max_block_num = n_tokens / model_config_.block_size; + + std::shared_lock lock(kvcache_mutex_); + Murmur3Key token_hash_key; + for (size_t i = 0; i < n_tokens; i += model_config_.block_size) { + if (i == 0) { + murmur_hash3(nullptr, + token_ids.slice(i, i + model_config_.block_size), + token_hash_key.data); + } else { + murmur_hash3(token_hash_key.data, + token_ids.slice(i, i + model_config_.block_size), + token_hash_key.data); + } + + auto iter = kvcache_infos_.find(token_hash_key); + if (iter != kvcache_infos_.end() && !iter->second.empty()) { + if (!iter->second.hbm_instance_set.empty()) { + set_score(iter->second.hbm_instance_set, + i / model_config_.block_size + 1, + &(overlap_scores->hbm_instance_score), + &(overlap_scores->instances)); + overlap_scores->max_matched_instance_name = + *iter->second.hbm_instance_set.begin(); + overlap_scores->max_matched_block_num = + i / model_config_.block_size + 1; + } + + if (!iter->second.dram_instance_set.empty()) { + set_score(iter->second.dram_instance_set, + i / model_config_.block_size + 1, + &(overlap_scores->dram_instance_score), + &(overlap_scores->instances)); + overlap_scores->max_matched_instance_name = + *iter->second.hbm_instance_set.begin(); + overlap_scores->max_matched_block_num = + i / model_config_.block_size + 1; + } + + if (!iter->second.ssd_instance_set.empty()) { + set_score(iter->second.ssd_instance_set, + i / model_config_.block_size + 1, + &(overlap_scores->ssd_instance_score), + &(overlap_scores->instances)); + overlap_scores->max_matched_instance_name = + *iter->second.hbm_instance_set.begin(); + overlap_scores->max_matched_block_num = + i / model_config_.block_size + 1; + } + } else { + break; + } + } +} + +void GlobalKVCacheMgr::handle_kvcache_watch(const etcd::Response& response, + const uint64_t prefix_len) { + if (response.events().empty() || exited_) { + return; + } + + Murmur3KeyCacheMap put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + auto key = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + CacheLocations cachelocations; + auto json_str = event.kv().as_string(); + if (!cachelocations.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } + + put_map.insert_or_assign(Murmur3Key{key.c_str()}, + std::move(cachelocations)); + + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.emplace_back(Murmur3Key{key.c_str()}); + } + } + + { + std::unique_lock lock(kvcache_mutex_); + for (auto& iter : put_map) { + kvcache_infos_.insert_or_assign(iter.first, std::move(iter.second)); + } + + for (auto& iter : delete_list) { + kvcache_infos_.erase(iter); + } + } +} + +void GlobalKVCacheMgr::record_updated_kvcaches( + const std::string& instance_name, + const proto::KvCacheEvent& kvcache_event) { + std::lock_guard update_lock(update_mutex_); + std::shared_lock metric_lock(kvcache_mutex_); + for (int i = 0; i < kvcache_event.stored_cache_size(); i++) { + Murmur3Key key(kvcache_event.stored_cache(i).c_str()); + if (updated_kvcaches_.count(key) == 0) { + if (kvcache_infos_.count(key) == 0) { + updated_kvcaches_.insert_or_assign(key, CacheLocations()); + } else { + updated_kvcaches_.insert_or_assign(key, kvcache_infos_[key]); + } + } + updated_kvcaches_.at(key).hbm_instance_set.insert(instance_name); + } + + for (int i = 0; i < kvcache_event.offload_cache_size(); i++) { + Murmur3Key key(kvcache_event.offload_cache(i).c_str()); + if (updated_kvcaches_.count(key) == 0) { + if (kvcache_infos_.count(key) == 0) { + continue; + } else { + updated_kvcaches_.insert_or_assign(key, kvcache_infos_[key]); + } + } + if (updated_kvcaches_.at(key).hbm_instance_set.count(instance_name) != 0) { + updated_kvcaches_.at(key).hbm_instance_set.erase(instance_name); + updated_kvcaches_.at(key).dram_instance_set.insert(instance_name); + } else { + updated_kvcaches_.at(key).dram_instance_set.erase(instance_name); + updated_kvcaches_.at(key).ssd_instance_set.insert(instance_name); + } + } + + for (int i = 0; i < kvcache_event.removed_cache_size(); i++) { + Murmur3Key key(kvcache_event.removed_cache(i).c_str()); + if (updated_kvcaches_.count(key) == 0) { + if (kvcache_infos_.count(key) == 0) { + continue; + } else { + updated_kvcaches_.insert_or_assign(key, kvcache_infos_[key]); + } + } + updated_kvcaches_.at(key).hbm_instance_set.erase(instance_name); + updated_kvcaches_.at(key).dram_instance_set.erase(instance_name); + updated_kvcaches_.at(key).ssd_instance_set.erase(instance_name); + } +} + +bool GlobalKVCacheMgr::upload_kvcache() { + std::lock_guard update_lock(update_mutex_); + if (updated_kvcaches_.empty()) { + return true; + } + bool rt = etcd_client_->set(ETCD_CACHE_PREFIX, updated_kvcaches_); + { + std::unique_lock metric_lock(kvcache_mutex_); + for (auto& iter : updated_kvcaches_) { + if (iter.second.empty()) { + kvcache_infos_.erase(iter.first); + } else { + kvcache_infos_.insert_or_assign(iter.first, std::move(iter.second)); + } + } + } + if (rt) { + updated_kvcaches_.clear(); + } + return rt; +} + +void GlobalKVCacheMgr::set_as_master() { + is_master_service_ = true; + etcd_client_->remove_watch(ETCD_CACHE_PREFIX); +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/global_kvcache_mgr.h b/xllm_service/rpc_service/global_kvcache_mgr.h new file mode 100644 index 0000000..f06abc3 --- /dev/null +++ b/xllm_service/rpc_service/global_kvcache_mgr.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +#include "common/hash_util.h" +#include "common/macros.h" +#include "common/slice.h" +#include "common/types.h" +#include "etcd_client.h" +#include "xllm_rpc_service.pb.h" + +namespace xllm_service { + +class GlobalKVCacheMgr final { + public: + explicit GlobalKVCacheMgr(const std::shared_ptr& etcd_client, + const ModelConfig& model_config, + const bool is_master_service); + ~GlobalKVCacheMgr(); + + void match(const Slice& token_ids, OverlapScores* overlap_scores); + + void record_updated_kvcaches(const std::string& instance_name, + const proto::KvCacheEvent& kvcache_event); + bool upload_kvcache(); + + void set_as_master(); + + private: + DISALLOW_COPY_AND_ASSIGN(GlobalKVCacheMgr); + + void handle_kvcache_watch(const etcd::Response& response, + const uint64_t prefix_len); + + private: + ModelConfig model_config_; + std::atomic_bool is_master_service_ = false; + bool exited_ = false; + std::shared_mutex kvcache_mutex_; + Murmur3KeyCacheMap kvcache_infos_; + std::shared_ptr etcd_client_; // not own + + std::mutex update_mutex_; + Murmur3KeyCacheMap updated_kvcaches_; +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/scheduler.cpp b/xllm_service/rpc_service/scheduler.cpp new file mode 100644 index 0000000..2d2e8e3 --- /dev/null +++ b/xllm_service/rpc_service/scheduler.cpp @@ -0,0 +1,160 @@ +#include "scheduler.h" + +#include +#include +#include + +#include "chat_template/chat_template_factory.h" +#include "common.pb.h" +#include "common/hash_util.h" +#include "tokenizer/tokenizer_factory.h" + +static constexpr int kHeartbeatInterval = 3; // in seconds +static std::string ETCD_MASTER_SERVICE_KEY = "XLLM:SERVICE:MASTER"; + +namespace xllm_service { + +Scheduler::Scheduler(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config) + : rpc_config_(rpc_config), + model_config_(model_config), + http_config_(http_config) { + tokenizer_ = create_tokenizer(model_config_, &tokenizer_args_); + chat_template_ = + create_chat_template(model_config_.model_type, tokenizer_args_); + + etcd_client_ = std::make_shared(rpc_config_.etcd_addr); + + if (!etcd_client_->get(ETCD_MASTER_SERVICE_KEY, nullptr)) { + is_master_service_ = etcd_client_->set( + ETCD_MASTER_SERVICE_KEY, rpc_config_.service_name, kHeartbeatInterval); + LOG(INFO) << "Set current service as master!"; + } + + instance_mgr_ = std::make_unique( + etcd_client_, http_config_, is_master_service_); + + global_kvcache_mgr_ = std::make_unique( + etcd_client_, model_config_, is_master_service_); + + lb_policy_ = std::make_unique(); + + if (is_master_service_) { + heartbeat_thread_ = std::make_unique( + &Scheduler::update_master_service_heartbeat, this); + } else { + auto handle_master = std::bind(&Scheduler::handle_master_service_watch, + this, + std::placeholders::_1, + std::placeholders::_2); + etcd_client_->add_watch(ETCD_MASTER_SERVICE_KEY, handle_master); + } +} + +Scheduler::~Scheduler() { etcd_client_->stop_watch(); } + +bool Scheduler::schedule(const ChatMessages& messages, SchduleResult* res) { + if (chat_template_ == nullptr) { + LOG(ERROR) << "Chat template has not configured for model type: " + << model_config_.model_type; + return false; + } + + auto prompt = chat_template_->apply(messages); + if (!prompt.has_value()) { + LOG(ERROR) << "Failed to construct prompt from messages"; + return false; + } + + return schedule(prompt.value(), res); +} + +bool Scheduler::schedule(const std::string& prompt, SchduleResult* res) { + LoadBalanceInfos lb_infos; + if (prompt.size() != 0) { + if (!get_tls_tokenizer()->encode(prompt, &res->token_ids)) { + LOG(ERROR) << "Encode prompt faill: " << prompt; + return false; + } + + Slice token_ids(res->token_ids.data(), + res->token_ids.size()); + + global_kvcache_mgr_->match(token_ids, &lb_infos.overlap_scores); + DLOG(INFO) << lb_infos.debug_string(); + } + + instance_mgr_->get_load_metrics(&lb_infos); + DLOG(INFO) << lb_infos.debug_string(); + + if (lb_infos.prefill_load_metrics.size() == 0) { + LOG(INFO) << "No node available!"; + return false; + } + + lb_policy_->select_instances_pair(lb_infos, &res->routing); + + DLOG(INFO) << res->routing.debug_string(); + + return true; +} + +std::shared_ptr Scheduler::get_channel( + const std::string& target_name) { + return instance_mgr_->get_channel(target_name); +} + +void Scheduler::update_master_service_heartbeat() { + while (!exited_) { + std::this_thread::sleep_for(std::chrono::seconds(kHeartbeatInterval)); + + global_kvcache_mgr_->upload_kvcache(); + + instance_mgr_->upload_load_metrics(); + } +} + +void Scheduler::handle_instance_heartbeat(const proto::HeartbeatRequest* req) { + if (exited_) { + return; + } + global_kvcache_mgr_->record_updated_kvcaches(req->name(), req->cache_event()); + instance_mgr_->record_load_metrics_update(req->name(), req->load_metrics()); +} + +void Scheduler::handle_master_service_watch(const etcd::Response& response, + const uint64_t& prefix_len) { + if (exited_ || response.events().empty()) { + return; + } + + if (etcd_client_->set(ETCD_MASTER_SERVICE_KEY, + rpc_config_.service_name, + kHeartbeatInterval)) { + is_master_service_ = true; + + heartbeat_thread_ = std::make_unique( + &Scheduler::update_master_service_heartbeat, this); + + global_kvcache_mgr_->set_as_master(); + instance_mgr_->set_as_master(); + } +} + +InstanceMetaInfo Scheduler::get_instance_info( + const std::string& instance_name) { + return instance_mgr_->get_instance_info(instance_name); +} + +std::vector Scheduler::get_static_decode_list( + const std::string& instance_name) { + return instance_mgr_->get_static_decode_list(instance_name); +} + +Tokenizer* Scheduler::get_tls_tokenizer() { + thread_local std::unique_ptr tls_tokenizer(tokenizer_->clone()); + return tls_tokenizer.get(); +} + +} // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/rpc_service/scheduler.h b/xllm_service/rpc_service/scheduler.h new file mode 100644 index 0000000..cbc37cc --- /dev/null +++ b/xllm_service/rpc_service/scheduler.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "chat_template/chat_template.h" +#include "common/hash_util.h" +#include "common/macros.h" +#include "common/types.h" +#include "etcd_client.h" +#include "global_kvcache_mgr.h" +#include "instance_mgr.h" +#include "loadbalance_policy/loadbalance_policy.h" +#include "tokenizer/tokenizer.h" +#include "tokenizer/tokenizer_args.h" +#include "xllm_rpc_service.pb.h" + +namespace xllm_service { + +class Scheduler { + public: + explicit Scheduler(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config); + + ~Scheduler(); + + bool schedule(const ChatMessages& messages, SchduleResult* res); + + bool schedule(const std::string& prompt, SchduleResult* res); + + std::shared_ptr get_channel(const std::string& target_name); + + InstanceMetaInfo get_instance_info(const std::string& instance_name); + + std::vector get_static_decode_list( + const std::string& instance_name); + + void handle_instance_heartbeat(const proto::HeartbeatRequest* req); + + void exited() { exited_ = true; } + + private: + DISALLOW_COPY_AND_ASSIGN(Scheduler); + + void update_master_service_heartbeat(); + + void handle_master_service_watch(const etcd::Response& response, + const uint64_t& prefix_len); + + Tokenizer* get_tls_tokenizer(); + + private: + bool exited_ = false; + + bool is_master_service_ = false; + + TokenizerArgs tokenizer_args_; + + RpcServiceConfig rpc_config_; + + ModelConfig model_config_; + + HttpServiceConfig http_config_; + + // chat template instance + std::unique_ptr chat_template_; + + std::shared_ptr etcd_client_; + + std::unique_ptr tokenizer_; + + std::unique_ptr instance_mgr_; + + std::unique_ptr global_kvcache_mgr_; + + std::unique_ptr lb_policy_; + + std::unique_ptr heartbeat_thread_; +}; + +} // namespace xllm_service From a46b7c7461becb0f0e0f67a6b4aa7c241c13f651 Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 16:16:11 +0800 Subject: [PATCH 07/11] feat: implement cache aware load balance. --- xllm_service/rpc_service/disagg_pd_policy.cpp | 207 ------------------ xllm_service/rpc_service/disagg_pd_policy.h | 77 ------- .../loadbalance_policy/CMakeLists.txt | 14 ++ .../loadbalance_policy/loadbalance_policy.cpp | 56 +++++ .../loadbalance_policy/loadbalance_policy.h | 32 +++ 5 files changed, 102 insertions(+), 284 deletions(-) delete mode 100644 xllm_service/rpc_service/disagg_pd_policy.cpp delete mode 100644 xllm_service/rpc_service/disagg_pd_policy.h create mode 100644 xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt create mode 100644 xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp create mode 100644 xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h diff --git a/xllm_service/rpc_service/disagg_pd_policy.cpp b/xllm_service/rpc_service/disagg_pd_policy.cpp deleted file mode 100644 index 292e14b..0000000 --- a/xllm_service/rpc_service/disagg_pd_policy.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm-service/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "disagg_pd_policy.h" - -#include - -#include "common/utils.h" - -namespace xllm_service { - -namespace { - -void debug_print(const std::string& action, - const std::string& name, - const std::string& type, - int idx) { - if (utils::enable_debug_log()) { - LOG(INFO) << "DisaggPdPolicy " << action << " instance, name = " << name - << ", type = " << type << ", idx = " << idx; - } -} - -} // namespace - -DisaggPdPolicy::DisaggPdPolicy() {} - -DisaggPdPolicy::~DisaggPdPolicy() {} - -void DisaggPdPolicy::insert_instance(const std::string& name, - InstanceMetaInfo* info) { - std::lock_guard guard(mutex_); - InstanceType type = info->type; - if (type == InstanceType::DEFAULT || type == InstanceType::PREFILL) { - auto it = prefill_instance_to_index_.find(name); - if (it != prefill_instance_to_index_.end()) { - LOG(ERROR) << "Insert instance is already existed, name: " << name - << ", type: " << static_cast(type); - return; - } - prefill_instance_.emplace_back(info); - prefill_instance_to_index_[name] = prefill_instance_.size() - 1; - debug_print( - "insert", name, "prefill or default", prefill_instance_to_index_[name]); - } else { - auto it = decode_instance_to_index_.find(name); - if (it != decode_instance_to_index_.end()) { - LOG(ERROR) << "Insert instance is already existed, name: " << name - << ", type: " << static_cast(type); - return; - } - decode_instance_.emplace_back(info); - decode_instance_to_index_[name] = decode_instance_.size() - 1; - debug_print("insert", name, "decode", decode_instance_to_index_[name]); - } -} - -void DisaggPdPolicy::update_instance(const std::string& name, - InstanceMetaInfo* info) { - std::lock_guard guard(mutex_); - InstanceType type = info->type; - if (type == InstanceType::DEFAULT || type == InstanceType::PREFILL) { - auto it = prefill_instance_to_index_.find(name); - if (it == prefill_instance_to_index_.end()) { - LOG(ERROR) << "Update instance is not existed, name: " << name - << ", type: " << static_cast(type); - return; - } - prefill_instance_[it->second] = info; - debug_print( - "update", name, "prefill or default", prefill_instance_to_index_[name]); - } else { - auto it = decode_instance_to_index_.find(name); - if (it == decode_instance_to_index_.end()) { - LOG(ERROR) << "Update instance is not existed, name: " << name - << ", type: " << static_cast(type); - return; - } - decode_instance_[it->second] = info; - debug_print("update", name, "decode", decode_instance_to_index_[name]); - } -} - -void DisaggPdPolicy::remove_instance(const std::string& name, - InstanceType type) { - std::lock_guard guard(mutex_); - if (type == InstanceType::DEFAULT || type == InstanceType::PREFILL) { - auto it = prefill_instance_to_index_.find(name); - if (it == prefill_instance_to_index_.end()) { - LOG(ERROR) << "Remove instance not found, name: " << name - << ", type: " << static_cast(type); - return; - } - auto idx = it->second; - // Label the instance be deleted - prefill_instance_[idx] = nullptr; - prefill_instance_to_index_.erase(name); - debug_print("remove", name, "prefill or default", idx); - } else { - auto it = decode_instance_to_index_.find(name); - if (it == decode_instance_to_index_.end()) { - LOG(ERROR) << "Remove instance not found, name: " << name - << ", type: " << static_cast(type); - return; - } - auto idx = it->second; - // Label the instance be deleted - decode_instance_[idx] = nullptr; - decode_instance_to_index_.erase(name); - debug_print("remove", name, "decode", idx); - } -} - -RoundRobinDisaggPdPolicy::RoundRobinDisaggPdPolicy() { - LOG(INFO) << "Enable RoundRobin disaggregated pd policy."; -} - -RoundRobinDisaggPdPolicy::~RoundRobinDisaggPdPolicy() {} - -InstancesPair RoundRobinDisaggPdPolicy::select_instances_pair( - bool only_prefill) { - std::lock_guard guard(mutex_); - // return the first available prefill instance - if (only_prefill) { - InstancesPair inst_pair; - for (const auto& inst : prefill_instance_) { - if (inst != nullptr) { - inst_pair.prefill_instance_http_addr = inst->name; - break; - } - } - return inst_pair; - } - - int prefill_count = prefill_instance_.size(); - int decode_count = decode_instance_.size(); - InstancesPair inst_pair; - // select prefill instance - if (prefill_count > 0) { - auto start_idx = next_prefill_idx_; - bool inst_not_existed = false; - while (prefill_instance_[next_prefill_idx_] == nullptr) { - ++next_prefill_idx_; - next_prefill_idx_ %= prefill_count; - if (next_prefill_idx_ == start_idx) { - inst_not_existed = true; - break; - } - } - if (!inst_not_existed) { - inst_pair.prefill_instance_http_addr = - prefill_instance_[next_prefill_idx_]->name; - } - - ++next_prefill_idx_; - next_prefill_idx_ %= prefill_count; - } - - // select decode instance - if (decode_count > 0) { - auto start_idx = next_decode_idx_; - bool inst_not_existed = false; - while (decode_instance_[next_decode_idx_] == nullptr) { - ++next_decode_idx_; - next_decode_idx_ %= decode_count; - if (next_decode_idx_ == start_idx) { - inst_not_existed = true; - break; - } - } - if (!inst_not_existed) { - inst_pair.decode_instance_http_addr = - decode_instance_[next_decode_idx_]->name; - } - - ++next_decode_idx_; - next_decode_idx_ %= decode_count; - } - - return inst_pair; -} - -std::unordered_map -RoundRobinDisaggPdPolicy::reallocate_instances_type(/*params here*/) { - // TODO: implement this function - return {}; -} - -std::unordered_map> -RoundRobinDisaggPdPolicy::allocate_pd_pairs(/*params here*/) { - // TODO: implement this function - return {}; -} - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/disagg_pd_policy.h b/xllm_service/rpc_service/disagg_pd_policy.h deleted file mode 100644 index 4e83d43..0000000 --- a/xllm_service/rpc_service/disagg_pd_policy.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm-service/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include -#include -#include - -#include "common/types.h" - -namespace xllm_service { - -class DisaggPdPolicy { - public: - DisaggPdPolicy(); - virtual ~DisaggPdPolicy(); - - // re-allocate instance types: prefill or decode - virtual std::unordered_map - reallocate_instances_type(/*params here*/) = 0; - - // Allocate prefill and decode pairs, return prefill -> [decode instances] - // Allow multiple decode instances for each prefill instance and - // multiple prefill instances for each decode instance - virtual std::unordered_map> - allocate_pd_pairs(/*params here*/) = 0; - - // select instances(prefill/decode/default etc.) to handle request - // according the disagg pd policy. - virtual InstancesPair select_instances_pair(bool only_prefill = false) = 0; - - void insert_instance(const std::string& name, InstanceMetaInfo* info); - void update_instance(const std::string& name, InstanceMetaInfo* info); - void remove_instance(const std::string& name, InstanceType type); - - protected: - std::vector prefill_instance_; - std::vector decode_instance_; - // map the instance name to vector index - std::unordered_map prefill_instance_to_index_; - std::unordered_map decode_instance_to_index_; - - std::mutex mutex_; -}; - -class RoundRobinDisaggPdPolicy : public DisaggPdPolicy { - public: - RoundRobinDisaggPdPolicy(); - ~RoundRobinDisaggPdPolicy(); - - virtual std::unordered_map - reallocate_instances_type(/*params here*/) override; - virtual std::unordered_map> - allocate_pd_pairs(/*params here*/) override; - virtual InstancesPair select_instances_pair( - bool only_prefill = false) override; - - private: - int next_prefill_idx_ = 0; - int next_decode_idx_ = 0; -}; - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt b/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt new file mode 100644 index 0000000..06c29ef --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt @@ -0,0 +1,14 @@ +include(cc_binary) +include(cc_library) +include(cc_test) + +cc_library( + NAME + loadbalance_policy + HDRS + loadbalance_policy.h + SRCS + loadbalance_policy.cpp + DEPS + :common +) diff --git a/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp new file mode 100644 index 0000000..5af1048 --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp @@ -0,0 +1,56 @@ +#pragma once + +#include "loadbalance_policy.h" + +namespace xllm_service { + +constexpr float MIN_SCORE = -2.0; + +void LoadBalancePolicy::select_instances_pair(const LoadBalanceInfos& infos, + Routing* routing) { + // find preifll + cost_function(infos.overlap_scores.hbm_instance_score, + infos.overlap_scores.max_block_num, + infos.prefill_load_metrics, + infos.prefill_max_waiting_requests_num, + &routing->prefill_name); + + // find decode + if (infos.decode_load_metrics.size()) { + cost_function(infos.overlap_scores.hbm_instance_score, + infos.overlap_scores.max_block_num, + infos.decode_load_metrics, + infos.decode_max_waiting_requests_num, + &routing->decode_name); + } +} + +void LoadBalancePolicy::cost_function( + const std::unordered_map& overlap_scores, + const uint32_t& max_block_num, + const std::unordered_map& load_metrics, + const int64_t& max_waiting_requests_num, + std::string* best_choice) { + float best_score = MIN_SCORE; + for (const auto& it : load_metrics) { + const auto matched_blocks_it = overlap_scores.find(it.first); + uint32_t matched_blocks = 0; + if (matched_blocks_it != overlap_scores.end()) { + matched_blocks = matched_blocks_it->second; + } + + auto score = + (max_block_num == 0 ? 0 : matched_blocks / max_block_num) - + it.second.gpu_cache_usage_perc - + (max_waiting_requests_num == 0 + ? 0 + : it.second.waiting_requests_num / max_waiting_requests_num); + + if (score > best_score) { + best_score = score; + *best_choice = it.first; + } + } +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h new file mode 100644 index 0000000..9e83f54 --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +#include "common/macros.h" +#include "common/types.h" + +namespace xllm_service { + +class LoadBalancePolicy { + public: + LoadBalancePolicy() = default; + + virtual ~LoadBalancePolicy() = default; + + virtual void select_instances_pair(const LoadBalanceInfos& infos, + Routing* routing); + + protected: + DISALLOW_COPY_AND_ASSIGN(LoadBalancePolicy); + + virtual void cost_function( + const std::unordered_map& overlap_scores, + const uint32_t& max_block_num, + const std::unordered_map& load_metrics, + const int64_t& max_waiting_requests_num, + std::string* best_choice); +}; + +} // namespace xllm_service From 1e9ffd48b0dc4be02ae6c34758eb1acb9959e854 Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 22 Aug 2025 16:17:55 +0800 Subject: [PATCH 08/11] feat: update services to support cache aware routing. --- xllm_service/common/global_gflags.cpp | 5 - xllm_service/common/global_gflags.h | 2 - xllm_service/common/types.h | 4 +- xllm_service/http_service/service.cpp | 150 ++++++++---------- xllm_service/http_service/service.h | 10 +- xllm_service/master.cpp | 57 ++++++- xllm_service/master.h | 3 + xllm_service/rpc_service/main.cpp | 20 ++- xllm_service/rpc_service/rpc_service_test.cpp | 68 ++++---- xllm_service/rpc_service/scheduler.cpp | 3 +- xllm_service/rpc_service/service.cpp | 75 +++------ xllm_service/rpc_service/service.h | 33 ++-- 12 files changed, 216 insertions(+), 214 deletions(-) diff --git a/xllm_service/common/global_gflags.cpp b/xllm_service/common/global_gflags.cpp index d638379..15e0279 100644 --- a/xllm_service/common/global_gflags.cpp +++ b/xllm_service/common/global_gflags.cpp @@ -33,11 +33,6 @@ DEFINE_int32(http_server_max_concurrency, 128, "Limit number of requests processed in parallel"); -DEFINE_string(rpc_server_host, - "", - "Rpc server listen address, may be IPV4/IPV6/UDS." - " If this is set, the flag port will be ignored"); - DEFINE_int32(rpc_server_port, 8889, "Port for xllm rpc service to listen on"); DEFINE_int32(rpc_server_idle_timeout_s, diff --git a/xllm_service/common/global_gflags.h b/xllm_service/common/global_gflags.h index 7409c0d..3c09ee8 100644 --- a/xllm_service/common/global_gflags.h +++ b/xllm_service/common/global_gflags.h @@ -27,8 +27,6 @@ DECLARE_int32(http_server_num_threads); DECLARE_int32(http_server_max_concurrency); -DECLARE_string(rpc_server_host); - DECLARE_int32(rpc_server_port); DECLARE_int32(rpc_server_idle_timeout_s); diff --git a/xllm_service/common/types.h b/xllm_service/common/types.h index 478e642..c1e591e 100644 --- a/xllm_service/common/types.h +++ b/xllm_service/common/types.h @@ -102,9 +102,9 @@ enum class InstanceType : int8_t { }; struct LoadMetrics { - LoadMetrics() : waiting_requests_num(0), gpu_cache_usage_perc(0){}; + LoadMetrics() : waiting_requests_num(0), gpu_cache_usage_perc(0) {}; LoadMetrics(const uint64_t& waiting_reqs_num, const float& usage) - : waiting_requests_num(waiting_reqs_num), gpu_cache_usage_perc(usage){}; + : waiting_requests_num(waiting_reqs_num), gpu_cache_usage_perc(usage) {}; uint64_t waiting_requests_num; float gpu_cache_usage_perc; diff --git a/xllm_service/http_service/service.cpp b/xllm_service/http_service/service.cpp index ea5999e..dbdac78 100644 --- a/xllm_service/http_service/service.cpp +++ b/xllm_service/http_service/service.cpp @@ -63,56 +63,6 @@ XllmHttpServiceImpl::XllmHttpServiceImpl(const HttpServiceConfig& config) XllmHttpServiceImpl::~XllmHttpServiceImpl() {} -bool XllmHttpServiceImpl::create_channel(const std::string& target_uri) { - std::lock_guard guard(channel_mutex_); - if (cached_channels_.find(target_uri) == cached_channels_.end()) { - brpc::Channel* channel = new brpc::Channel(); - brpc::ChannelOptions options; - // Add to params - options.protocol = "http"; - options.timeout_ms = config_.timeout_ms; /*milliseconds*/ - options.max_retry = 3; - std::string load_balancer = ""; - if (channel->Init(target_uri.c_str(), load_balancer.c_str(), &options) != - 0) { - LOG(ERROR) << "Fail to initialize channel for " << target_uri; - return false; - } - cached_channels_[target_uri] = channel; - } - - return true; -} - -std::string XllmHttpServiceImpl::get_redirect_uri(bool only_prefill) { - std::string target_instance_addr; - if (!rpc_service_) { - // for testing - if (config_.test_instance_addr.empty()) { - LOG(ERROR) << "Rpc service is not start."; - return ""; - } - target_instance_addr = config_.test_instance_addr; - } else { - InstancesPair instances_pair = - rpc_service_->select_instances_pair(only_prefill); - if (instances_pair.prefill_instance_http_addr.empty()) { - LOG(ERROR) << "No prefill instance available."; - return ""; - } - target_instance_addr = instances_pair.prefill_instance_http_addr; - - if (!only_prefill) { - if (instances_pair.decode_instance_http_addr.empty()) { - // TODO: - } - // TODO: add instances_pair.decode_instance_http_addr to request? - } - } - - return target_instance_addr; -} - void XllmHttpServiceImpl::Hello(::google::protobuf::RpcController* controller, const proto::HttpHelloRequest* request, proto::HttpHelloResponse* response, @@ -213,7 +163,8 @@ void XllmHttpServiceImpl::handle(std::shared_ptr call_data, // async redistribute the request and wait the response // TODO: optimize the thread pool to async mode. - auto channel_ptr = cached_channels_[target_uri]; + brpc::Channel* channel_ptr = rpc_service_->get_channel(target_uri).get(); + // send request to prefill instance. thread_pool_->schedule([this, service_request_id, @@ -375,24 +326,6 @@ void XllmHttpServiceImpl::post_serving( // create xllm_service request_id: service_request_id std::string service_request_id = generate_service_request_id(serving_method); json_value["service_request_id"] = service_request_id; - std::string req_attachment = json_value.dump(); - request_tracer_->log(service_request_id, req_attachment); - - // redistribute the request to the correct P/D instance - // TODO: redistribute policy to select the instance - std::string target_uri = get_redirect_uri(); - if (target_uri.empty()) { - cntl->SetFailed( - "Internal runtime error, can not found a running instance."); - return; - } - if (cached_channels_.find(target_uri) == cached_channels_.end()) { - if (!create_channel(target_uri)) { - LOG(ERROR) << "Create channel failed, target_uri is " << target_uri; - cntl->SetFailed("Internal runtime error."); - return; - } - } std::function trace_callback; if (config_.enable_request_trace) { @@ -403,33 +336,82 @@ void XllmHttpServiceImpl::post_serving( trace_callback = nullptr; } + SchduleResult schedule_res; if (serving_method == "/v1/completions") { + if (json_value.contains("prompt")) { + if (!rpc_service_->schedule(json_value.at("prompt").get(), + &schedule_res)) { + cntl->SetFailed("Schedule fail!"); + LOG(ERROR) << "XllmRpcServiceImpl::schedule error!"; + return; + } + } else { + cntl->SetFailed("Input has no prompt!"); + LOG(ERROR) << "Input has no prompt!"; + return; + } + json_value["token_ids"] = schedule_res.token_ids; + json_value["routing"] = schedule_res.routing.serialize_to_json(); + + std::string req_attachment = json_value.dump(); auto arena = response->GetArena(); auto resp_pb = google::protobuf::Arena::CreateMessage( arena); auto call_data = std::make_shared( - cntl, stream, done_guard.release(), resp_pb, trace_callback); + cntl, stream, done_guard.release(), resp_pb); handle_v1_completions(call_data, req_attachment, service_request_id, stream, model, include_usage, - target_uri); + schedule_res.routing.prefill_name); } else if (serving_method == "/v1/chat/completions") { + if (json_value.contains("messages") && json_value["messages"].is_array()) { + ChatMessages messages; + try { + const auto& msgs = json_value["messages"]; + messages.reserve(msgs.size()); + for (const auto& msg : msgs) { + if (msg.contains("role") && msg["role"].is_string() && + msg.contains("content") && msg["content"].is_string()) { + messages.emplace_back(msg["role"].get(), + msg["content"].get()); + } + } + } catch (const nlohmann::json::exception& e) { + cntl->SetFailed("Parse request fail, Invalid messages!"); + LOG(ERROR) << "Parse request fail, Invalid messages!"; + return; + } + + if (!rpc_service_->schedule(messages, &schedule_res)) { + cntl->SetFailed("Schedule fail!"); + LOG(ERROR) << "XllmRpcServiceImpl::schedule error!"; + return; + } + } else { + cntl->SetFailed("Input has no messages!"); + LOG(ERROR) << "Input has no messages!"; + return; + } + json_value["token_ids"] = schedule_res.token_ids; + json_value["routing"] = schedule_res.routing.serialize_to_json(); + + std::string req_attachment = json_value.dump(); auto arena = response->GetArena(); auto resp_pb = google::protobuf::Arena::CreateMessage(arena); auto call_data = std::make_shared( - cntl, stream, done_guard.release(), resp_pb, trace_callback); + cntl, stream, done_guard.release(), resp_pb); handle_v1_chat_completions(call_data, req_attachment, service_request_id, stream, model, include_usage, - target_uri); + schedule_res.routing.prefill_name); } else { LOG(ERROR) << "Not supported method: " << serving_method; cntl->SetFailed("Not supported method: " + serving_method); @@ -471,22 +453,18 @@ void XllmHttpServiceImpl::get_serving( // done_guard.release()); auto call_data = std::make_shared( cntl, false, done_guard.release(), nullptr); - std::string target_uri = get_redirect_uri(true /*only_prefill*/); - if (target_uri.empty()) { - cntl->SetFailed( - "Internal runtime error, can not found a running instance."); + + SchduleResult schedule_res; + if (!rpc_service_->schedule("", &schedule_res)) { + cntl->SetFailed("Schedule fail!"); + LOG(ERROR) << "XllmRpcServiceImpl::schedule error!"; return; } - if (cached_channels_.find(target_uri) == cached_channels_.end()) { - if (!create_channel(target_uri)) { - LOG(ERROR) << "Create channel failed, target_uri is " << target_uri; - cntl->SetFailed("Internal runtime error."); - return; - } - } - auto channel_ptr = cached_channels_[target_uri]; - target_uri += serving_method; + brpc::Channel* channel_ptr = + rpc_service_->get_channel(schedule_res.routing.prefill_name).get(); + std::string target_uri = schedule_res.routing.prefill_name + serving_method; + thread_pool_->schedule( [/*req_attachment, */ call_data, cntl, channel_ptr, target_uri]() { brpc::Controller* redirect_cntl = new brpc::Controller(); diff --git a/xllm_service/http_service/service.h b/xllm_service/http_service/service.h index 3bff5c4..69dd6cb 100644 --- a/xllm_service/http_service/service.h +++ b/xllm_service/http_service/service.h @@ -77,8 +77,7 @@ class XllmHttpServiceImpl : public proto::XllmHttpService { private: bool create_channel(const std::string& target_uri); - // only prefill is true means only prefill instance is returned - std::string get_redirect_uri(bool only_prefill = false); + void post_serving(const std::string& serving_method, ::google::protobuf::RpcController* controller, const proto::HttpRequest* request, @@ -124,13 +123,8 @@ class XllmHttpServiceImpl : public proto::XllmHttpService { std::shared_ptr rpc_service_; std::unique_ptr request_tracer_; - // uri -> channel - // e.g. 127.0.0.1:9999/v1/completions -> channel1 - // 127.0.0.1:9999/v1/chat/completions -> channel2 - // NOTE: different methods to one instance has different channels - std::unordered_map cached_channels_; + std::unique_ptr thread_pool_; - std::mutex channel_mutex_; // In disagg pd mode, we support receive generated token from // prefill or from decode directly. diff --git a/xllm_service/master.cpp b/xllm_service/master.cpp index a6ebe69..4156b68 100644 --- a/xllm_service/master.cpp +++ b/xllm_service/master.cpp @@ -15,6 +15,7 @@ limitations under the License. #include "master.h" +#include #include #include "common/global_gflags.h" @@ -35,14 +36,26 @@ Master::Master(const ServerOptions& server_options) rpc_config.detect_disconnected_instance_interval = server_options.detect_disconnected_instance_interval; - rpc_service_impl_ = - std::make_shared(rpc_config); - rpc_service_ = - std::make_unique(rpc_service_impl_); + rpc_config.service_name = server_options_.rpc_server_host + ":" + + std::to_string(server_options_.rpc_port); + + ModelConfig model_config; + model_config.block_size = server_options.block_size; + model_config.model_type = server_options.model_type; + model_config.tokenizer_path = server_options.tokenizer_path; - HttpServiceConfig http_config; + xllm_service::HttpServiceConfig http_config; http_config.num_threads = server_options.http_num_threads; + http_config.timeout_ms = server_options.timeout_ms; + http_config.test_instance_addr = server_options.test_instance_addr; http_config.enable_request_trace = server_options.enable_request_trace; + + rpc_service_impl_ = std::make_shared( + rpc_config, model_config, http_config); + + rpc_service_ = + std::make_unique(rpc_service_impl_); + http_service_ = std::make_unique( rpc_service_impl_, http_config); } @@ -160,13 +173,37 @@ void shutdown_handler(int signal) { exit(1); } +std::string get_local_ip() { + using namespace boost::asio; + io_service io; + ip::tcp::resolver resolver(io); + ip::tcp::resolver::query query(ip::host_name(), ""); + ip::tcp::resolver::iterator iter = resolver.resolve(query); + ip::tcp::resolver::iterator end; + + while (iter != end) { + ip::address addr = iter->endpoint().address(); + if (!addr.is_loopback() && addr.is_v4()) { + return addr.to_string(); + } + ++iter; + } + + LOG(FATAL) << "Get local ip faill!"; + return ""; +} + int main(int argc, char* argv[]) { // Initialize gflags gflags::ParseCommandLineFlags(&argc, &argv, true); // Initialize glog google::InitGoogleLogging(argv[0]); - FLAGS_logtostderr = true; + // FLAGS_logtostderr = true; + + LOG(INFO) << "Dump all gflags: " << std::endl + << google::CommandlineFlagsIntoString(); + google::FlushLogFiles(google::INFO); LOG(INFO) << "Starting xllm master service."; @@ -191,7 +228,7 @@ int main(int argc, char* argv[]) { server_options.http_idle_timeout_s = FLAGS_http_server_idle_timeout_s; server_options.http_num_threads = FLAGS_http_server_num_threads; server_options.http_max_concurrency = FLAGS_http_server_max_concurrency; - server_options.rpc_server_host = FLAGS_rpc_server_host; + server_options.rpc_server_host = get_local_ip(); server_options.rpc_port = FLAGS_rpc_server_port; server_options.rpc_idle_timeout_s = FLAGS_rpc_server_idle_timeout_s; server_options.rpc_num_threads = FLAGS_rpc_server_num_threads; @@ -201,10 +238,16 @@ int main(int argc, char* argv[]) { server_options.detect_disconnected_instance_interval = FLAGS_detect_disconnected_instance_interval; server_options.enable_request_trace = FLAGS_enable_request_trace; + + server_options.tokenizer_path = FLAGS_tokenizer_path; server_options.block_size = FLAGS_block_size; server_options.model_type = FLAGS_model_type; server_options.tokenizer_path = FLAGS_tokenizer_path; + server_options.num_threads = FLAGS_num_threads; + server_options.timeout_ms = FLAGS_timeout_ms; + server_options.test_instance_addr = FLAGS_test_instance_addr; + xllm_service::Master master(server_options); if (!master.start()) { diff --git a/xllm_service/master.h b/xllm_service/master.h index 92b9692..b242bb6 100644 --- a/xllm_service/master.h +++ b/xllm_service/master.h @@ -32,6 +32,9 @@ struct ServerOptions { int32_t http_num_threads = 32; int32_t http_max_concurrency = 128; bool enable_request_trace = false; + int num_threads = 16; + int timeout_ms = -1; + std::string test_instance_addr = ""; // rpc server options std::string rpc_server_host = ""; diff --git a/xllm_service/rpc_service/main.cpp b/xllm_service/rpc_service/main.cpp index fdd6b78..c96fb6e 100644 --- a/xllm_service/rpc_service/main.cpp +++ b/xllm_service/rpc_service/main.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include "common/global_gflags.h" +#include "common/types.h" #include "common/utils.h" #include "rpc_service/service.h" @@ -27,7 +28,10 @@ int main(int argc, char* argv[]) { // Initialize glog google::InitGoogleLogging(argv[0]); - FLAGS_logtostderr = true; + + LOG(INFO) << "Dump all gflags: " << std::endl + << google::CommandlineFlagsIntoString(); + google::FlushLogFiles(google::INFO); LOG(INFO) << "Starting xllm rpc service, port: " << FLAGS_port; @@ -43,9 +47,19 @@ int main(int argc, char* argv[]) { config.detect_disconnected_instance_interval = FLAGS_detect_disconnected_instance_interval; + xllm_service::ModelConfig model_config; + model_config.block_size = FLAGS_block_size; + model_config.model_type = FLAGS_model_type; + model_config.tokenizer_path = FLAGS_tokenizer_path; + + xllm_service::HttpServiceConfig http_config; + http_config.num_threads = FLAGS_num_threads; + http_config.timeout_ms = FLAGS_timeout_ms; + http_config.test_instance_addr = FLAGS_test_instance_addr; + // create xllm service - auto xllm_service_impl = - std::make_shared(config); + auto xllm_service_impl = std::make_shared( + config, model_config, http_config); xllm_service::XllmRpcService service(xllm_service_impl); // Initialize brpc server diff --git a/xllm_service/rpc_service/rpc_service_test.cpp b/xllm_service/rpc_service/rpc_service_test.cpp index d9f490d..36eb9cb 100644 --- a/xllm_service/rpc_service/rpc_service_test.cpp +++ b/xllm_service/rpc_service/rpc_service_test.cpp @@ -26,36 +26,44 @@ class XllmRpcServiceTest : public ::testing::Test { void TearDown() override { google::ShutdownGoogleLogging(); } }; +// TODO +// TEST_F(XllmRpcServiceTest, RegisterInstance) { +// RpcServiceConfig config; +// HttpServiceConfig http_config; +// ModelConfig model_config; +// auto xllm_service = +// std::make_shared(config, model_config, +// http_config); +// std::string inst_name = "127.0.0.1@nic0"; +// InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", +// InstanceType::PREFILL); EXPECT_EQ(ErrorCode::OK, +// xllm_service->register_instance(inst_name, metainfo)); -TEST_F(XllmRpcServiceTest, RegisterInstance) { - RpcServiceConfig config; - auto xllm_service = std::make_shared(config); - std::string inst_name = "127.0.0.1@nic0"; - InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", InstanceType::PREFILL); - EXPECT_EQ(ErrorCode::OK, - xllm_service->register_instance(inst_name, metainfo)); - - metainfo.type = InstanceType::DECODE; - EXPECT_EQ(ErrorCode::INSTANCE_EXISTED, - xllm_service->register_instance(inst_name, metainfo)); -} - -TEST_F(XllmRpcServiceTest, UpdateInstanceMetainfo) { - RpcServiceConfig config; - auto xllm_service = std::make_shared(config); - std::string inst_name = "127.0.0.1@nic0"; - InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", InstanceType::PREFILL); - EXPECT_EQ(ErrorCode::OK, - xllm_service->register_instance(inst_name, metainfo)); - metainfo.type = InstanceType::DECODE; - EXPECT_EQ(ErrorCode::OK, - xllm_service->update_instance_metainfo(inst_name, metainfo)); - - std::string inst_name2 = "127.0.0.1@nic2"; - InstanceMetaInfo metainfo2( - inst_name2, "127.0.0.1:7778", InstanceType::PREFILL); - EXPECT_EQ(ErrorCode::INSTANCE_NOT_EXISTED, - xllm_service->update_instance_metainfo(inst_name2, metainfo)); -} +// metainfo.type = InstanceType::DECODE; +// EXPECT_EQ(ErrorCode::INSTANCE_EXISTED, +// xllm_service->register_instance(inst_name, metainfo)); +// } + +// TEST_F(XllmRpcServiceTest, UpdateInstanceMetainfo) { +// RpcServiceConfig config; +// HttpServiceConfig http_config; +// ModelConfig model_config; +// auto xllm_service = +// std::make_shared(config, model_config, +// http_config); +// std::string inst_name = "127.0.0.1@nic0"; +// InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", +// InstanceType::PREFILL); EXPECT_EQ(ErrorCode::OK, +// xllm_service->register_instance(inst_name, metainfo)); +// metainfo.type = InstanceType::DECODE; +// EXPECT_EQ(ErrorCode::OK, +// xllm_service->update_instance_metainfo(inst_name, metainfo)); + +// std::string inst_name2 = "127.0.0.1@nic2"; +// InstanceMetaInfo metainfo2( +// inst_name2, "127.0.0.1:7778", InstanceType::PREFILL); +// EXPECT_EQ(ErrorCode::INSTANCE_NOT_EXISTED, +// xllm_service->update_instance_metainfo(inst_name2, metainfo)); +// } } // namespace xllm_service::test diff --git a/xllm_service/rpc_service/scheduler.cpp b/xllm_service/rpc_service/scheduler.cpp index 2d2e8e3..2c47597 100644 --- a/xllm_service/rpc_service/scheduler.cpp +++ b/xllm_service/rpc_service/scheduler.cpp @@ -78,8 +78,7 @@ bool Scheduler::schedule(const std::string& prompt, SchduleResult* res) { return false; } - Slice token_ids(res->token_ids.data(), - res->token_ids.size()); + Slice token_ids(res->token_ids.data(), res->token_ids.size()); global_kvcache_mgr_->match(token_ids, &lb_infos.overlap_scores); DLOG(INFO) << lb_infos.debug_string(); diff --git a/xllm_service/rpc_service/service.cpp b/xllm_service/rpc_service/service.cpp index e5ca3b6..75484ff 100644 --- a/xllm_service/rpc_service/service.cpp +++ b/xllm_service/rpc_service/service.cpp @@ -53,42 +53,45 @@ grpc::StatusCode to_grpc_status_code(llm::StatusCode code) { } } // namespace -XllmRpcServiceImpl::XllmRpcServiceImpl(const RpcServiceConfig& config) { +XllmRpcServiceImpl::XllmRpcServiceImpl(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config) { enable_decode_response_to_service_ = utils::get_bool_env("ENABLE_DECODE_RESPONSE_TO_SERVICE", false); - instance_mgr_ = std::make_unique(config); + + scheduler_ = + std::make_unique(rpc_config, model_config, http_config); } -XllmRpcServiceImpl::~XllmRpcServiceImpl() {} +XllmRpcServiceImpl::~XllmRpcServiceImpl() { scheduler_->exited(); } -ErrorCode XllmRpcServiceImpl::heartbeat(const std::string& instance_name) { - return instance_mgr_->heartbeat(instance_name); +void XllmRpcServiceImpl::heartbeat(const proto::HeartbeatRequest* req) { + scheduler_->handle_instance_heartbeat(req); } -ErrorCode XllmRpcServiceImpl::register_instance( - const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - return instance_mgr_->register_instance(instance_name, metainfo); +InstanceMetaInfo XllmRpcServiceImpl::get_instance_info( + const std::string& instance_name) { + return scheduler_->get_instance_info(instance_name); } -ErrorCode XllmRpcServiceImpl::update_instance_metainfo( - const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - return instance_mgr_->update_instance_metainfo(instance_name, metainfo); +std::vector XllmRpcServiceImpl::get_static_decode_list( + const std::string& instance_name) { + return scheduler_->get_static_decode_list(instance_name); } -InstancesPair XllmRpcServiceImpl::select_instances_pair(bool only_prefill) { - return instance_mgr_->select_instances_pair(only_prefill); +bool XllmRpcServiceImpl::schedule(const std::string& prompt, + SchduleResult* res) { + return scheduler_->schedule(prompt, res); } -InstanceMetaInfo XllmRpcServiceImpl::get_instance_info( - const std::string& instance_name) { - return instance_mgr_->get_instance_info(instance_name); +bool XllmRpcServiceImpl::schedule(const ChatMessages& messages, + SchduleResult* res) { + return scheduler_->schedule(messages, res); } -std::vector XllmRpcServiceImpl::get_static_decode_list( - const std::string& instance_name) { - return instance_mgr_->get_static_decode_list(instance_name); +std::shared_ptr XllmRpcServiceImpl::get_channel( + const std::string& target_name) { + return scheduler_->get_channel(target_name); } bool XllmRpcServiceImpl::handle_generation( @@ -274,32 +277,6 @@ void XllmRpcService::Hello(google::protobuf::RpcController* cntl_base, resp->set_ok(true); } -void XllmRpcService::RegisterInstance( - google::protobuf::RpcController* cntl_base, - const proto::InstanceMetaInfo* req, - proto::StatusCode* resp, - google::protobuf::Closure* done) { - brpc::ClosureGuard done_guard(done); - InstanceType type = InstanceType::DEFAULT; - if (req->has_type() && req->type() == proto::InstanceType::PREFILL) { - type = InstanceType::PREFILL; - } else if (req->has_type() && req->type() == proto::InstanceType::DECODE) { - type = InstanceType::DECODE; - } - InstanceMetaInfo metainfo(req->name(), req->rpc_address(), type); - metainfo.cluster_ids = std::vector(req->cluster_ids().begin(), - req->cluster_ids().end()); - metainfo.addrs = - std::vector(req->addrs().begin(), req->addrs().end()); - metainfo.k_cache_ids = std::vector(req->k_cache_ids().begin(), - req->k_cache_ids().end()); - metainfo.v_cache_ids = std::vector(req->v_cache_ids().begin(), - req->v_cache_ids().end()); - metainfo.dp_size = req->dp_size(); - ErrorCode code = xllm_service_->register_instance(req->name(), metainfo); - resp->set_status_code(ConvertErrorCode::to_int(code)); -} - void XllmRpcService::GetInstanceInfo(google::protobuf::RpcController* cntl_base, const proto::InstanceID* req, proto::InstanceMetaInfo* resp, @@ -335,9 +312,7 @@ void XllmRpcService::Heartbeat(google::protobuf::RpcController* cntl_base, proto::Status* resp, google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); - auto inst_name = req->name(); - // TODO: handle req.cache_event and req.load_metrics - xllm_service_->heartbeat(inst_name); + xllm_service_->heartbeat(req); resp->set_ok(true); } diff --git a/xllm_service/rpc_service/service.h b/xllm_service/rpc_service/service.h index 4575e11..dd92a30 100644 --- a/xllm_service/rpc_service/service.h +++ b/xllm_service/rpc_service/service.h @@ -26,6 +26,7 @@ limitations under the License. #include "completion.pb.h" #include "instance_mgr.h" #include "response_handler.h" +#include "scheduler.h" #include "xllm_rpc_service.pb.h" namespace xllm_service { @@ -41,25 +42,17 @@ struct ServiceConfig { class XllmRpcServiceImpl final { public: - XllmRpcServiceImpl(const RpcServiceConfig& config); + XllmRpcServiceImpl(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config); ~XllmRpcServiceImpl(); - ErrorCode heartbeat(const std::string& instance_name); - ErrorCode register_instance(const std::string& instance_name, - const InstanceMetaInfo& metainfo); - ErrorCode update_instance_metainfo(const std::string& instance_name, - const InstanceMetaInfo& metainfo); + void heartbeat(const proto::HeartbeatRequest* req); InstanceMetaInfo get_instance_info(const std::string& instance_name); ServiceConfig get_config(); - // methods for master - - // select instances(prefill/decode/default etc.) to handle request - // according the disagg pd policy (or some other policies.). - InstancesPair select_instances_pair(bool only_prefill = false); - std::vector get_static_decode_list( const std::string& prefill_name); @@ -82,9 +75,13 @@ class XllmRpcServiceImpl final { bool include_usage); void finish_request(const std::string& service_request_id); - private: - std::unique_ptr instance_mgr_; + bool schedule(const std::string& prompt, SchduleResult* res); + + bool schedule(const ChatMessages& messages, SchduleResult* res); + std::shared_ptr get_channel(const std::string& target_name); + + private: // `request` -> `callback` map std::unordered_map callbacks_; std::mutex callback_mutex_; @@ -114,6 +111,9 @@ class XllmRpcServiceImpl final { // used when receive token from decode instance. ResponseHandler response_handler_; + + // instance discovery by register to etcd + std::unique_ptr scheduler_; }; // parse proto data and call XllmRpcService @@ -127,11 +127,6 @@ class XllmRpcService : public proto::XllmRpcService { proto::Status* resp, google::protobuf::Closure* done) override; - virtual void RegisterInstance(google::protobuf::RpcController* cntl_base, - const proto::InstanceMetaInfo* req, - proto::StatusCode* resp, - google::protobuf::Closure* done) override; - virtual void Heartbeat(google::protobuf::RpcController* cntl_base, const proto::HeartbeatRequest* req, proto::Status* resp, From a0f81ce1e8e821ef696c582bcc1885887cce2990 Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Tue, 26 Aug 2025 22:46:01 +0800 Subject: [PATCH 09/11] feat: add extra thread to update instance info, load metrics, and kvcache info. --- xllm_service/common/types.h | 18 +- xllm_service/common/utils.cpp | 21 ++ xllm_service/common/utils.h | 1 + xllm_service/http_service/service.cpp | 12 +- xllm_service/master.cpp | 23 +-- .../rpc_service/global_kvcache_mgr.cpp | 66 ++++--- xllm_service/rpc_service/global_kvcache_mgr.h | 7 +- xllm_service/rpc_service/instance_mgr.cpp | 184 +++++++++--------- xllm_service/rpc_service/instance_mgr.h | 12 +- xllm_service/rpc_service/scheduler.cpp | 4 +- xllm_service/rpc_service/scheduler.h | 4 +- xllm_service/rpc_service/service.cpp | 4 +- xllm_service/rpc_service/service.h | 4 +- 13 files changed, 196 insertions(+), 164 deletions(-) diff --git a/xllm_service/common/types.h b/xllm_service/common/types.h index c1e591e..5f9600f 100644 --- a/xllm_service/common/types.h +++ b/xllm_service/common/types.h @@ -70,7 +70,7 @@ struct Routing { std::string debug_string() const { return serialize_to_json().dump(2); } }; -struct SchduleResult { +struct ScheduleResult { std::vector token_ids; Routing routing; }; @@ -275,10 +275,26 @@ struct CacheLocations { } }; +/** + * @brief Records the prefix cache match lengths for different instances on + * current request + * + * This struct stores and manages prefix cache matching information across + * multiple instances, supporting different storage types (HBM, DRAM, SSD) for + * match length recording, and tracks information about the best matching + * instance. + */ struct OverlapScores { + // Set of matched instance names std::unordered_set instances; + // HBM storage type instance match length mapping (instance name -> match + // length) std::unordered_map hbm_instance_score; + // DRAM storage type instance match length mapping (instance name -> match + // length) std::unordered_map dram_instance_score; + // SSD storage type instance match length mapping (instance name -> match + // length) std::unordered_map ssd_instance_score; uint32_t max_block_num = 0; uint32_t max_matched_block_num = 0; diff --git a/xllm_service/common/utils.cpp b/xllm_service/common/utils.cpp index 5047257..99dacd2 100644 --- a/xllm_service/common/utils.cpp +++ b/xllm_service/common/utils.cpp @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include namespace xllm_service { @@ -73,5 +74,25 @@ bool get_bool_env(const std::string& key, bool defaultValue) { strVal == "True"); } +std::string get_local_ip() { + using namespace boost::asio; + io_service io; + ip::tcp::resolver resolver(io); + ip::tcp::resolver::query query(ip::host_name(), ""); + ip::tcp::resolver::iterator iter = resolver.resolve(query); + ip::tcp::resolver::iterator end; + + while (iter != end) { + ip::address addr = iter->endpoint().address(); + if (!addr.is_loopback() && addr.is_v4()) { + return addr.to_string(); + } + ++iter; + } + + LOG(FATAL) << "Get local ip faill!"; + return ""; +} + } // namespace utils } // namespace xllm_service diff --git a/xllm_service/common/utils.h b/xllm_service/common/utils.h index 0b00ac8..12d335b 100644 --- a/xllm_service/common/utils.h +++ b/xllm_service/common/utils.h @@ -23,6 +23,7 @@ namespace utils { bool enable_debug_log(); bool is_port_available(int port); bool get_bool_env(const std::string& key, bool defaultValue); +std::string get_local_ip(); } // namespace utils } // namespace xllm_service diff --git a/xllm_service/http_service/service.cpp b/xllm_service/http_service/service.cpp index dbdac78..d28e8d9 100644 --- a/xllm_service/http_service/service.cpp +++ b/xllm_service/http_service/service.cpp @@ -336,7 +336,7 @@ void XllmHttpServiceImpl::post_serving( trace_callback = nullptr; } - SchduleResult schedule_res; + ScheduleResult schedule_res; if (serving_method == "/v1/completions") { if (json_value.contains("prompt")) { if (!rpc_service_->schedule(json_value.at("prompt").get(), @@ -350,8 +350,8 @@ void XllmHttpServiceImpl::post_serving( LOG(ERROR) << "Input has no prompt!"; return; } - json_value["token_ids"] = schedule_res.token_ids; - json_value["routing"] = schedule_res.routing.serialize_to_json(); + json_value["token_ids"] = std::move(schedule_res.token_ids); + json_value["routing"] = std::move(schedule_res.routing.serialize_to_json()); std::string req_attachment = json_value.dump(); auto arena = response->GetArena(); @@ -396,8 +396,8 @@ void XllmHttpServiceImpl::post_serving( LOG(ERROR) << "Input has no messages!"; return; } - json_value["token_ids"] = schedule_res.token_ids; - json_value["routing"] = schedule_res.routing.serialize_to_json(); + json_value["token_ids"] = std::move(schedule_res.token_ids); + json_value["routing"] = std::move(schedule_res.routing.serialize_to_json()); std::string req_attachment = json_value.dump(); auto arena = response->GetArena(); @@ -454,7 +454,7 @@ void XllmHttpServiceImpl::get_serving( auto call_data = std::make_shared( cntl, false, done_guard.release(), nullptr); - SchduleResult schedule_res; + ScheduleResult schedule_res; if (!rpc_service_->schedule("", &schedule_res)) { cntl->SetFailed("Schedule fail!"); LOG(ERROR) << "XllmRpcServiceImpl::schedule error!"; diff --git a/xllm_service/master.cpp b/xllm_service/master.cpp index 4156b68..19247da 100644 --- a/xllm_service/master.cpp +++ b/xllm_service/master.cpp @@ -15,7 +15,6 @@ limitations under the License. #include "master.h" -#include #include #include "common/global_gflags.h" @@ -173,26 +172,6 @@ void shutdown_handler(int signal) { exit(1); } -std::string get_local_ip() { - using namespace boost::asio; - io_service io; - ip::tcp::resolver resolver(io); - ip::tcp::resolver::query query(ip::host_name(), ""); - ip::tcp::resolver::iterator iter = resolver.resolve(query); - ip::tcp::resolver::iterator end; - - while (iter != end) { - ip::address addr = iter->endpoint().address(); - if (!addr.is_loopback() && addr.is_v4()) { - return addr.to_string(); - } - ++iter; - } - - LOG(FATAL) << "Get local ip faill!"; - return ""; -} - int main(int argc, char* argv[]) { // Initialize gflags gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -228,7 +207,7 @@ int main(int argc, char* argv[]) { server_options.http_idle_timeout_s = FLAGS_http_server_idle_timeout_s; server_options.http_num_threads = FLAGS_http_server_num_threads; server_options.http_max_concurrency = FLAGS_http_server_max_concurrency; - server_options.rpc_server_host = get_local_ip(); + server_options.rpc_server_host = xllm_service::utils::get_local_ip(); server_options.rpc_port = FLAGS_rpc_server_port; server_options.rpc_idle_timeout_s = FLAGS_rpc_server_idle_timeout_s; server_options.rpc_num_threads = FLAGS_rpc_server_num_threads; diff --git a/xllm_service/rpc_service/global_kvcache_mgr.cpp b/xllm_service/rpc_service/global_kvcache_mgr.cpp index 44a28dd..5be4e90 100644 --- a/xllm_service/rpc_service/global_kvcache_mgr.cpp +++ b/xllm_service/rpc_service/global_kvcache_mgr.cpp @@ -20,7 +20,7 @@ GlobalKVCacheMgr::GlobalKVCacheMgr( is_master_service_(is_master_service), etcd_client_(etcd_client) { if (!is_master_service_) { - auto handle_kvcache = std::bind(&GlobalKVCacheMgr::handle_kvcache_watch, + auto handle_kvcache = std::bind(&GlobalKVCacheMgr::update_kvcache, this, std::placeholders::_1, std::placeholders::_2); @@ -117,44 +117,48 @@ void GlobalKVCacheMgr::match(const Slice& token_ids, } } -void GlobalKVCacheMgr::handle_kvcache_watch(const etcd::Response& response, - const uint64_t prefix_len) { +void GlobalKVCacheMgr::update_kvcache(const etcd::Response& response, + const uint64_t prefix_len) { if (response.events().empty() || exited_) { return; } - - Murmur3KeyCacheMap put_map; - std::vector delete_list; - - for (const auto& event : response.events()) { - auto key = event.kv().key().substr(prefix_len); - - if (event.event_type() == etcd::Event::EventType::PUT) { - CacheLocations cachelocations; - auto json_str = event.kv().as_string(); - if (!cachelocations.parse_from_json(json_str)) { - LOG(ERROR) << "pase json:" << json_str << " error!"; - continue; + threadpool_.schedule([this, + response = std::move(response), + prefix_len = std::move(prefix_len)] { + if (exited_) return; + Murmur3KeyCacheMap put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + auto key = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + CacheLocations cachelocations; + auto json_str = event.kv().as_string(); + if (!cachelocations.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } + + put_map.insert_or_assign(Murmur3Key{key.c_str()}, + std::move(cachelocations)); + + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.emplace_back(Murmur3Key{key.c_str()}); } - - put_map.insert_or_assign(Murmur3Key{key.c_str()}, - std::move(cachelocations)); - - } else if (event.event_type() == etcd::Event::EventType::DELETE_) { - delete_list.emplace_back(Murmur3Key{key.c_str()}); } - } - { - std::unique_lock lock(kvcache_mutex_); - for (auto& iter : put_map) { - kvcache_infos_.insert_or_assign(iter.first, std::move(iter.second)); - } + { + std::unique_lock lock(kvcache_mutex_); + for (auto& iter : put_map) { + kvcache_infos_.insert_or_assign(iter.first, std::move(iter.second)); + } - for (auto& iter : delete_list) { - kvcache_infos_.erase(iter); + for (auto& iter : delete_list) { + kvcache_infos_.erase(iter); + } } - } + }); } void GlobalKVCacheMgr::record_updated_kvcaches( diff --git a/xllm_service/rpc_service/global_kvcache_mgr.h b/xllm_service/rpc_service/global_kvcache_mgr.h index f06abc3..384c3c6 100644 --- a/xllm_service/rpc_service/global_kvcache_mgr.h +++ b/xllm_service/rpc_service/global_kvcache_mgr.h @@ -6,6 +6,7 @@ #include "common/hash_util.h" #include "common/macros.h" #include "common/slice.h" +#include "common/threadpool.h" #include "common/types.h" #include "etcd_client.h" #include "xllm_rpc_service.pb.h" @@ -30,8 +31,8 @@ class GlobalKVCacheMgr final { private: DISALLOW_COPY_AND_ASSIGN(GlobalKVCacheMgr); - void handle_kvcache_watch(const etcd::Response& response, - const uint64_t prefix_len); + void update_kvcache(const etcd::Response& response, + const uint64_t prefix_len); private: ModelConfig model_config_; @@ -43,6 +44,8 @@ class GlobalKVCacheMgr final { std::mutex update_mutex_; Murmur3KeyCacheMap updated_kvcaches_; + + ThreadPool threadpool_; }; } // namespace xllm_service diff --git a/xllm_service/rpc_service/instance_mgr.cpp b/xllm_service/rpc_service/instance_mgr.cpp index 1d0e188..1c30bd4 100644 --- a/xllm_service/rpc_service/instance_mgr.cpp +++ b/xllm_service/rpc_service/instance_mgr.cpp @@ -45,7 +45,7 @@ InstanceMgr::InstanceMgr(const std::shared_ptr& etcd_client, is_master_service_(is_master_service), etcd_client_(etcd_client) { auto handle_instance_metainfo = - std::bind(&InstanceMgr::handle_instance_metainfo_watch, + std::bind(&InstanceMgr::update_instance_metainfo, this, std::placeholders::_1, std::placeholders::_2); @@ -53,11 +53,10 @@ InstanceMgr::InstanceMgr(const std::shared_ptr& etcd_client, etcd_client_->add_watch(it.second, handle_instance_metainfo); } if (!is_master_service_) { - auto handle_load_metrics = - std::bind(&InstanceMgr::handle_load_metrics_watch, - this, - std::placeholders::_1, - std::placeholders::_2); + auto handle_load_metrics = std::bind(&InstanceMgr::update_load_metrics, + this, + std::placeholders::_1, + std::placeholders::_2); etcd_client_->add_watch(ETCD_LOADMETRICS_PREFIX, handle_load_metrics); } @@ -73,7 +72,7 @@ void InstanceMgr::init() { LOG(INFO) << "Load instance info from etcd:" << instances_.size(); for (const auto& name : instances_) { if (!create_channel(name.first)) { - zombie_nodes_.insert(name.first); + // TODO: add retry instances_.erase(name.first); } } @@ -140,9 +139,9 @@ void InstanceMgr::get_load_metrics(LoadBalanceInfos* infos) { } std::string least_loaded_prefill_instance; - int32_t least_loaded_prefill_waiting_reqs = INT32_MAX; + float least_loaded_prefill_gpu_cache_usage_perc = 1; std::string least_loaded_decode_instance; - int32_t least_loaded_decode_waiting_reqs = INT32_MAX; + float least_loaded_decode_gpu_cache_usage_perc = 1; if (infos->prefill_load_metrics.size() == 0 || infos->decode_load_metrics.size() == 0) { @@ -150,17 +149,17 @@ void InstanceMgr::get_load_metrics(LoadBalanceInfos* infos) { auto instance_it = instances_.find(metric.first); if (instance_it != instances_.end()) { if (instance_it->second.type != InstanceType::DECODE) { - if (metric.second.waiting_requests_num < - least_loaded_prefill_waiting_reqs) { - least_loaded_prefill_waiting_reqs = - metric.second.waiting_requests_num; + if (metric.second.gpu_cache_usage_perc < + least_loaded_prefill_gpu_cache_usage_perc) { + least_loaded_prefill_gpu_cache_usage_perc = + metric.second.gpu_cache_usage_perc; least_loaded_prefill_instance = metric.first; } } else { - if (metric.second.waiting_requests_num < - least_loaded_decode_waiting_reqs) { - least_loaded_decode_waiting_reqs = - metric.second.waiting_requests_num; + if (metric.second.gpu_cache_usage_perc < + least_loaded_decode_gpu_cache_usage_perc) { + least_loaded_decode_gpu_cache_usage_perc = + metric.second.gpu_cache_usage_perc; least_loaded_decode_instance = metric.first; } } @@ -173,7 +172,6 @@ void InstanceMgr::get_load_metrics(LoadBalanceInfos* infos) { infos->prefill_load_metrics.insert( std::make_pair(least_loaded_prefill_instance, load_metrics_[least_loaded_prefill_instance])); - infos->prefill_max_waiting_requests_num = least_loaded_prefill_waiting_reqs; } if (infos->decode_load_metrics.size() == 0 && @@ -181,7 +179,6 @@ void InstanceMgr::get_load_metrics(LoadBalanceInfos* infos) { infos->decode_load_metrics.insert( std::make_pair(least_loaded_decode_instance, load_metrics_[least_loaded_decode_instance])); - infos->decode_max_waiting_requests_num = least_loaded_decode_waiting_reqs; } } @@ -251,99 +248,108 @@ bool InstanceMgr::create_channel(const std::string& instance_name) { return true; } -void InstanceMgr::handle_instance_metainfo_watch(const etcd::Response& response, - const uint64_t& prefix_len) { - if (response.events().empty()) { +void InstanceMgr::update_instance_metainfo(const etcd::Response& response, + const uint64_t& prefix_len) { + if (response.events().empty() || exited_) { return; } - std::unordered_map put_map; - std::vector delete_list; + threadpool_.schedule([this, + response = std::move(response), + prefix_len = std::move(prefix_len)] { + if (exited_) return; + std::unordered_map put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + std::string instance_name = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + InstanceMetaInfo metainfo; + auto json_str = event.kv().as_string(); + if (!metainfo.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } - for (const auto& event : response.events()) { - std::string instance_name = event.kv().key().substr(prefix_len); + put_map.insert(std::make_pair(instance_name, std::move(metainfo))); - if (event.event_type() == etcd::Event::EventType::PUT) { - InstanceMetaInfo metainfo; - auto json_str = event.kv().as_string(); - if (!metainfo.parse_from_json(json_str)) { - LOG(ERROR) << "pase json:" << json_str << " error!"; - continue; + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.push_back(instance_name); } - - put_map.insert(std::make_pair(instance_name, std::move(metainfo))); - - } else if (event.event_type() == etcd::Event::EventType::DELETE_) { - delete_list.push_back(instance_name); } - } - { - std::unique_lock lock(inst_mutex_); - for (auto& iter : put_map) { - if (instances_.find(iter.first) != instances_.end()) { - LOG(ERROR) << "Instance is already registered, instance_name: " - << iter.first; - continue; + { + std::unique_lock lock(inst_mutex_); + for (auto& iter : put_map) { + if (instances_.find(iter.first) != instances_.end()) { + LOG(ERROR) << "Instance is already registered, instance_name: " + << iter.first; + continue; + } + instances_.insert(std::make_pair(iter.first, std::move(iter.second))); + create_channel(iter.first); } - instances_.insert(std::make_pair(iter.first, std::move(iter.second))); - create_channel(iter.first); - } - for (auto& iter : delete_list) { - if (instances_.find(iter) == instances_.end()) { - LOG(ERROR) << "Instance is already deleted, instance_name: " << iter; - continue; - } - // TODO: notify cache manager to clear expire cache - instances_.erase(iter); - cached_channels_.erase(iter); - { - std::lock_guard lock(update_mutex_); - updated_metrics_.erase(iter); - removed_instance_.insert(iter); + for (auto& iter : delete_list) { + if (instances_.find(iter) == instances_.end()) { + LOG(ERROR) << "Instance is already deleted, instance_name: " << iter; + continue; + } + // TODO: notify cache manager to clear expire cache + instances_.erase(iter); + cached_channels_.erase(iter); + { + std::lock_guard lock(update_mutex_); + updated_metrics_.erase(iter); + removed_instance_.insert(iter); + } } } - } + }); } -void InstanceMgr::handle_load_metrics_watch(const etcd::Response& response, - const uint64_t prefix_len) { - if (response.events().empty()) { +void InstanceMgr::update_load_metrics(const etcd::Response& response, + const uint64_t& prefix_len) { + if (response.events().empty() || exited_) { return; } + threadpool_.schedule([this, + response = std::move(response), + prefix_len = std::move(prefix_len)] { + if (exited_) return; + std::unordered_map put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + std::string instance_name = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + LoadMetrics load_metrics; + auto json_str = event.kv().as_string(); + if (!load_metrics.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } - std::unordered_map put_map; - std::vector delete_list; - - for (const auto& event : response.events()) { - std::string instance_name = event.kv().key().substr(prefix_len); + put_map.insert(std::make_pair(instance_name, std::move(load_metrics))); - if (event.event_type() == etcd::Event::EventType::PUT) { - LoadMetrics load_metrics; - auto json_str = event.kv().as_string(); - if (!load_metrics.parse_from_json(json_str)) { - LOG(ERROR) << "pase json:" << json_str << " error!"; - continue; + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.push_back(instance_name); } - - put_map.insert(std::make_pair(instance_name, std::move(load_metrics))); - - } else if (event.event_type() == etcd::Event::EventType::DELETE_) { - delete_list.push_back(instance_name); } - } - { - std::unique_lock lock(load_metric_mutex_); - for (auto& iter : put_map) { - load_metrics_.insert_or_assign(iter.first, std::move(iter.second)); - } + { + std::unique_lock lock(load_metric_mutex_); + for (auto& iter : put_map) { + load_metrics_.insert_or_assign(iter.first, std::move(iter.second)); + } - for (auto& iter : delete_list) { - load_metrics_.erase(iter); + for (auto& iter : delete_list) { + load_metrics_.erase(iter); + } } - } + }); } } // namespace xllm_service diff --git a/xllm_service/rpc_service/instance_mgr.h b/xllm_service/rpc_service/instance_mgr.h index 175c91a..9eb8ff7 100644 --- a/xllm_service/rpc_service/instance_mgr.h +++ b/xllm_service/rpc_service/instance_mgr.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "common/macros.h" +#include "common/threadpool.h" #include "common/types.h" #include "etcd_client.h" #include "xllm_rpc_service.pb.h" @@ -59,11 +60,11 @@ class InstanceMgr final { bool create_channel(const std::string& target_uri); // use etcd as ServiceDiscovery - void handle_instance_metainfo_watch(const etcd::Response& response, - const uint64_t& prefix_len); + void update_instance_metainfo(const etcd::Response& response, + const uint64_t& prefix_len); - void handle_load_metrics_watch(const etcd::Response& response, - const uint64_t prefix_len); + void update_load_metrics(const etcd::Response& response, + const uint64_t& prefix_len); private: bool exited_ = false; @@ -82,11 +83,12 @@ class InstanceMgr final { std::unordered_map load_metrics_; std::unordered_map> cached_channels_; - std::unordered_set zombie_nodes_; std::mutex update_mutex_; std::unordered_map updated_metrics_; std::unordered_set removed_instance_; + + ThreadPool threadpool_; }; } // namespace xllm_service diff --git a/xllm_service/rpc_service/scheduler.cpp b/xllm_service/rpc_service/scheduler.cpp index 2c47597..5c4a5a1 100644 --- a/xllm_service/rpc_service/scheduler.cpp +++ b/xllm_service/rpc_service/scheduler.cpp @@ -54,7 +54,7 @@ Scheduler::Scheduler(const RpcServiceConfig& rpc_config, Scheduler::~Scheduler() { etcd_client_->stop_watch(); } -bool Scheduler::schedule(const ChatMessages& messages, SchduleResult* res) { +bool Scheduler::schedule(const ChatMessages& messages, ScheduleResult* res) { if (chat_template_ == nullptr) { LOG(ERROR) << "Chat template has not configured for model type: " << model_config_.model_type; @@ -70,7 +70,7 @@ bool Scheduler::schedule(const ChatMessages& messages, SchduleResult* res) { return schedule(prompt.value(), res); } -bool Scheduler::schedule(const std::string& prompt, SchduleResult* res) { +bool Scheduler::schedule(const std::string& prompt, ScheduleResult* res) { LoadBalanceInfos lb_infos; if (prompt.size() != 0) { if (!get_tls_tokenizer()->encode(prompt, &res->token_ids)) { diff --git a/xllm_service/rpc_service/scheduler.h b/xllm_service/rpc_service/scheduler.h index cbc37cc..209d865 100644 --- a/xllm_service/rpc_service/scheduler.h +++ b/xllm_service/rpc_service/scheduler.h @@ -29,9 +29,9 @@ class Scheduler { ~Scheduler(); - bool schedule(const ChatMessages& messages, SchduleResult* res); + bool schedule(const ChatMessages& messages, ScheduleResult* res); - bool schedule(const std::string& prompt, SchduleResult* res); + bool schedule(const std::string& prompt, ScheduleResult* res); std::shared_ptr get_channel(const std::string& target_name); diff --git a/xllm_service/rpc_service/service.cpp b/xllm_service/rpc_service/service.cpp index 75484ff..6e696db 100644 --- a/xllm_service/rpc_service/service.cpp +++ b/xllm_service/rpc_service/service.cpp @@ -80,12 +80,12 @@ std::vector XllmRpcServiceImpl::get_static_decode_list( } bool XllmRpcServiceImpl::schedule(const std::string& prompt, - SchduleResult* res) { + ScheduleResult* res) { return scheduler_->schedule(prompt, res); } bool XllmRpcServiceImpl::schedule(const ChatMessages& messages, - SchduleResult* res) { + ScheduleResult* res) { return scheduler_->schedule(messages, res); } diff --git a/xllm_service/rpc_service/service.h b/xllm_service/rpc_service/service.h index dd92a30..5a4f878 100644 --- a/xllm_service/rpc_service/service.h +++ b/xllm_service/rpc_service/service.h @@ -75,9 +75,9 @@ class XllmRpcServiceImpl final { bool include_usage); void finish_request(const std::string& service_request_id); - bool schedule(const std::string& prompt, SchduleResult* res); + bool schedule(const std::string& prompt, ScheduleResult* res); - bool schedule(const ChatMessages& messages, SchduleResult* res); + bool schedule(const ChatMessages& messages, ScheduleResult* res); std::shared_ptr get_channel(const std::string& target_name); From f30a8082b8244f6afc1a5fb7975e6772126efaaf Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Thu, 28 Aug 2025 23:19:25 +0800 Subject: [PATCH 10/11] feat: fix rebase confict. --- xllm_service/chat_template/CMakeLists.txt | 2 - .../chat_template/chat_template_factory.cpp | 50 ------------------- .../chat_template/chat_template_factory.h | 12 ----- xllm_service/rpc_service/scheduler.cpp | 8 +-- xllm_service/rpc_service/scheduler.h | 4 +- xllm_service/tokenizer/tokenizer_args.cpp | 4 +- xllm_service/tokenizer/tokenizer_args.h | 2 +- xllm_service/tokenizer/tokenizer_factory.cpp | 14 ++++-- xllm_service/tokenizer/tokenizer_factory.h | 2 +- 9 files changed, 18 insertions(+), 80 deletions(-) delete mode 100644 xllm_service/chat_template/chat_template_factory.cpp delete mode 100644 xllm_service/chat_template/chat_template_factory.h diff --git a/xllm_service/chat_template/CMakeLists.txt b/xllm_service/chat_template/CMakeLists.txt index d47e081..d5c6081 100644 --- a/xllm_service/chat_template/CMakeLists.txt +++ b/xllm_service/chat_template/CMakeLists.txt @@ -6,10 +6,8 @@ cc_library ( chat_template HDRS jinja_chat_template.h - chat_template_factory.h SRCS jinja_chat_template.cpp - chat_template_factory.cpp DEPS :minja :tokenizer diff --git a/xllm_service/chat_template/chat_template_factory.cpp b/xllm_service/chat_template/chat_template_factory.cpp deleted file mode 100644 index 8e347c4..0000000 --- a/xllm_service/chat_template/chat_template_factory.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "chat_template/chat_template_factory.h" - -#include - -#include "chat_template/coded_chat_template.h" -#include "chat_template/common_chat_template.h" -#include "chat_template/jinja_chat_template.h" - -namespace xllm_service { - -constexpr std::array JINJA_CHAT_TEMPLATE_MODELS{ - "deepseek_v3_mtp", - "deepseek_v2", - "deepseek_v3", - "qwen2", - "qwen3"}; - -constexpr bool is_jinja_model(std::string_view model) { - for (auto m : JINJA_CHAT_TEMPLATE_MODELS) { - if (m == model) return true; - } - return false; -} - -std::unique_ptr create_chat_template( - const std::string& model_type, - const TokenizerArgs& tokenizer_args) { - if (is_jinja_model(model_type)) { - return std::make_unique(tokenizer_args); - } else if (model_type == "chatglm") { - return std::make_unique(); - } else if (model_type == "chatglm4") { - return std::make_unique(); - } else if (model_type == "llama") { - return std::make_unique(); - } else if (model_type == "llama3") { - return std::make_unique(); - } else if (model_type == "rhino") { - return std::make_unique(); - } else if (model_type == "minicpmv") { - return std::make_unique(); - } else if (model_type == "qwen") { - return std::make_unique(); - } else { - LOG(FATAL) << "Unknow model: " << model_type - << ", create ChatTemplate fail!"; - } -} - -} // namespace xllm_service diff --git a/xllm_service/chat_template/chat_template_factory.h b/xllm_service/chat_template/chat_template_factory.h deleted file mode 100644 index 9c2d939..0000000 --- a/xllm_service/chat_template/chat_template_factory.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include "chat_template.h" -#include "tokenizer/tokenizer_args.h" - -namespace xllm_service { - -std::unique_ptr create_chat_template( - const std::string& model_type, - const TokenizerArgs& tokenizer_args); - -} // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/rpc_service/scheduler.cpp b/xllm_service/rpc_service/scheduler.cpp index 5c4a5a1..b63e523 100644 --- a/xllm_service/rpc_service/scheduler.cpp +++ b/xllm_service/rpc_service/scheduler.cpp @@ -4,7 +4,7 @@ #include #include -#include "chat_template/chat_template_factory.h" +#include "chat_template/jinja_chat_template.h" #include "common.pb.h" #include "common/hash_util.h" #include "tokenizer/tokenizer_factory.h" @@ -20,9 +20,9 @@ Scheduler::Scheduler(const RpcServiceConfig& rpc_config, : rpc_config_(rpc_config), model_config_(model_config), http_config_(http_config) { - tokenizer_ = create_tokenizer(model_config_, &tokenizer_args_); - chat_template_ = - create_chat_template(model_config_.model_type, tokenizer_args_); + tokenizer_ = TokenizerFactory::create_tokenizer(model_config_.tokenizer_path, + &tokenizer_args_); + chat_template_ = std::make_unique(tokenizer_args_); etcd_client_ = std::make_shared(rpc_config_.etcd_addr); diff --git a/xllm_service/rpc_service/scheduler.h b/xllm_service/rpc_service/scheduler.h index 209d865..df0a859 100644 --- a/xllm_service/rpc_service/scheduler.h +++ b/xllm_service/rpc_service/scheduler.h @@ -7,7 +7,7 @@ #include #include -#include "chat_template/chat_template.h" +#include "chat_template/jinja_chat_template.h" #include "common/hash_util.h" #include "common/macros.h" #include "common/types.h" @@ -68,7 +68,7 @@ class Scheduler { HttpServiceConfig http_config_; // chat template instance - std::unique_ptr chat_template_; + std::unique_ptr chat_template_; std::shared_ptr etcd_client_; diff --git a/xllm_service/tokenizer/tokenizer_args.cpp b/xllm_service/tokenizer/tokenizer_args.cpp index e463296..28f7bb3 100644 --- a/xllm_service/tokenizer/tokenizer_args.cpp +++ b/xllm_service/tokenizer/tokenizer_args.cpp @@ -27,7 +27,7 @@ std::optional load_chat_template_file(const std::string& dir) { } } // namespace -bool load_tokenizer_args(const std::string& model_weights_path, +void load_tokenizer_args(const std::string& model_weights_path, TokenizerArgs& tokenizer_args) { // tokenizer args from tokenizer_config.json JsonReader tokenizer_reader; @@ -68,8 +68,6 @@ bool load_tokenizer_args(const std::string& model_weights_path, tokenizer_args.pad_token() = v.value(); } } - - return true; } } // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/tokenizer/tokenizer_args.h b/xllm_service/tokenizer/tokenizer_args.h index c1951e1..34e6443 100644 --- a/xllm_service/tokenizer/tokenizer_args.h +++ b/xllm_service/tokenizer/tokenizer_args.h @@ -93,7 +93,7 @@ inline std::ostream& operator<<(std::ostream& os, const TokenizerArgs& args) { return os; } -bool load_tokenizer_args(const std::string& model_weights_path, +void load_tokenizer_args(const std::string& model_weights_path, TokenizerArgs& tokenizer_args); } // namespace xllm_service diff --git a/xllm_service/tokenizer/tokenizer_factory.cpp b/xllm_service/tokenizer/tokenizer_factory.cpp index 204d868..120cb77 100644 --- a/xllm_service/tokenizer/tokenizer_factory.cpp +++ b/xllm_service/tokenizer/tokenizer_factory.cpp @@ -2,28 +2,32 @@ #include +#include "tokenizer_args.h" + namespace xllm_service { std::unique_ptr TokenizerFactory::create_tokenizer( const std::string& model_weights_path, - TokenizerArgs tokenizer_args) { + TokenizerArgs* tokenizer_args) { + load_tokenizer_args(model_weights_path, *tokenizer_args); + const std::string tokenizer_json_path = model_weights_path + "/tokenizer.json"; if (std::filesystem::exists(tokenizer_json_path)) { // 1. fast tokenizer LOG(INFO) << "Create fast tokenizer."; return std::make_unique(tokenizer_json_path); - } else if (tokenizer_args.tokenizer_type() == "tiktoken" || - tokenizer_args.tokenizer_class() == "TikTokenTokenizer") { + } else if (tokenizer_args->tokenizer_type() == "tiktoken" || + tokenizer_args->tokenizer_class() == "TikTokenTokenizer") { // 2. create tiktoken tokenizer LOG(INFO) << "Create Tiktoken tokenizer."; return std::make_unique(model_weights_path, - tokenizer_args); + *tokenizer_args); } else { // 3. create sentencepiece tokenizer LOG(INFO) << "Create SentencePiece tokenizer."; return std::make_unique(model_weights_path, - tokenizer_args); + *tokenizer_args); } } diff --git a/xllm_service/tokenizer/tokenizer_factory.h b/xllm_service/tokenizer/tokenizer_factory.h index 2120071..5afd20d 100644 --- a/xllm_service/tokenizer/tokenizer_factory.h +++ b/xllm_service/tokenizer/tokenizer_factory.h @@ -11,7 +11,7 @@ class TokenizerFactory { public: static std::unique_ptr create_tokenizer( const std::string& model_weights_path, - TokenizerArgs tokenizer_args); + TokenizerArgs* tokenizer_args); }; } // namespace xllm_service From 38947ab680710fc13ce59368ff824e9e80351dda Mon Sep 17 00:00:00 2001 From: kangmeng3 Date: Fri, 29 Aug 2025 17:10:06 +0800 Subject: [PATCH 11/11] feat: add roundrobin policy. --- xllm_service/common/global_gflags.cpp | 4 +- xllm_service/common/global_gflags.h | 2 +- xllm_service/common/types.h | 4 +- xllm_service/master.cpp | 4 +- xllm_service/master.h | 2 +- xllm_service/rpc_service/CMakeLists.txt | 12 +-- .../rpc_service/etcd_client/CMakeLists.txt | 17 +++ .../{ => etcd_client}/etcd_client.cpp | 0 .../{ => etcd_client}/etcd_client.h | 0 .../loadbalance_policy/CMakeLists.txt | 6 +- .../cache_aware_routing.cpp | 87 +++++++++++++++ .../loadbalance_policy/cache_aware_routing.h | 44 ++++++++ .../loadbalance_policy/loadbalance_policy.cpp | 56 ---------- .../loadbalance_policy/loadbalance_policy.h | 40 ++++--- .../loadbalance_policy/round_robin.cpp | 26 +++++ .../loadbalance_policy/round_robin.h | 37 +++++++ xllm_service/rpc_service/main.cpp | 2 +- .../rpc_service/managers/CMakeLists.txt | 22 ++++ .../{ => managers}/global_kvcache_mgr.cpp | 15 +++ .../{ => managers}/global_kvcache_mgr.h | 17 ++- .../{ => managers}/instance_mgr.cpp | 101 ++++++++++++++++-- .../rpc_service/{ => managers}/instance_mgr.h | 8 +- xllm_service/rpc_service/scheduler.cpp | 34 +++--- xllm_service/rpc_service/scheduler.h | 10 +- xllm_service/rpc_service/service.h | 2 +- 25 files changed, 429 insertions(+), 123 deletions(-) create mode 100644 xllm_service/rpc_service/etcd_client/CMakeLists.txt rename xllm_service/rpc_service/{ => etcd_client}/etcd_client.cpp (100%) rename xllm_service/rpc_service/{ => etcd_client}/etcd_client.h (100%) create mode 100644 xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.cpp create mode 100644 xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.h delete mode 100644 xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp create mode 100644 xllm_service/rpc_service/loadbalance_policy/round_robin.cpp create mode 100644 xllm_service/rpc_service/loadbalance_policy/round_robin.h create mode 100644 xllm_service/rpc_service/managers/CMakeLists.txt rename xllm_service/rpc_service/{ => managers}/global_kvcache_mgr.cpp (92%) rename xllm_service/rpc_service/{ => managers}/global_kvcache_mgr.h (64%) rename xllm_service/rpc_service/{ => managers}/instance_mgr.cpp (78%) rename xllm_service/rpc_service/{ => managers}/instance_mgr.h (91%) diff --git a/xllm_service/common/global_gflags.cpp b/xllm_service/common/global_gflags.cpp index 15e0279..bc15ad2 100644 --- a/xllm_service/common/global_gflags.cpp +++ b/xllm_service/common/global_gflags.cpp @@ -78,7 +78,9 @@ DEFINE_int32(idle_timeout_s, "Connection will be closed if there is no " "read/write operations during the last `idle_timeout_s'"); -DEFINE_string(disagg_pd_policy, "RR", "Disaggregated prefill-decode policy."); +DEFINE_string(load_balance_policy, + "RR", + "Disaggregated prefill-decode policy."); DEFINE_int32(detect_disconnected_instance_interval, 15, diff --git a/xllm_service/common/global_gflags.h b/xllm_service/common/global_gflags.h index 3c09ee8..b513c00 100644 --- a/xllm_service/common/global_gflags.h +++ b/xllm_service/common/global_gflags.h @@ -53,7 +53,7 @@ DECLARE_int32(max_concurrency); DECLARE_string(etcd_addr); -DECLARE_string(disagg_pd_policy); +DECLARE_string(load_balance_policy); DECLARE_int32(detect_disconnected_instance_interval); diff --git a/xllm_service/common/types.h b/xllm_service/common/types.h index 5f9600f..c17d834 100644 --- a/xllm_service/common/types.h +++ b/xllm_service/common/types.h @@ -45,7 +45,7 @@ struct HttpServiceConfig { struct RpcServiceConfig { std::string etcd_addr = ""; - std::string disagg_pd_policy = ""; + std::string load_balance_policy = ""; int detect_disconnected_instance_interval = 15; // seconds std::string service_name = ""; }; @@ -163,6 +163,8 @@ struct InstanceMetaInfo { // latest heatbeat timestamp uint64_t latest_timestamp = 0; + uint64_t instance_index = -1; + nlohmann::json serialize_to_json() const { nlohmann::json json_val; json_val["name"] = name; diff --git a/xllm_service/master.cpp b/xllm_service/master.cpp index 19247da..72983c4 100644 --- a/xllm_service/master.cpp +++ b/xllm_service/master.cpp @@ -31,7 +31,7 @@ Master::Master(const ServerOptions& server_options) } RpcServiceConfig rpc_config; rpc_config.etcd_addr = server_options.etcd_addr; - rpc_config.disagg_pd_policy = server_options.disagg_pd_policy; + rpc_config.load_balance_policy = server_options.load_balance_policy; rpc_config.detect_disconnected_instance_interval = server_options.detect_disconnected_instance_interval; @@ -213,7 +213,7 @@ int main(int argc, char* argv[]) { server_options.rpc_num_threads = FLAGS_rpc_server_num_threads; server_options.rpc_max_concurrency = FLAGS_rpc_server_max_concurrency; server_options.etcd_addr = FLAGS_etcd_addr; - server_options.disagg_pd_policy = FLAGS_disagg_pd_policy; + server_options.load_balance_policy = FLAGS_load_balance_policy; server_options.detect_disconnected_instance_interval = FLAGS_detect_disconnected_instance_interval; server_options.enable_request_trace = FLAGS_enable_request_trace; diff --git a/xllm_service/master.h b/xllm_service/master.h index b242bb6..cafea81 100644 --- a/xllm_service/master.h +++ b/xllm_service/master.h @@ -43,7 +43,7 @@ struct ServerOptions { int32_t rpc_num_threads = 32; int32_t rpc_max_concurrency = 128; std::string etcd_addr = ""; - std::string disagg_pd_policy = "RR"; + std::string load_balance_policy = "RR"; int32_t detect_disconnected_instance_interval = 15; int32_t block_size = 16; std::string model_type = "chatglm"; diff --git a/xllm_service/rpc_service/CMakeLists.txt b/xllm_service/rpc_service/CMakeLists.txt index 55fb4e9..cdcb57c 100644 --- a/xllm_service/rpc_service/CMakeLists.txt +++ b/xllm_service/rpc_service/CMakeLists.txt @@ -1,6 +1,8 @@ include(cc_binary) include(cc_library) include(cc_test) +add_subdirectory(etcd_client) +add_subdirectory(managers) add_subdirectory(loadbalance_policy) cc_library( @@ -8,25 +10,19 @@ cc_library( xllm_rpc_service HDRS scheduler.h - etcd_client.h - instance_mgr.h - global_kvcache_mgr.h response_handler.h service.h SRCS scheduler.cpp - etcd_client.cpp - instance_mgr.cpp - global_kvcache_mgr.cpp response_handler.cpp service.cpp DEPS :common + :etcd_client + :managers :loadbalance_policy absl::random_random absl::strings - cpprest - etcd-cpp-api glog::glog nlohmann_json::nlohmann_json proto::proto_rpc_service diff --git a/xllm_service/rpc_service/etcd_client/CMakeLists.txt b/xllm_service/rpc_service/etcd_client/CMakeLists.txt new file mode 100644 index 0000000..9a1df93 --- /dev/null +++ b/xllm_service/rpc_service/etcd_client/CMakeLists.txt @@ -0,0 +1,17 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + etcd_client + HDRS + etcd_client.h + SRCS + etcd_client.cpp + DEPS + :common + cpprest + etcd-cpp-api + glog::glog + nlohmann_json::nlohmann_json +) \ No newline at end of file diff --git a/xllm_service/rpc_service/etcd_client.cpp b/xllm_service/rpc_service/etcd_client/etcd_client.cpp similarity index 100% rename from xllm_service/rpc_service/etcd_client.cpp rename to xllm_service/rpc_service/etcd_client/etcd_client.cpp diff --git a/xllm_service/rpc_service/etcd_client.h b/xllm_service/rpc_service/etcd_client/etcd_client.h similarity index 100% rename from xllm_service/rpc_service/etcd_client.h rename to xllm_service/rpc_service/etcd_client/etcd_client.h diff --git a/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt b/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt index 06c29ef..3d9fb9a 100644 --- a/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt +++ b/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt @@ -7,8 +7,12 @@ cc_library( loadbalance_policy HDRS loadbalance_policy.h + round_robin.h + cache_aware_routing.h SRCS - loadbalance_policy.cpp + round_robin.cpp + cache_aware_routing.cpp DEPS :common + :managers ) diff --git a/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.cpp b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.cpp new file mode 100644 index 0000000..11ada6f --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.cpp @@ -0,0 +1,87 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "cache_aware_routing.h" + +namespace xllm_service { + +constexpr float MIN_SCORE = -2.0; + +bool CacheAwareRouting::select_instances_pair(ScheduleResult* res) { + LoadBalanceInfos lb_infos; + if (!res->token_ids.empty()) { + Slice token_ids(res->token_ids.data(), res->token_ids.size()); + global_kvcache_mgr_->match(token_ids, &lb_infos.overlap_scores); + DLOG(INFO) << lb_infos.debug_string(); + } + + instance_mgr_->get_load_metrics(&lb_infos); + DLOG(INFO) << lb_infos.debug_string(); + + if (lb_infos.prefill_load_metrics.size() == 0) { + LOG(INFO) << "No node available!"; + return false; + } + + // find preifll + cost_function(lb_infos.overlap_scores.hbm_instance_score, + lb_infos.overlap_scores.max_block_num, + lb_infos.prefill_load_metrics, + lb_infos.prefill_max_waiting_requests_num, + &res->routing.prefill_name); + + // find decode + if (lb_infos.decode_load_metrics.size()) { + cost_function(lb_infos.overlap_scores.hbm_instance_score, + lb_infos.overlap_scores.max_block_num, + lb_infos.decode_load_metrics, + lb_infos.decode_max_waiting_requests_num, + &res->routing.decode_name); + } + + return true; +} + +void CacheAwareRouting::cost_function( + const std::unordered_map& overlap_scores, + const uint32_t& max_block_num, + const std::unordered_map& load_metrics, + const int64_t& max_waiting_requests_num, + std::string* best_choice) { + float best_score = MIN_SCORE; + for (const auto& it : load_metrics) { + const auto matched_blocks_it = overlap_scores.find(it.first); + uint32_t matched_blocks = 0; + if (matched_blocks_it != overlap_scores.end()) { + matched_blocks = matched_blocks_it->second; + } + + auto score = + (max_block_num == 0 ? 0 : matched_blocks / max_block_num) - + it.second.gpu_cache_usage_perc - + (max_waiting_requests_num == 0 + ? 0 + : it.second.waiting_requests_num / max_waiting_requests_num); + + if (score > best_score) { + best_score = score; + *best_choice = it.first; + } + } +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.h b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.h new file mode 100644 index 0000000..98c8605 --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "common/macros.h" +#include "loadbalance_policy.h" + +namespace xllm_service { + +class CacheAwareRouting final : public LoadBalancePolicy { + public: + CacheAwareRouting(std::shared_ptr instance_mgr, + std::shared_ptr global_kvcache_mgr) + : LoadBalancePolicy(instance_mgr, global_kvcache_mgr) {}; + + virtual ~CacheAwareRouting() = default; + + bool select_instances_pair(ScheduleResult* res) override; + + protected: + DISALLOW_COPY_AND_ASSIGN(CacheAwareRouting); + + void cost_function( + const std::unordered_map& overlap_scores, + const uint32_t& max_block_num, + const std::unordered_map& load_metrics, + const int64_t& max_waiting_requests_num, + std::string* best_choice); +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp deleted file mode 100644 index 5af1048..0000000 --- a/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once - -#include "loadbalance_policy.h" - -namespace xllm_service { - -constexpr float MIN_SCORE = -2.0; - -void LoadBalancePolicy::select_instances_pair(const LoadBalanceInfos& infos, - Routing* routing) { - // find preifll - cost_function(infos.overlap_scores.hbm_instance_score, - infos.overlap_scores.max_block_num, - infos.prefill_load_metrics, - infos.prefill_max_waiting_requests_num, - &routing->prefill_name); - - // find decode - if (infos.decode_load_metrics.size()) { - cost_function(infos.overlap_scores.hbm_instance_score, - infos.overlap_scores.max_block_num, - infos.decode_load_metrics, - infos.decode_max_waiting_requests_num, - &routing->decode_name); - } -} - -void LoadBalancePolicy::cost_function( - const std::unordered_map& overlap_scores, - const uint32_t& max_block_num, - const std::unordered_map& load_metrics, - const int64_t& max_waiting_requests_num, - std::string* best_choice) { - float best_score = MIN_SCORE; - for (const auto& it : load_metrics) { - const auto matched_blocks_it = overlap_scores.find(it.first); - uint32_t matched_blocks = 0; - if (matched_blocks_it != overlap_scores.end()) { - matched_blocks = matched_blocks_it->second; - } - - auto score = - (max_block_num == 0 ? 0 : matched_blocks / max_block_num) - - it.second.gpu_cache_usage_perc - - (max_waiting_requests_num == 0 - ? 0 - : it.second.waiting_requests_num / max_waiting_requests_num); - - if (score > best_score) { - best_score = score; - *best_choice = it.first; - } - } -} - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h index 9e83f54..205e880 100644 --- a/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h +++ b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h @@ -1,32 +1,40 @@ -#pragma once +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE -#include -#include -#include +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ -#include "common/macros.h" +#pragma once + +#include "../managers/global_kvcache_mgr.h" +#include "../managers/instance_mgr.h" #include "common/types.h" namespace xllm_service { class LoadBalancePolicy { public: - LoadBalancePolicy() = default; + LoadBalancePolicy(std::shared_ptr instance_mgr, + std::shared_ptr global_kvcache_mgr) + : instance_mgr_(instance_mgr), global_kvcache_mgr_(global_kvcache_mgr) {} virtual ~LoadBalancePolicy() = default; - virtual void select_instances_pair(const LoadBalanceInfos& infos, - Routing* routing); + virtual bool select_instances_pair(ScheduleResult* res) = 0; protected: - DISALLOW_COPY_AND_ASSIGN(LoadBalancePolicy); - - virtual void cost_function( - const std::unordered_map& overlap_scores, - const uint32_t& max_block_num, - const std::unordered_map& load_metrics, - const int64_t& max_waiting_requests_num, - std::string* best_choice); + std::shared_ptr instance_mgr_; + + std::shared_ptr global_kvcache_mgr_; }; } // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/round_robin.cpp b/xllm_service/rpc_service/loadbalance_policy/round_robin.cpp new file mode 100644 index 0000000..0ed8a1d --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/round_robin.cpp @@ -0,0 +1,26 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "round_robin.h" + +namespace xllm_service { + +bool RoundRobin::select_instances_pair(ScheduleResult* res) { + return instance_mgr_->get_next_instance_pair(&res->routing); +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/round_robin.h b/xllm_service/rpc_service/loadbalance_policy/round_robin.h new file mode 100644 index 0000000..2616ad8 --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/round_robin.h @@ -0,0 +1,37 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "common/macros.h" +#include "loadbalance_policy.h" + +namespace xllm_service { + +class RoundRobin final : public LoadBalancePolicy { + public: + RoundRobin(std::shared_ptr instance_mgr, + std::shared_ptr global_kvcache_mgr) + : LoadBalancePolicy(instance_mgr, global_kvcache_mgr) {}; + + virtual ~RoundRobin() = default; + + bool select_instances_pair(ScheduleResult* res) override; + + protected: + DISALLOW_COPY_AND_ASSIGN(RoundRobin); +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/main.cpp b/xllm_service/rpc_service/main.cpp index c96fb6e..3707eb5 100644 --- a/xllm_service/rpc_service/main.cpp +++ b/xllm_service/rpc_service/main.cpp @@ -43,7 +43,7 @@ int main(int argc, char* argv[]) { xllm_service::RpcServiceConfig config; config.etcd_addr = FLAGS_etcd_addr; - config.disagg_pd_policy = FLAGS_disagg_pd_policy; + config.load_balance_policy = FLAGS_load_balance_policy; config.detect_disconnected_instance_interval = FLAGS_detect_disconnected_instance_interval; diff --git a/xllm_service/rpc_service/managers/CMakeLists.txt b/xllm_service/rpc_service/managers/CMakeLists.txt new file mode 100644 index 0000000..79f5b49 --- /dev/null +++ b/xllm_service/rpc_service/managers/CMakeLists.txt @@ -0,0 +1,22 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + managers + HDRS + instance_mgr.h + global_kvcache_mgr.h + SRCS + instance_mgr.cpp + global_kvcache_mgr.cpp + DEPS + :common + :etcd_client + absl::random_random + absl::strings + glog::glog + proto::proto_rpc_service + proto_xllm +) +target_link_libraries(managers PRIVATE brpc-static) diff --git a/xllm_service/rpc_service/global_kvcache_mgr.cpp b/xllm_service/rpc_service/managers/global_kvcache_mgr.cpp similarity index 92% rename from xllm_service/rpc_service/global_kvcache_mgr.cpp rename to xllm_service/rpc_service/managers/global_kvcache_mgr.cpp index 5be4e90..6d6484e 100644 --- a/xllm_service/rpc_service/global_kvcache_mgr.cpp +++ b/xllm_service/rpc_service/managers/global_kvcache_mgr.cpp @@ -1,3 +1,18 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "global_kvcache_mgr.h" #include diff --git a/xllm_service/rpc_service/global_kvcache_mgr.h b/xllm_service/rpc_service/managers/global_kvcache_mgr.h similarity index 64% rename from xllm_service/rpc_service/global_kvcache_mgr.h rename to xllm_service/rpc_service/managers/global_kvcache_mgr.h index 384c3c6..0115aa5 100644 --- a/xllm_service/rpc_service/global_kvcache_mgr.h +++ b/xllm_service/rpc_service/managers/global_kvcache_mgr.h @@ -1,14 +1,29 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #pragma once #include #include +#include "../etcd_client/etcd_client.h" #include "common/hash_util.h" #include "common/macros.h" #include "common/slice.h" #include "common/threadpool.h" #include "common/types.h" -#include "etcd_client.h" #include "xllm_rpc_service.pb.h" namespace xllm_service { diff --git a/xllm_service/rpc_service/instance_mgr.cpp b/xllm_service/rpc_service/managers/instance_mgr.cpp similarity index 78% rename from xllm_service/rpc_service/instance_mgr.cpp rename to xllm_service/rpc_service/managers/instance_mgr.cpp index 1c30bd4..65dc7e7 100644 --- a/xllm_service/rpc_service/instance_mgr.cpp +++ b/xllm_service/rpc_service/managers/instance_mgr.cpp @@ -35,7 +35,6 @@ static std::unordered_map ETCD_KEYS_PREFIX_MAP = { {InstanceType::DECODE, "XLLM:DECODE:"}, }; static std::string ETCD_ALL_KEYS_PREFIX = "XLLM:"; -static std::string DEFAULT_DISAGG_PD_POLICY = "RR"; static std::string ETCD_LOADMETRICS_PREFIX = "XLLM:LOADMETRICS:"; InstanceMgr::InstanceMgr(const std::shared_ptr& etcd_client, @@ -70,17 +69,43 @@ void InstanceMgr::init() { etcd_client_->get_prefix(it.second, &instances_); } LOG(INFO) << "Load instance info from etcd:" << instances_.size(); - for (const auto& name : instances_) { - if (!create_channel(name.first)) { - // TODO: add retry - instances_.erase(name.first); + std::vector channel_creat_fail_insts; + prefill_index_.reserve(instances_.size()); + decode_index_.reserve(instances_.size()); + + for (auto& ist : instances_) { + if (!create_channel(ist.first)) { + channel_creat_fail_insts.emplace_back(ist.first); + } else { + switch (ist.second.type) { + case InstanceType::DEFAULT: + case InstanceType::PREFILL: + ist.second.instance_index = prefill_index_.size(); + prefill_index_.emplace_back(ist.first); + break; + case InstanceType::DECODE: + ist.second.instance_index = decode_index_.size(); + decode_index_.emplace_back(ist.first); + break; + default: + LOG(WARNING) << "Unknown InstanceType: " << int(ist.second.type); + channel_creat_fail_insts.emplace_back(ist.first); + break; + } } } + for (auto& name : channel_creat_fail_insts) { + instances_.erase(name); + } } { std::unique_lock lock(load_metric_mutex_); etcd_client_->get_prefix(ETCD_LOADMETRICS_PREFIX, &load_metrics_); } + + for (int i = 0; i < prefill_index_.size(); i++) { + LOG(INFO) << i << " : " << prefill_index_[i]; + } } InstanceMgr::~InstanceMgr() { exited_ = true; } @@ -97,6 +122,24 @@ InstanceMetaInfo InstanceMgr::get_instance_info( return instances_[instance_name]; } +bool InstanceMgr::get_next_instance_pair(Routing* routing) { + std::unique_lock lock(inst_mutex_); + if (prefill_index_.empty()) { + LOG(ERROR) << "No prefill or default instance found!"; + return false; + } + next_prefill_index_ = next_prefill_index_ % prefill_index_.size(); + routing->prefill_name = prefill_index_[next_prefill_index_]; + next_prefill_index_++; + if (decode_index_.empty()) { + return true; + } + next_decode_index_ = next_decode_index_ % decode_index_.size(); + routing->decode_name = decode_index_[next_decode_index_]; + next_decode_index_++; + return true; +} + // TODO: refactor later, currently return all decode instances std::vector InstanceMgr::get_static_decode_list( const std::string& instance_name) { @@ -287,8 +330,28 @@ void InstanceMgr::update_instance_metainfo(const etcd::Response& response, << iter.first; continue; } + + if (!create_channel(iter.first)) { + LOG(ERROR) << "create channel fail: " << iter.first; + continue; + } + instances_.insert(std::make_pair(iter.first, std::move(iter.second))); - create_channel(iter.first); + + switch (iter.second.type) { + case InstanceType::DEFAULT: + case InstanceType::PREFILL: + iter.second.instance_index = prefill_index_.size(); + prefill_index_.emplace_back(iter.first); + break; + case InstanceType::DECODE: + iter.second.instance_index = decode_index_.size(); + decode_index_.emplace_back(iter.first); + break; + default: + LOG(WARNING) << "Unknown InstanceType: " << int(iter.second.type); + break; + } } for (auto& iter : delete_list) { @@ -297,6 +360,32 @@ void InstanceMgr::update_instance_metainfo(const etcd::Response& response, continue; } // TODO: notify cache manager to clear expire cache + uint64_t index = instances_[iter].instance_index; + + switch (instances_[iter].type) { + case InstanceType::DEFAULT: + case InstanceType::PREFILL: + if (index == -1 || index >= prefill_index_.size()) { + break; + } + std::swap(prefill_index_[index], prefill_index_.back()); + instances_[prefill_index_[index]].instance_index = index; + prefill_index_.pop_back(); + break; + case InstanceType::DECODE: + if (index == -1 || index >= decode_index_.size()) { + break; + } + std::swap(decode_index_[index], decode_index_.back()); + instances_[decode_index_[index]].instance_index = index; + decode_index_.pop_back(); + break; + default: + LOG(WARNING) << "Unknown InstanceType: " + << int(instances_[iter].type); + break; + } + instances_.erase(iter); cached_channels_.erase(iter); { diff --git a/xllm_service/rpc_service/instance_mgr.h b/xllm_service/rpc_service/managers/instance_mgr.h similarity index 91% rename from xllm_service/rpc_service/instance_mgr.h rename to xllm_service/rpc_service/managers/instance_mgr.h index 9eb8ff7..6905270 100644 --- a/xllm_service/rpc_service/instance_mgr.h +++ b/xllm_service/rpc_service/managers/instance_mgr.h @@ -22,10 +22,10 @@ limitations under the License. #include #include +#include "../etcd_client/etcd_client.h" #include "common/macros.h" #include "common/threadpool.h" #include "common/types.h" -#include "etcd_client.h" #include "xllm_rpc_service.pb.h" namespace xllm_service { @@ -40,6 +40,8 @@ class InstanceMgr final { InstanceMetaInfo get_instance_info(const std::string& instance_name); + bool get_next_instance_pair(Routing* routing); + std::vector get_static_decode_list( const std::string& instance_name); @@ -78,6 +80,10 @@ class InstanceMgr final { std::shared_mutex inst_mutex_; std::unordered_map instances_; + std::vector prefill_index_; + std::vector decode_index_; + uint64_t next_prefill_index_ = 0; + uint64_t next_decode_index_ = 0; std::shared_mutex load_metric_mutex_; std::unordered_map load_metrics_; diff --git a/xllm_service/rpc_service/scheduler.cpp b/xllm_service/rpc_service/scheduler.cpp index b63e523..a66f4a8 100644 --- a/xllm_service/rpc_service/scheduler.cpp +++ b/xllm_service/rpc_service/scheduler.cpp @@ -7,6 +7,8 @@ #include "chat_template/jinja_chat_template.h" #include "common.pb.h" #include "common/hash_util.h" +#include "loadbalance_policy/cache_aware_routing.h" +#include "loadbalance_policy/round_robin.h" #include "tokenizer/tokenizer_factory.h" static constexpr int kHeartbeatInterval = 3; // in seconds @@ -32,13 +34,19 @@ Scheduler::Scheduler(const RpcServiceConfig& rpc_config, LOG(INFO) << "Set current service as master!"; } - instance_mgr_ = std::make_unique( + instance_mgr_ = std::make_shared( etcd_client_, http_config_, is_master_service_); - global_kvcache_mgr_ = std::make_unique( + global_kvcache_mgr_ = std::make_shared( etcd_client_, model_config_, is_master_service_); - lb_policy_ = std::make_unique(); + if (rpc_config_.load_balance_policy == "CAR") { + lb_policy_ = + std::make_unique(instance_mgr_, global_kvcache_mgr_); + } else { + lb_policy_ = + std::make_unique(instance_mgr_, global_kvcache_mgr_); + } if (is_master_service_) { heartbeat_thread_ = std::make_unique( @@ -71,32 +79,16 @@ bool Scheduler::schedule(const ChatMessages& messages, ScheduleResult* res) { } bool Scheduler::schedule(const std::string& prompt, ScheduleResult* res) { - LoadBalanceInfos lb_infos; if (prompt.size() != 0) { if (!get_tls_tokenizer()->encode(prompt, &res->token_ids)) { LOG(ERROR) << "Encode prompt faill: " << prompt; return false; } - - Slice token_ids(res->token_ids.data(), res->token_ids.size()); - - global_kvcache_mgr_->match(token_ids, &lb_infos.overlap_scores); - DLOG(INFO) << lb_infos.debug_string(); - } - - instance_mgr_->get_load_metrics(&lb_infos); - DLOG(INFO) << lb_infos.debug_string(); - - if (lb_infos.prefill_load_metrics.size() == 0) { - LOG(INFO) << "No node available!"; - return false; } - - lb_policy_->select_instances_pair(lb_infos, &res->routing); - + auto ret = lb_policy_->select_instances_pair(res); DLOG(INFO) << res->routing.debug_string(); - return true; + return ret; } std::shared_ptr Scheduler::get_channel( diff --git a/xllm_service/rpc_service/scheduler.h b/xllm_service/rpc_service/scheduler.h index df0a859..429e52d 100644 --- a/xllm_service/rpc_service/scheduler.h +++ b/xllm_service/rpc_service/scheduler.h @@ -11,10 +11,10 @@ #include "common/hash_util.h" #include "common/macros.h" #include "common/types.h" -#include "etcd_client.h" -#include "global_kvcache_mgr.h" -#include "instance_mgr.h" +#include "etcd_client/etcd_client.h" #include "loadbalance_policy/loadbalance_policy.h" +#include "managers/global_kvcache_mgr.h" +#include "managers/instance_mgr.h" #include "tokenizer/tokenizer.h" #include "tokenizer/tokenizer_args.h" #include "xllm_rpc_service.pb.h" @@ -74,9 +74,9 @@ class Scheduler { std::unique_ptr tokenizer_; - std::unique_ptr instance_mgr_; + std::shared_ptr instance_mgr_; - std::unique_ptr global_kvcache_mgr_; + std::shared_ptr global_kvcache_mgr_; std::unique_ptr lb_policy_; diff --git a/xllm_service/rpc_service/service.h b/xllm_service/rpc_service/service.h index 5a4f878..2c15a29 100644 --- a/xllm_service/rpc_service/service.h +++ b/xllm_service/rpc_service/service.h @@ -24,7 +24,7 @@ limitations under the License. #include "common/xllm/output.h" #include "common/xllm/status.h" #include "completion.pb.h" -#include "instance_mgr.h" +#include "managers/instance_mgr.h" #include "response_handler.h" #include "scheduler.h" #include "xllm_rpc_service.pb.h"