diff --git a/.gitmodules b/.gitmodules index 2c43fb2..507ceba 100644 --- a/.gitmodules +++ b/.gitmodules @@ -16,3 +16,6 @@ [submodule "third_party/minja"] path = third_party/minja url = https://gitcode.com/xLLM-AI/minja.git +[submodule "third_party/smhasher"] + path = third_party/smhasher + url = https://gitcode.com/xLLM-AI/smhasher.git diff --git a/CMakeLists.txt b/CMakeLists.txt index a5ee4cb..617a5e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,7 @@ endif() # find all dependencies from vcpkg find_package(Boost REQUIRED) +find_package(Boost REQUIRED COMPONENTS serialization) find_package(glog CONFIG REQUIRED) find_package(gflags CONFIG REQUIRED) find_package(leveldb CONFIG REQUIRED) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index bb3c3a1..280b058 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -4,4 +4,5 @@ add_subdirectory(cpprestsdk) add_subdirectory(etcd_cpp_apiv3) add_subdirectory(brpc) add_subdirectory(minja) -add_subdirectory(sentencepiece) \ No newline at end of file +add_subdirectory(sentencepiece) +add_subdirectory(smhasher/src) \ No newline at end of file diff --git a/vcpkg.json b/vcpkg.json index 86653fd..d57f269 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -38,6 +38,10 @@ "name": "boost-random", "version>=": "1.84.0" }, + { + "name": "boost-serialization", + "version>=": "1.84.0" + }, { "name": "protobuf", "version>=": "3.21.12", diff --git a/xllm_service/common/CMakeLists.txt b/xllm_service/common/CMakeLists.txt index df21039..2ffc2f9 100644 --- a/xllm_service/common/CMakeLists.txt +++ b/xllm_service/common/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library( threadpool.h types.h utils.h + hash_util.h xllm/output.h xllm/status.h xllm/uuid.h @@ -22,6 +23,7 @@ cc_library( json_reader.cpp threadpool.cpp utils.cpp + hash_util.cpp xllm/uuid.cpp DEPS absl::random_random @@ -29,6 +31,6 @@ cc_library( glog::glog gflags::gflags nlohmann_json::nlohmann_json + SMHasherSupport ) -target_link_libraries(common PRIVATE OpenSSL::SSL OpenSSL::Crypto) add_dependencies(common brpc-static) diff --git a/xllm_service/common/global_gflags.cpp b/xllm_service/common/global_gflags.cpp index ff84c32..bc15ad2 100644 --- a/xllm_service/common/global_gflags.cpp +++ b/xllm_service/common/global_gflags.cpp @@ -33,11 +33,6 @@ DEFINE_int32(http_server_max_concurrency, 128, "Limit number of requests processed in parallel"); -DEFINE_string(rpc_server_host, - "", - "Rpc server listen address, may be IPV4/IPV6/UDS." - " If this is set, the flag port will be ignored"); - DEFINE_int32(rpc_server_port, 8889, "Port for xllm rpc service to listen on"); DEFINE_int32(rpc_server_idle_timeout_s, @@ -55,6 +50,8 @@ DEFINE_string(etcd_addr, "0.0.0.0:2379", "etcd adderss for save instance meta info"); +DEFINE_uint32(murmur_hash3_seed, 1024, "default Murmur Hash seed"); + DEFINE_int32(port, 8888, "Port for xllm service to listen on"); DEFINE_int32(num_threads, 32, "Number of threads to process requests"); @@ -81,7 +78,9 @@ DEFINE_int32(idle_timeout_s, "Connection will be closed if there is no " "read/write operations during the last `idle_timeout_s'"); -DEFINE_string(disagg_pd_policy, "RR", "Disaggregated prefill-decode policy."); +DEFINE_string(load_balance_policy, + "RR", + "Disaggregated prefill-decode policy."); DEFINE_int32(detect_disconnected_instance_interval, 15, diff --git a/xllm_service/common/global_gflags.h b/xllm_service/common/global_gflags.h index 657a918..b513c00 100644 --- a/xllm_service/common/global_gflags.h +++ b/xllm_service/common/global_gflags.h @@ -27,8 +27,6 @@ DECLARE_int32(http_server_num_threads); DECLARE_int32(http_server_max_concurrency); -DECLARE_string(rpc_server_host); - DECLARE_int32(rpc_server_port); DECLARE_int32(rpc_server_idle_timeout_s); @@ -37,6 +35,8 @@ DECLARE_int32(rpc_server_num_threads); DECLARE_int32(rpc_server_max_concurrency); +DECLARE_uint32(murmur_hash3_seed); + DECLARE_string(test_instance_addr); DECLARE_int32(timeout_ms); @@ -53,7 +53,7 @@ DECLARE_int32(max_concurrency); DECLARE_string(etcd_addr); -DECLARE_string(disagg_pd_policy); +DECLARE_string(load_balance_policy); DECLARE_int32(detect_disconnected_instance_interval); diff --git a/xllm_service/common/hash_util.cpp b/xllm_service/common/hash_util.cpp new file mode 100644 index 0000000..5c131b7 --- /dev/null +++ b/xllm_service/common/hash_util.cpp @@ -0,0 +1,62 @@ + +#include "common/hash_util.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "common/global_gflags.h" + +namespace xllm_service { + +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value) { + if (pre_hash_value == nullptr) { + MurmurHash3_x64_128(reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size(), + FLAGS_murmur_hash3_seed, + hash_value); + } else { + uint8_t key[1024]; + + int32_t data_len = + sizeof(int32_t) * token_ids.size() + MURMUR_HASH3_VALUE_LEN; + assert(sizeof(key) > data_len); + + memcpy(key, pre_hash_value, MURMUR_HASH3_VALUE_LEN); + memcpy(key + MURMUR_HASH3_VALUE_LEN, + reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size()); + + // print_hex_array(key, data_len); + MurmurHash3_x64_128(reinterpret_cast(key), + data_len, + FLAGS_murmur_hash3_seed, + hash_value); + } +} + +void print_hex_array(uint8_t* array) { + for (size_t i = 0; i < MURMUR_HASH3_VALUE_LEN; ++i) { + unsigned char uc = static_cast(array[i]); + std::cout << std::hex << std::setw(2) << std::setfill('0') + << static_cast(uc); + + if (i % MURMUR_HASH3_VALUE_LEN == MURMUR_HASH3_VALUE_LEN - 1) { + std::cout << std::endl; + } + + else { + std::cout << " "; + } + } + std::cout << std::dec << std::endl; +} + +} // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/common/hash_util.h b/xllm_service/common/hash_util.h new file mode 100644 index 0000000..435944e --- /dev/null +++ b/xllm_service/common/hash_util.h @@ -0,0 +1,59 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "common/slice.h" + +namespace xllm_service { +constexpr uint32_t MURMUR_HASH3_VALUE_LEN = 16; + +struct Murmur3Key { + uint8_t data[MURMUR_HASH3_VALUE_LEN]; + + Murmur3Key() {} + Murmur3Key(const uint8_t* const input_data) { + memcpy(data, input_data, MURMUR_HASH3_VALUE_LEN); + } + Murmur3Key(const char* const input_data) { + memcpy(data, input_data, MURMUR_HASH3_VALUE_LEN); + } + + std::string to_string() const { + return std::string(reinterpret_cast(data), + MURMUR_HASH3_VALUE_LEN); + } + + bool operator==(const Murmur3Key& other) { + return strncmp(reinterpret_cast(data), + reinterpret_cast(other.data), + MURMUR_HASH3_VALUE_LEN); + } +}; + +struct FixedStringKeyHash { + size_t operator()(const Murmur3Key& key) const { + return std::hash()(std::string_view( + reinterpret_cast(key.data), sizeof(key.data))); + } +}; + +struct FixedStringKeyEqual { + bool operator()(const Murmur3Key& left, const Murmur3Key& right) const { + return strncmp(reinterpret_cast(left.data), + reinterpret_cast(right.data), + sizeof(left.data)) == 0; + } +}; + +void print_hex_array(uint8_t* array); + +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value); + +} // namespace xllm_service diff --git a/xllm_service/common/types.h b/xllm_service/common/types.h index c681ab6..c17d834 100644 --- a/xllm_service/common/types.h +++ b/xllm_service/common/types.h @@ -19,13 +19,23 @@ limitations under the License. #include #include +#include #include +#include +#include #include +#include "common/hash_util.h" #include "nlohmann/json.hpp" namespace xllm_service { +struct CacheLocations; +using Murmur3KeyCacheMap = std::unordered_map; + struct HttpServiceConfig { int num_threads = 16; int timeout_ms = -1; @@ -35,15 +45,34 @@ struct HttpServiceConfig { struct RpcServiceConfig { std::string etcd_addr = ""; - std::string disagg_pd_policy = ""; + std::string load_balance_policy = ""; int detect_disconnected_instance_interval = 15; // seconds + std::string service_name = ""; +}; + +struct ModelConfig { + int32_t block_size = 16; + std::string model_type = "chatglm"; + std::string tokenizer_path = ""; +}; + +struct Routing { + std::string prefill_name; + std::string decode_name; + + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["prefill_name"] = prefill_name; + json_val["decode_name"] = decode_name; + return json_val; + } + + std::string debug_string() const { return serialize_to_json().dump(2); } }; -// instances pair for prefill and decode in disagg PD mode. -struct InstancesPair { - std::string prefill_instance_http_addr = ""; - // empty means no decode instance, only prefill instance is available - std::string decode_instance_http_addr = ""; +struct ScheduleResult { + std::vector token_ids; + Routing routing; }; enum class ErrorCode : int32_t { @@ -72,6 +101,42 @@ enum class InstanceType : int8_t { DECODE = 2, }; +struct LoadMetrics { + LoadMetrics() : waiting_requests_num(0), gpu_cache_usage_perc(0) {}; + LoadMetrics(const uint64_t& waiting_reqs_num, const float& usage) + : waiting_requests_num(waiting_reqs_num), gpu_cache_usage_perc(usage) {}; + + uint64_t waiting_requests_num; + float gpu_cache_usage_perc; + + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["waiting_requests_num"] = waiting_requests_num; + json_val["gpu_cache_usage_perc"] = gpu_cache_usage_perc; + return json_val; + } + + std::string debug_string() const { return serialize_to_json().dump(2); } + + bool parse_from_json(const std::string& json_str) { + try { + nlohmann::json json_value = nlohmann::json::parse(json_str); + + waiting_requests_num = + json_value.at("waiting_requests_num").get(); + gpu_cache_usage_perc = json_value.at("gpu_cache_usage_perc").get(); + + } catch (const std::exception& e) { + LOG(ERROR) << "json str:" << json_str + << ", parse to loadmetrics error: " << e.what(); + return false; + } + return true; + } + + bool empty() const { return false; } +}; + struct InstanceMetaInfo { public: InstanceMetaInfo() { set_init_timestamp(); } @@ -98,6 +163,63 @@ struct InstanceMetaInfo { // latest heatbeat timestamp uint64_t latest_timestamp = 0; + uint64_t instance_index = -1; + + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["name"] = name; + json_val["rpc_address"] = rpc_address; + json_val["type"] = int8_t(type); + json_val["addrs"] = addrs; + json_val["cluster_ids"] = cluster_ids; + json_val["k_cache_ids"] = k_cache_ids; + json_val["v_cache_ids"] = v_cache_ids; + json_val["dp_size"] = dp_size; + return json_val; + } + + std::string debug_string() const { return serialize_to_json().dump(2); } + + bool parse_from_json(const std::string& json_str) { + try { + nlohmann::json json_value = nlohmann::json::parse(json_str); + name = json_value.at("name").get(); + rpc_address = json_value.at("rpc_address").get(); + type = static_cast(json_value.at("type").get()); + + for (const auto& item : + json_value.at("cluster_ids").get>()) { + cluster_ids.push_back(item); + } + + for (const auto& item : + json_value.at("k_cache_ids").get>()) { + k_cache_ids.push_back(item); + } + + for (const auto& item : + json_value.at("addrs").get>()) { + addrs.push_back(item); + } + + for (const auto& item : + json_value.at("v_cache_ids").get>()) { + v_cache_ids.push_back(item); + } + + dp_size = json_value.at("dp_size").get(); + + set_init_timestamp(); + } catch (const std::exception& e) { + LOG(ERROR) << "json str:" << json_str + << ", parse to instancemetainfo error: " << e.what(); + return false; + } + return true; + } + + bool empty() const { return rpc_address == ""; } + private: void set_init_timestamp() { auto now = std::chrono::system_clock::now(); @@ -108,17 +230,122 @@ struct InstanceMetaInfo { } }; -// the info be stored in etcd -struct InstanceIdentityInfo { - std::string instance_addr; - std::string rpc_addr; - int8_t instance_type; // convert to InstanceType - - const std::string debug_string() const { - std::string debug_str = - "instance_addr: " + instance_addr + ", rpc_addr: " + rpc_addr + - ", instance_type: " + std::to_string((int)(instance_type)); - return debug_str; +struct CacheLocations { + std::unordered_set hbm_instance_set; + std::unordered_set dram_instance_set; + std::unordered_set ssd_instance_set; + + nlohmann::json serialize_to_json() const { + nlohmann::json json_val; + json_val["hbm_instance_set"] = hbm_instance_set; + json_val["dram_instance_set"] = dram_instance_set; + json_val["ssd_instance_set"] = ssd_instance_set; + return json_val; + } + + std::string debug_string() { return serialize_to_json().dump(2); } + + bool parse_from_json(const std::string& json_str) { + try { + nlohmann::json json_value = nlohmann::json::parse(json_str); + for (const auto& item : + json_value.at("hbm_instance_set").get>()) { + hbm_instance_set.insert(item); + } + + for (const auto& item : + json_value.at("dram_instance_set").get>()) { + dram_instance_set.insert(item); + } + + for (const auto& item : + json_value.at("ssd_instance_set").get>()) { + ssd_instance_set.insert(item); + } + + } catch (const std::exception& e) { + LOG(ERROR) << "json str:" << json_str + << ", parse to cachelocation error: " << e.what(); + return false; + } + return true; + } + + bool empty() const { + return hbm_instance_set.empty() && dram_instance_set.empty() && + ssd_instance_set.empty(); + } +}; + +/** + * @brief Records the prefix cache match lengths for different instances on + * current request + * + * This struct stores and manages prefix cache matching information across + * multiple instances, supporting different storage types (HBM, DRAM, SSD) for + * match length recording, and tracks information about the best matching + * instance. + */ +struct OverlapScores { + // Set of matched instance names + std::unordered_set instances; + // HBM storage type instance match length mapping (instance name -> match + // length) + std::unordered_map hbm_instance_score; + // DRAM storage type instance match length mapping (instance name -> match + // length) + std::unordered_map dram_instance_score; + // SSD storage type instance match length mapping (instance name -> match + // length) + std::unordered_map ssd_instance_score; + uint32_t max_block_num = 0; + uint32_t max_matched_block_num = 0; + std::string max_matched_instance_name = ""; + + std::string debug_string() { + nlohmann::json json_val; + json_val["instances"] = instances; + json_val["hbm_instance_score"] = hbm_instance_score; + json_val["dram_instance_score"] = dram_instance_score; + json_val["ssd_instance_score"] = ssd_instance_score; + json_val["max_block_num"] = max_block_num; + json_val["max_matched_block_num"] = max_matched_block_num; + json_val["max_matched_instance_name"] = max_matched_instance_name; + return json_val.dump(2); + } +}; + +struct LoadBalanceInfos { + OverlapScores overlap_scores; + std::unordered_map prefill_load_metrics; + std::unordered_map decode_load_metrics; + uint64_t prefill_max_waiting_requests_num = 0; + uint64_t decode_max_waiting_requests_num = 0; + + std::string debug_string() { + nlohmann::json json_val; + + json_val["overlap_scores"] = + nlohmann::json::parse(overlap_scores.debug_string()); + + nlohmann::json prefill_json; + for (auto& [key, metrics] : prefill_load_metrics) { + prefill_json[key] = nlohmann::json::parse(metrics.debug_string()); + } + json_val["prefill_load_metrics"] = prefill_json; + + nlohmann::json decode_json; + for (auto& [key, metrics] : decode_load_metrics) { + decode_json[key] = nlohmann::json::parse(metrics.debug_string()); + } + json_val["decode_load_metrics"] = decode_json; + + json_val["prefill_max_waiting_requests_num"] = + prefill_max_waiting_requests_num; + json_val["decode_max_waiting_requests_num"] = + decode_max_waiting_requests_num; + + return json_val.dump(2); } }; diff --git a/xllm_service/common/utils.cpp b/xllm_service/common/utils.cpp index 5047257..99dacd2 100644 --- a/xllm_service/common/utils.cpp +++ b/xllm_service/common/utils.cpp @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include namespace xllm_service { @@ -73,5 +74,25 @@ bool get_bool_env(const std::string& key, bool defaultValue) { strVal == "True"); } +std::string get_local_ip() { + using namespace boost::asio; + io_service io; + ip::tcp::resolver resolver(io); + ip::tcp::resolver::query query(ip::host_name(), ""); + ip::tcp::resolver::iterator iter = resolver.resolve(query); + ip::tcp::resolver::iterator end; + + while (iter != end) { + ip::address addr = iter->endpoint().address(); + if (!addr.is_loopback() && addr.is_v4()) { + return addr.to_string(); + } + ++iter; + } + + LOG(FATAL) << "Get local ip faill!"; + return ""; +} + } // namespace utils } // namespace xllm_service diff --git a/xllm_service/common/utils.h b/xllm_service/common/utils.h index 0b00ac8..12d335b 100644 --- a/xllm_service/common/utils.h +++ b/xllm_service/common/utils.h @@ -23,6 +23,7 @@ namespace utils { bool enable_debug_log(); bool is_port_available(int port); bool get_bool_env(const std::string& key, bool defaultValue); +std::string get_local_ip(); } // namespace utils } // namespace xllm_service diff --git a/xllm_service/http_service/service.cpp b/xllm_service/http_service/service.cpp index ea5999e..d28e8d9 100644 --- a/xllm_service/http_service/service.cpp +++ b/xllm_service/http_service/service.cpp @@ -63,56 +63,6 @@ XllmHttpServiceImpl::XllmHttpServiceImpl(const HttpServiceConfig& config) XllmHttpServiceImpl::~XllmHttpServiceImpl() {} -bool XllmHttpServiceImpl::create_channel(const std::string& target_uri) { - std::lock_guard guard(channel_mutex_); - if (cached_channels_.find(target_uri) == cached_channels_.end()) { - brpc::Channel* channel = new brpc::Channel(); - brpc::ChannelOptions options; - // Add to params - options.protocol = "http"; - options.timeout_ms = config_.timeout_ms; /*milliseconds*/ - options.max_retry = 3; - std::string load_balancer = ""; - if (channel->Init(target_uri.c_str(), load_balancer.c_str(), &options) != - 0) { - LOG(ERROR) << "Fail to initialize channel for " << target_uri; - return false; - } - cached_channels_[target_uri] = channel; - } - - return true; -} - -std::string XllmHttpServiceImpl::get_redirect_uri(bool only_prefill) { - std::string target_instance_addr; - if (!rpc_service_) { - // for testing - if (config_.test_instance_addr.empty()) { - LOG(ERROR) << "Rpc service is not start."; - return ""; - } - target_instance_addr = config_.test_instance_addr; - } else { - InstancesPair instances_pair = - rpc_service_->select_instances_pair(only_prefill); - if (instances_pair.prefill_instance_http_addr.empty()) { - LOG(ERROR) << "No prefill instance available."; - return ""; - } - target_instance_addr = instances_pair.prefill_instance_http_addr; - - if (!only_prefill) { - if (instances_pair.decode_instance_http_addr.empty()) { - // TODO: - } - // TODO: add instances_pair.decode_instance_http_addr to request? - } - } - - return target_instance_addr; -} - void XllmHttpServiceImpl::Hello(::google::protobuf::RpcController* controller, const proto::HttpHelloRequest* request, proto::HttpHelloResponse* response, @@ -213,7 +163,8 @@ void XllmHttpServiceImpl::handle(std::shared_ptr call_data, // async redistribute the request and wait the response // TODO: optimize the thread pool to async mode. - auto channel_ptr = cached_channels_[target_uri]; + brpc::Channel* channel_ptr = rpc_service_->get_channel(target_uri).get(); + // send request to prefill instance. thread_pool_->schedule([this, service_request_id, @@ -375,24 +326,6 @@ void XllmHttpServiceImpl::post_serving( // create xllm_service request_id: service_request_id std::string service_request_id = generate_service_request_id(serving_method); json_value["service_request_id"] = service_request_id; - std::string req_attachment = json_value.dump(); - request_tracer_->log(service_request_id, req_attachment); - - // redistribute the request to the correct P/D instance - // TODO: redistribute policy to select the instance - std::string target_uri = get_redirect_uri(); - if (target_uri.empty()) { - cntl->SetFailed( - "Internal runtime error, can not found a running instance."); - return; - } - if (cached_channels_.find(target_uri) == cached_channels_.end()) { - if (!create_channel(target_uri)) { - LOG(ERROR) << "Create channel failed, target_uri is " << target_uri; - cntl->SetFailed("Internal runtime error."); - return; - } - } std::function trace_callback; if (config_.enable_request_trace) { @@ -403,33 +336,82 @@ void XllmHttpServiceImpl::post_serving( trace_callback = nullptr; } + ScheduleResult schedule_res; if (serving_method == "/v1/completions") { + if (json_value.contains("prompt")) { + if (!rpc_service_->schedule(json_value.at("prompt").get(), + &schedule_res)) { + cntl->SetFailed("Schedule fail!"); + LOG(ERROR) << "XllmRpcServiceImpl::schedule error!"; + return; + } + } else { + cntl->SetFailed("Input has no prompt!"); + LOG(ERROR) << "Input has no prompt!"; + return; + } + json_value["token_ids"] = std::move(schedule_res.token_ids); + json_value["routing"] = std::move(schedule_res.routing.serialize_to_json()); + + std::string req_attachment = json_value.dump(); auto arena = response->GetArena(); auto resp_pb = google::protobuf::Arena::CreateMessage( arena); auto call_data = std::make_shared( - cntl, stream, done_guard.release(), resp_pb, trace_callback); + cntl, stream, done_guard.release(), resp_pb); handle_v1_completions(call_data, req_attachment, service_request_id, stream, model, include_usage, - target_uri); + schedule_res.routing.prefill_name); } else if (serving_method == "/v1/chat/completions") { + if (json_value.contains("messages") && json_value["messages"].is_array()) { + ChatMessages messages; + try { + const auto& msgs = json_value["messages"]; + messages.reserve(msgs.size()); + for (const auto& msg : msgs) { + if (msg.contains("role") && msg["role"].is_string() && + msg.contains("content") && msg["content"].is_string()) { + messages.emplace_back(msg["role"].get(), + msg["content"].get()); + } + } + } catch (const nlohmann::json::exception& e) { + cntl->SetFailed("Parse request fail, Invalid messages!"); + LOG(ERROR) << "Parse request fail, Invalid messages!"; + return; + } + + if (!rpc_service_->schedule(messages, &schedule_res)) { + cntl->SetFailed("Schedule fail!"); + LOG(ERROR) << "XllmRpcServiceImpl::schedule error!"; + return; + } + } else { + cntl->SetFailed("Input has no messages!"); + LOG(ERROR) << "Input has no messages!"; + return; + } + json_value["token_ids"] = std::move(schedule_res.token_ids); + json_value["routing"] = std::move(schedule_res.routing.serialize_to_json()); + + std::string req_attachment = json_value.dump(); auto arena = response->GetArena(); auto resp_pb = google::protobuf::Arena::CreateMessage(arena); auto call_data = std::make_shared( - cntl, stream, done_guard.release(), resp_pb, trace_callback); + cntl, stream, done_guard.release(), resp_pb); handle_v1_chat_completions(call_data, req_attachment, service_request_id, stream, model, include_usage, - target_uri); + schedule_res.routing.prefill_name); } else { LOG(ERROR) << "Not supported method: " << serving_method; cntl->SetFailed("Not supported method: " + serving_method); @@ -471,22 +453,18 @@ void XllmHttpServiceImpl::get_serving( // done_guard.release()); auto call_data = std::make_shared( cntl, false, done_guard.release(), nullptr); - std::string target_uri = get_redirect_uri(true /*only_prefill*/); - if (target_uri.empty()) { - cntl->SetFailed( - "Internal runtime error, can not found a running instance."); + + ScheduleResult schedule_res; + if (!rpc_service_->schedule("", &schedule_res)) { + cntl->SetFailed("Schedule fail!"); + LOG(ERROR) << "XllmRpcServiceImpl::schedule error!"; return; } - if (cached_channels_.find(target_uri) == cached_channels_.end()) { - if (!create_channel(target_uri)) { - LOG(ERROR) << "Create channel failed, target_uri is " << target_uri; - cntl->SetFailed("Internal runtime error."); - return; - } - } - auto channel_ptr = cached_channels_[target_uri]; - target_uri += serving_method; + brpc::Channel* channel_ptr = + rpc_service_->get_channel(schedule_res.routing.prefill_name).get(); + std::string target_uri = schedule_res.routing.prefill_name + serving_method; + thread_pool_->schedule( [/*req_attachment, */ call_data, cntl, channel_ptr, target_uri]() { brpc::Controller* redirect_cntl = new brpc::Controller(); diff --git a/xllm_service/http_service/service.h b/xllm_service/http_service/service.h index 3bff5c4..69dd6cb 100644 --- a/xllm_service/http_service/service.h +++ b/xllm_service/http_service/service.h @@ -77,8 +77,7 @@ class XllmHttpServiceImpl : public proto::XllmHttpService { private: bool create_channel(const std::string& target_uri); - // only prefill is true means only prefill instance is returned - std::string get_redirect_uri(bool only_prefill = false); + void post_serving(const std::string& serving_method, ::google::protobuf::RpcController* controller, const proto::HttpRequest* request, @@ -124,13 +123,8 @@ class XllmHttpServiceImpl : public proto::XllmHttpService { std::shared_ptr rpc_service_; std::unique_ptr request_tracer_; - // uri -> channel - // e.g. 127.0.0.1:9999/v1/completions -> channel1 - // 127.0.0.1:9999/v1/chat/completions -> channel2 - // NOTE: different methods to one instance has different channels - std::unordered_map cached_channels_; + std::unique_ptr thread_pool_; - std::mutex channel_mutex_; // In disagg pd mode, we support receive generated token from // prefill or from decode directly. diff --git a/xllm_service/master.cpp b/xllm_service/master.cpp index a6ebe69..72983c4 100644 --- a/xllm_service/master.cpp +++ b/xllm_service/master.cpp @@ -31,18 +31,30 @@ Master::Master(const ServerOptions& server_options) } RpcServiceConfig rpc_config; rpc_config.etcd_addr = server_options.etcd_addr; - rpc_config.disagg_pd_policy = server_options.disagg_pd_policy; + rpc_config.load_balance_policy = server_options.load_balance_policy; rpc_config.detect_disconnected_instance_interval = server_options.detect_disconnected_instance_interval; - rpc_service_impl_ = - std::make_shared(rpc_config); - rpc_service_ = - std::make_unique(rpc_service_impl_); + rpc_config.service_name = server_options_.rpc_server_host + ":" + + std::to_string(server_options_.rpc_port); + + ModelConfig model_config; + model_config.block_size = server_options.block_size; + model_config.model_type = server_options.model_type; + model_config.tokenizer_path = server_options.tokenizer_path; - HttpServiceConfig http_config; + xllm_service::HttpServiceConfig http_config; http_config.num_threads = server_options.http_num_threads; + http_config.timeout_ms = server_options.timeout_ms; + http_config.test_instance_addr = server_options.test_instance_addr; http_config.enable_request_trace = server_options.enable_request_trace; + + rpc_service_impl_ = std::make_shared( + rpc_config, model_config, http_config); + + rpc_service_ = + std::make_unique(rpc_service_impl_); + http_service_ = std::make_unique( rpc_service_impl_, http_config); } @@ -166,7 +178,11 @@ int main(int argc, char* argv[]) { // Initialize glog google::InitGoogleLogging(argv[0]); - FLAGS_logtostderr = true; + // FLAGS_logtostderr = true; + + LOG(INFO) << "Dump all gflags: " << std::endl + << google::CommandlineFlagsIntoString(); + google::FlushLogFiles(google::INFO); LOG(INFO) << "Starting xllm master service."; @@ -191,20 +207,26 @@ int main(int argc, char* argv[]) { server_options.http_idle_timeout_s = FLAGS_http_server_idle_timeout_s; server_options.http_num_threads = FLAGS_http_server_num_threads; server_options.http_max_concurrency = FLAGS_http_server_max_concurrency; - server_options.rpc_server_host = FLAGS_rpc_server_host; + server_options.rpc_server_host = xllm_service::utils::get_local_ip(); server_options.rpc_port = FLAGS_rpc_server_port; server_options.rpc_idle_timeout_s = FLAGS_rpc_server_idle_timeout_s; server_options.rpc_num_threads = FLAGS_rpc_server_num_threads; server_options.rpc_max_concurrency = FLAGS_rpc_server_max_concurrency; server_options.etcd_addr = FLAGS_etcd_addr; - server_options.disagg_pd_policy = FLAGS_disagg_pd_policy; + server_options.load_balance_policy = FLAGS_load_balance_policy; server_options.detect_disconnected_instance_interval = FLAGS_detect_disconnected_instance_interval; server_options.enable_request_trace = FLAGS_enable_request_trace; + + server_options.tokenizer_path = FLAGS_tokenizer_path; server_options.block_size = FLAGS_block_size; server_options.model_type = FLAGS_model_type; server_options.tokenizer_path = FLAGS_tokenizer_path; + server_options.num_threads = FLAGS_num_threads; + server_options.timeout_ms = FLAGS_timeout_ms; + server_options.test_instance_addr = FLAGS_test_instance_addr; + xllm_service::Master master(server_options); if (!master.start()) { diff --git a/xllm_service/master.h b/xllm_service/master.h index 92b9692..cafea81 100644 --- a/xllm_service/master.h +++ b/xllm_service/master.h @@ -32,6 +32,9 @@ struct ServerOptions { int32_t http_num_threads = 32; int32_t http_max_concurrency = 128; bool enable_request_trace = false; + int num_threads = 16; + int timeout_ms = -1; + std::string test_instance_addr = ""; // rpc server options std::string rpc_server_host = ""; @@ -40,7 +43,7 @@ struct ServerOptions { int32_t rpc_num_threads = 32; int32_t rpc_max_concurrency = 128; std::string etcd_addr = ""; - std::string disagg_pd_policy = "RR"; + std::string load_balance_policy = "RR"; int32_t detect_disconnected_instance_interval = 15; int32_t block_size = 16; std::string model_type = "chatglm"; diff --git a/xllm_service/proto/xllm/chat.proto b/xllm_service/proto/xllm/chat.proto index 5ded588..1fdecd2 100644 --- a/xllm_service/proto/xllm/chat.proto +++ b/xllm_service/proto/xllm/chat.proto @@ -20,7 +20,7 @@ message ChatMessage { } -// Next Id: 25 +// Next Id: 28 message ChatRequest { // ID of the model to use. You can use the ListModels endpoint to list available models. @@ -110,7 +110,9 @@ message ChatRequest { optional string service_request_id = 25; - Routing routing = 28; + repeated int32 token_ids = 26; + + Routing routing = 27; } message ChatLogProbData { diff --git a/xllm_service/proto/xllm/common.proto b/xllm_service/proto/xllm/common.proto index b821d1d..8015e64 100644 --- a/xllm_service/proto/xllm/common.proto +++ b/xllm_service/proto/xllm/common.proto @@ -24,9 +24,7 @@ message Status { } message Routing { - repeated int32 token_ids = 1; + string prefill_name = 1; - string prefill_name = 2; - - string decode_name = 3; + string decode_name = 2; } diff --git a/xllm_service/proto/xllm/completion.proto b/xllm_service/proto/xllm/completion.proto index 9a8f70d..9b1c78f 100644 --- a/xllm_service/proto/xllm/completion.proto +++ b/xllm_service/proto/xllm/completion.proto @@ -5,7 +5,7 @@ option cc_enable_arenas = true; import "common.proto"; -// Next ID: 25 +// Next ID: 28 message CompletionRequest { // ID of the model to use. (required) // You can use the ListModels endpoint to list available models. @@ -87,7 +87,9 @@ message CompletionRequest { optional string service_request_id = 25; - Routing routing = 28; + repeated int32 token_ids = 26; + + Routing routing = 27; } message LogProbs { diff --git a/xllm_service/proto/xllm_rpc_service.proto b/xllm_service/proto/xllm_rpc_service.proto index d769ba3..023f7f9 100644 --- a/xllm_service/proto/xllm_rpc_service.proto +++ b/xllm_service/proto/xllm_rpc_service.proto @@ -51,7 +51,6 @@ message LoadMetrics { float gpu_cache_usage_perc = 2; } -// TODO: add metainfo/metrics ? message HeartbeatRequest { string name = 1; KvCacheEvent cache_event = 2; diff --git a/xllm_service/rpc_service/CMakeLists.txt b/xllm_service/rpc_service/CMakeLists.txt index b876a6e..cdcb57c 100644 --- a/xllm_service/rpc_service/CMakeLists.txt +++ b/xllm_service/rpc_service/CMakeLists.txt @@ -1,28 +1,28 @@ include(cc_binary) include(cc_library) include(cc_test) +add_subdirectory(etcd_client) +add_subdirectory(managers) +add_subdirectory(loadbalance_policy) cc_library( NAME xllm_rpc_service HDRS - disagg_pd_policy.h - etcd_client.h - instance_mgr.h + scheduler.h response_handler.h service.h SRCS - disagg_pd_policy.cpp - etcd_client.cpp - instance_mgr.cpp + scheduler.cpp response_handler.cpp service.cpp DEPS :common + :etcd_client + :managers + :loadbalance_policy absl::random_random absl::strings - cpprest - etcd-cpp-api glog::glog nlohmann_json::nlohmann_json proto::proto_rpc_service diff --git a/xllm_service/rpc_service/disagg_pd_policy.cpp b/xllm_service/rpc_service/disagg_pd_policy.cpp deleted file mode 100644 index 292e14b..0000000 --- a/xllm_service/rpc_service/disagg_pd_policy.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm-service/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "disagg_pd_policy.h" - -#include - -#include "common/utils.h" - -namespace xllm_service { - -namespace { - -void debug_print(const std::string& action, - const std::string& name, - const std::string& type, - int idx) { - if (utils::enable_debug_log()) { - LOG(INFO) << "DisaggPdPolicy " << action << " instance, name = " << name - << ", type = " << type << ", idx = " << idx; - } -} - -} // namespace - -DisaggPdPolicy::DisaggPdPolicy() {} - -DisaggPdPolicy::~DisaggPdPolicy() {} - -void DisaggPdPolicy::insert_instance(const std::string& name, - InstanceMetaInfo* info) { - std::lock_guard guard(mutex_); - InstanceType type = info->type; - if (type == InstanceType::DEFAULT || type == InstanceType::PREFILL) { - auto it = prefill_instance_to_index_.find(name); - if (it != prefill_instance_to_index_.end()) { - LOG(ERROR) << "Insert instance is already existed, name: " << name - << ", type: " << static_cast(type); - return; - } - prefill_instance_.emplace_back(info); - prefill_instance_to_index_[name] = prefill_instance_.size() - 1; - debug_print( - "insert", name, "prefill or default", prefill_instance_to_index_[name]); - } else { - auto it = decode_instance_to_index_.find(name); - if (it != decode_instance_to_index_.end()) { - LOG(ERROR) << "Insert instance is already existed, name: " << name - << ", type: " << static_cast(type); - return; - } - decode_instance_.emplace_back(info); - decode_instance_to_index_[name] = decode_instance_.size() - 1; - debug_print("insert", name, "decode", decode_instance_to_index_[name]); - } -} - -void DisaggPdPolicy::update_instance(const std::string& name, - InstanceMetaInfo* info) { - std::lock_guard guard(mutex_); - InstanceType type = info->type; - if (type == InstanceType::DEFAULT || type == InstanceType::PREFILL) { - auto it = prefill_instance_to_index_.find(name); - if (it == prefill_instance_to_index_.end()) { - LOG(ERROR) << "Update instance is not existed, name: " << name - << ", type: " << static_cast(type); - return; - } - prefill_instance_[it->second] = info; - debug_print( - "update", name, "prefill or default", prefill_instance_to_index_[name]); - } else { - auto it = decode_instance_to_index_.find(name); - if (it == decode_instance_to_index_.end()) { - LOG(ERROR) << "Update instance is not existed, name: " << name - << ", type: " << static_cast(type); - return; - } - decode_instance_[it->second] = info; - debug_print("update", name, "decode", decode_instance_to_index_[name]); - } -} - -void DisaggPdPolicy::remove_instance(const std::string& name, - InstanceType type) { - std::lock_guard guard(mutex_); - if (type == InstanceType::DEFAULT || type == InstanceType::PREFILL) { - auto it = prefill_instance_to_index_.find(name); - if (it == prefill_instance_to_index_.end()) { - LOG(ERROR) << "Remove instance not found, name: " << name - << ", type: " << static_cast(type); - return; - } - auto idx = it->second; - // Label the instance be deleted - prefill_instance_[idx] = nullptr; - prefill_instance_to_index_.erase(name); - debug_print("remove", name, "prefill or default", idx); - } else { - auto it = decode_instance_to_index_.find(name); - if (it == decode_instance_to_index_.end()) { - LOG(ERROR) << "Remove instance not found, name: " << name - << ", type: " << static_cast(type); - return; - } - auto idx = it->second; - // Label the instance be deleted - decode_instance_[idx] = nullptr; - decode_instance_to_index_.erase(name); - debug_print("remove", name, "decode", idx); - } -} - -RoundRobinDisaggPdPolicy::RoundRobinDisaggPdPolicy() { - LOG(INFO) << "Enable RoundRobin disaggregated pd policy."; -} - -RoundRobinDisaggPdPolicy::~RoundRobinDisaggPdPolicy() {} - -InstancesPair RoundRobinDisaggPdPolicy::select_instances_pair( - bool only_prefill) { - std::lock_guard guard(mutex_); - // return the first available prefill instance - if (only_prefill) { - InstancesPair inst_pair; - for (const auto& inst : prefill_instance_) { - if (inst != nullptr) { - inst_pair.prefill_instance_http_addr = inst->name; - break; - } - } - return inst_pair; - } - - int prefill_count = prefill_instance_.size(); - int decode_count = decode_instance_.size(); - InstancesPair inst_pair; - // select prefill instance - if (prefill_count > 0) { - auto start_idx = next_prefill_idx_; - bool inst_not_existed = false; - while (prefill_instance_[next_prefill_idx_] == nullptr) { - ++next_prefill_idx_; - next_prefill_idx_ %= prefill_count; - if (next_prefill_idx_ == start_idx) { - inst_not_existed = true; - break; - } - } - if (!inst_not_existed) { - inst_pair.prefill_instance_http_addr = - prefill_instance_[next_prefill_idx_]->name; - } - - ++next_prefill_idx_; - next_prefill_idx_ %= prefill_count; - } - - // select decode instance - if (decode_count > 0) { - auto start_idx = next_decode_idx_; - bool inst_not_existed = false; - while (decode_instance_[next_decode_idx_] == nullptr) { - ++next_decode_idx_; - next_decode_idx_ %= decode_count; - if (next_decode_idx_ == start_idx) { - inst_not_existed = true; - break; - } - } - if (!inst_not_existed) { - inst_pair.decode_instance_http_addr = - decode_instance_[next_decode_idx_]->name; - } - - ++next_decode_idx_; - next_decode_idx_ %= decode_count; - } - - return inst_pair; -} - -std::unordered_map -RoundRobinDisaggPdPolicy::reallocate_instances_type(/*params here*/) { - // TODO: implement this function - return {}; -} - -std::unordered_map> -RoundRobinDisaggPdPolicy::allocate_pd_pairs(/*params here*/) { - // TODO: implement this function - return {}; -} - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/disagg_pd_policy.h b/xllm_service/rpc_service/disagg_pd_policy.h deleted file mode 100644 index 4e83d43..0000000 --- a/xllm_service/rpc_service/disagg_pd_policy.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm-service/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include -#include -#include - -#include "common/types.h" - -namespace xllm_service { - -class DisaggPdPolicy { - public: - DisaggPdPolicy(); - virtual ~DisaggPdPolicy(); - - // re-allocate instance types: prefill or decode - virtual std::unordered_map - reallocate_instances_type(/*params here*/) = 0; - - // Allocate prefill and decode pairs, return prefill -> [decode instances] - // Allow multiple decode instances for each prefill instance and - // multiple prefill instances for each decode instance - virtual std::unordered_map> - allocate_pd_pairs(/*params here*/) = 0; - - // select instances(prefill/decode/default etc.) to handle request - // according the disagg pd policy. - virtual InstancesPair select_instances_pair(bool only_prefill = false) = 0; - - void insert_instance(const std::string& name, InstanceMetaInfo* info); - void update_instance(const std::string& name, InstanceMetaInfo* info); - void remove_instance(const std::string& name, InstanceType type); - - protected: - std::vector prefill_instance_; - std::vector decode_instance_; - // map the instance name to vector index - std::unordered_map prefill_instance_to_index_; - std::unordered_map decode_instance_to_index_; - - std::mutex mutex_; -}; - -class RoundRobinDisaggPdPolicy : public DisaggPdPolicy { - public: - RoundRobinDisaggPdPolicy(); - ~RoundRobinDisaggPdPolicy(); - - virtual std::unordered_map - reallocate_instances_type(/*params here*/) override; - virtual std::unordered_map> - allocate_pd_pairs(/*params here*/) override; - virtual InstancesPair select_instances_pair( - bool only_prefill = false) override; - - private: - int next_prefill_idx_ = 0; - int next_decode_idx_ = 0; -}; - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/etcd_client.cpp b/xllm_service/rpc_service/etcd_client.cpp deleted file mode 100644 index 26f7d13..0000000 --- a/xllm_service/rpc_service/etcd_client.cpp +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm-service/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "etcd_client.h" - -#include - -#include - -namespace xllm_service { - -EtcdClient::EtcdClient(const std::string& etcd_addr) - : client_(etcd_addr), etcd_addr_(etcd_addr) { - auto response = client_.put("XLLM_PING", "PING"); - if (!response.is_ok()) { - LOG(FATAL) << "etcd connect to etcd server failed: " - << response.error_message(); - } -} - -EtcdClient::~EtcdClient() {} - -bool EtcdClient::get(const std::string& key, InstanceIdentityInfo& value) { - auto response = client_.get(key); - if (!response.is_ok()) { - LOG(ERROR) << "etcd get " << key << " failed: " << response.error_message(); - return false; - } - auto json_str = response.value().as_string(); - try { - nlohmann::json json_value = nlohmann::json::parse(json_str); - value.instance_addr = json_value.at("instance_addr").get(); - value.instance_type = json_value.at("instance_type").get(); - } catch (const std::exception& e) { - LOG(ERROR) << "etcd get " << key - << " failed: json parse error: " << e.what(); - return false; - } - - return true; -} - -bool EtcdClient::get_prefix(const std::string& key_prefix, - std::vector& values) { - auto response = client_.ls(key_prefix); - if (!response.is_ok()) { - LOG(ERROR) << "etcd get " << key_prefix - << " failed: " << response.error_message(); - return false; - } - for (const auto& v : response.values()) { - InstanceIdentityInfo value; - auto json_str = v.as_string(); - try { - nlohmann::json json_value = nlohmann::json::parse(json_str); - value.instance_addr = json_value.at("instance_addr").get(); - value.instance_type = json_value.at("instance_type").get(); - values.emplace_back(value); - } catch (const std::exception& e) { - LOG(ERROR) << "etcd get " << key_prefix - << " failed: json parse error: " << e.what(); - return false; - } - } - - return true; -} - -bool EtcdClient::set(const std::string& key, - const InstanceIdentityInfo& value) { - std::string json_str; - try { - nlohmann::json json_value; - json_value["instance_addr"] = value.instance_addr; - json_value["instance_type"] = value.instance_type; - json_str = json_value.dump(); - } catch (const std::exception& e) { - LOG(ERROR) << "etcd set " << key - << " failed: json dump error: " << e.what(); - return false; - } - - auto response = client_.put(key, json_str); - if (!response.is_ok()) { - LOG(ERROR) << "etcd set " << key << " failed: " << response.error_message(); - return false; - } - - return true; -} - -bool EtcdClient::rm(const std::string& key) { - auto response = client_.rm(key); - if (!response.is_ok()) { - LOG(ERROR) << "etcd rm " << key << " failed: " << response.error_message(); - return false; - } - - return true; -} - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/etcd_client/CMakeLists.txt b/xllm_service/rpc_service/etcd_client/CMakeLists.txt new file mode 100644 index 0000000..9a1df93 --- /dev/null +++ b/xllm_service/rpc_service/etcd_client/CMakeLists.txt @@ -0,0 +1,17 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + etcd_client + HDRS + etcd_client.h + SRCS + etcd_client.cpp + DEPS + :common + cpprest + etcd-cpp-api + glog::glog + nlohmann_json::nlohmann_json +) \ No newline at end of file diff --git a/xllm_service/rpc_service/etcd_client/etcd_client.cpp b/xllm_service/rpc_service/etcd_client/etcd_client.cpp new file mode 100644 index 0000000..10d1e08 --- /dev/null +++ b/xllm_service/rpc_service/etcd_client/etcd_client.cpp @@ -0,0 +1,195 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "etcd_client.h" + +#include + +#include + +namespace xllm_service { + +EtcdClient::EtcdClient(const std::string& etcd_addr) + : client_(etcd_addr), etcd_addr_(etcd_addr) { + LOG(INFO) << "EtcdClient init put start!"; + auto response = client_.put("XLLM_PING", "PING"); + LOG(INFO) << "EtcdClient init put end!"; + if (!response.is_ok()) { + LOG(FATAL) << "etcd connect to etcd server failed: " + << response.error_message(); + } +} + +EtcdClient::~EtcdClient() { stop_watch(); } + +bool EtcdClient::set(const std::string& key, const std::string& value) { + auto response = client_.put(key, value); + if (!response.is_ok()) { + LOG(ERROR) << "etcd set " << key << " failed: " << response.error_message(); + return false; + } + + return true; +} + +bool EtcdClient::set(const std::string& key, + const std::string& value, + const int ttl) { + auto keep_alive = std::make_shared(client_, ttl); + etcdv3::Transaction transaction; + transaction.add_compare_create(key, 0); + transaction.add_success_put(key, value, keep_alive->Lease()); + etcd::Response response = client_.txn(transaction); + if (response.is_ok()) { + keep_alives_.emplace_back(std::move(keep_alive)); + return true; + } else { + keep_alive->Cancel(); + return false; + } +} + +bool EtcdClient::set(const std::string& key_prefix, + const Murmur3KeyCacheMap& values) { + bool rt = true; + for (const auto& iter : values) { + if (iter.second.empty()) { + rt = rt && client_.rm(key_prefix + iter.first.to_string()).is_ok(); + } else { + rt = rt && client_ + .put(key_prefix + iter.first.to_string(), + iter.second.serialize_to_json().dump()) + .is_ok(); + } + } + return true; +} + +bool EtcdClient::rm(const std::string& key) { + auto response = client_.rm(key); + if (!response.is_ok()) { + LOG(ERROR) << "etcd rm " << key << " failed: " << response.error_message(); + return false; + } + + return true; +} + +bool EtcdClient::rm(const std::string& key_prefix, + const std::unordered_set& keys) { + etcdv3::Transaction transaction; + transaction.add_compare_version( + "XLLM:SERVICE:MASTER", etcdv3::CompareResult::GREATER, -1); + for (const auto& iter : keys) { + transaction.add_success_delete(key_prefix + iter); + } + return client_.txn(transaction).is_ok(); +} + +bool EtcdClient::get(const std::string& key, std::string* value) { + auto response = client_.get(key); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key << " failed: " << response.error_message(); + return false; + } + if (value) { + *value = response.value().as_string(); + } + return true; +} + +bool EtcdClient::get_prefix(const std::string& key_prefix, + Murmur3KeyCacheMap* values) { + auto response = client_.ls(key_prefix); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key_prefix + << " failed: " << response.error_message(); + return false; + } + + for (int i = 0; i < response.keys().size(); i++) { + Murmur3Key key(response.key(i).substr(key_prefix.size()).c_str()); + auto json_str = response.value(i).as_string(); + + CacheLocations value; + if (!value.parse_from_json(json_str)) { + LOG(ERROR) << "Parse json fail: " << json_str; + continue; + } + + values->insert_or_assign(std::move(key), std::move(value)); + } + return true; +} + +bool EtcdClient::get_prefix( + const std::string& key_prefix, + std::unordered_map* values) { + auto response = client_.ls(key_prefix); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key_prefix + << " failed: " << response.error_message(); + return false; + } + + for (int i = 0; i < response.keys().size(); i++) { + auto key_str = response.key(i).substr(key_prefix.size()); + auto str = response.value(i).as_string(); + + values->insert_or_assign(std::move(key_str), std::move(str)); + } + return true; +} + +void EtcdClient::add_watch(const std::string& key_prefix, + Callback callback, + bool recursive) { + std::lock_guard lock(watchers_mutex_); + + if (watchers_.find(key_prefix) != watchers_.end()) { + watchers_[key_prefix].watcher->Cancel(); + } + auto watcher = std::make_unique( + client_, + key_prefix, + [callback, key_prefix](etcd::Response response) { + callback(response, uint64_t(key_prefix.size())); + }, + recursive); + + watchers_[key_prefix] = {std::move(watcher), callback}; +} + +void EtcdClient::remove_watch(const std::string& key_prefix) { + std::lock_guard lock(watchers_mutex_); + + auto it = watchers_.find(key_prefix); + if (it != watchers_.end()) { + it->second.watcher->Cancel(); + watchers_.erase(it); + } +} + +void EtcdClient::stop_watch() { + std::lock_guard lock(watchers_mutex_); + + for (auto& pair : watchers_) { + pair.second.watcher->Cancel(); + } + + watchers_.clear(); +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/etcd_client/etcd_client.h b/xllm_service/rpc_service/etcd_client/etcd_client.h new file mode 100644 index 0000000..06c9bb4 --- /dev/null +++ b/xllm_service/rpc_service/etcd_client/etcd_client.h @@ -0,0 +1,146 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common/hash_util.h" +#include "common/types.h" + +namespace xllm_service { + +using Callback = std::function; + +class EtcdClient { + public: + EtcdClient(const std::string& etcd_addr); + ~EtcdClient(); + + template + bool set(const std::string& key, const T& value) { + auto response = client_.put(key, value.serialize_to_json().dump()); + if (!response.is_ok()) { + LOG(ERROR) << "etcd set " << key + << " failed: " << response.error_message(); + return false; + } + + return true; + } + + template + bool set(const std::string& key_prefix, + const unordered_map& values) { + bool rt = true; + for (const auto& iter : values) { + if (iter.second.empty()) { + rt = rt && client_.rm(key_prefix + iter.first).is_ok(); + } else { + rt = rt && client_ + .put(key_prefix + iter.first, + iter.second.serialize_to_json().dump()) + .is_ok(); + } + } + return true; + } + + bool set(const std::string& key_prefix, const Murmur3KeyCacheMap& values); + + bool set(const std::string& key, const std::string& value); + + // create key-value with lease and transaction + bool set(const std::string& key, const std::string& value, const int ttl); + + bool rm(const std::string& key); + + bool rm(const std::string& key_prefix, + const std::unordered_set& keys); + + template + bool get(const std::string& key, T* value) { + auto response = client_.get(key); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key + << " failed: " << response.error_message(); + return false; + } + if (value) { + return value->parse_from_json(response.value().as_string()); + } else { + return true; + } + } + + bool get(const std::string& key_prefix, std::string* value); + + template + bool get_prefix(const std::string& key_prefix, + std::unordered_map* values) { + auto response = client_.ls(key_prefix); + if (!response.is_ok()) { + LOG(ERROR) << "etcd get " << key_prefix + << " failed: " << response.error_message(); + return false; + } + + for (int i = 0; i < response.keys().size(); i++) { + auto key_str = response.key(i).substr(key_prefix.size()); + auto json_str = response.value(i).as_string(); + + T value; + if (!value.parse_from_json(json_str)) { + LOG(ERROR) << "Parse json fail: " << json_str; + continue; + } + + values->insert_or_assign(std::move(key_str), std::move(value)); + } + return true; + } + + bool get_prefix(const std::string& key_prefix, Murmur3KeyCacheMap* values); + + bool get_prefix(const std::string& key_prefix, + std::unordered_map* values); + + void add_watch(const std::string& key_prefix, + Callback callback, + bool recursive = true); + + void remove_watch(const std::string& key_prefix); + + void stop_watch(); + + private: + struct WatcherInfo { + std::unique_ptr watcher; + Callback callback; + }; + + etcd::SyncClient client_; + std::string etcd_addr_; + std::mutex watchers_mutex_; + std::map watchers_; + std::vector> keep_alives_; +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/instance_mgr.cpp b/xllm_service/rpc_service/instance_mgr.cpp deleted file mode 100644 index 8ad59ec..0000000 --- a/xllm_service/rpc_service/instance_mgr.cpp +++ /dev/null @@ -1,290 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm-service/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "instance_mgr.h" - -#include -#include - -#include -#include - -#include "common/types.h" -#include "common/utils.h" -namespace xllm_service { - -// magic number, TODO: move to config file or env var -static constexpr int kDetectIntervals = 15; // 15seconds -static std::unordered_map ETCD_KEYS_PREFIX_MAP = { - {InstanceType::DEFAULT, "XLLM:DEFAULT:"}, - {InstanceType::PREFILL, "XLLM:PREFILL:"}, - {InstanceType::DECODE, "XLLM:DECODE:"}, -}; -static std::string ETCD_ALL_KEYS_PREFIX = "XLLM:"; -static std::string DEFAULT_DISAGG_PD_POLICY = "RR"; - -InstanceMgr::InstanceMgr(const RpcServiceConfig& config) : config_(config) { - if (config.etcd_addr.empty()) { - LOG(INFO) << "Disable etcd meta server"; - use_etcd_ = false; - } else { - LOG(INFO) << "Connect to etcd meta server: " << config.etcd_addr; - use_etcd_ = true; - etcd_client_ = std::make_unique(config.etcd_addr); - } - - internal_init(); -} - -void InstanceMgr::internal_init() { - std::string pd_policy = config_.disagg_pd_policy; - if (config_.disagg_pd_policy.empty()) { - LOG(WARNING) << "Not specify diasgg pd policy, use `RR` policy as default."; - pd_policy = DEFAULT_DISAGG_PD_POLICY; - } - if (pd_policy == "RR") { - disagg_pd_policy_ = std::make_unique(); - } else { - LOG(FATAL) << "Not supported diasgg pd policy: " << pd_policy; - return; - } - - heartbeat_thread_ = std::make_unique( - &InstanceMgr::detect_disconnected_instances, this); -} - -InstanceMgr::~InstanceMgr() { - exited_ = true; - if (heartbeat_thread_) { - heartbeat_thread_->join(); - } -} - -void InstanceMgr::detect_disconnected_instances() { - while (!exited_) { - std::this_thread::sleep_for(std::chrono::seconds(kDetectIntervals)); - { - std::lock_guard guard(inst_mutex_); - auto now = std::chrono::system_clock::now(); - auto timestamp_ms = std::chrono::duration_cast( - now.time_since_epoch()) - .count(); - std::vector disconnected_instances_name; - for (const auto& [name, info] : instances_) { - if (timestamp_ms - info.latest_timestamp > kDetectIntervals * 1000) { - LOG(WARNING) << "Instance maybe disconnected, instance_name: " << name - << ", last heartbeat interval(s): " - << (timestamp_ms - info.latest_timestamp) / 1000.0; - disconnected_instances_name.emplace_back(name); - } - } - - // no instances disconnected, return - if (disconnected_instances_name.empty()) { - continue; - } - - if (utils::enable_debug_log()) { - const auto instance_names = - absl::StrJoin(disconnected_instances_name, ", "); - LOG(WARNING) << "Detect disconnected instance, instance_name: " - << instance_names; - } - // detele instance metainfo from etcd - delete_persistence_metainfo(disconnected_instances_name); - for (const auto& name : disconnected_instances_name) { - disagg_pd_policy_->remove_instance(name, instances_[name].type); - instances_.erase(name); - } - } - } -} - -ErrorCode InstanceMgr::register_instance(const std::string& instance_name) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Register instance, instance_name: " << instance_name; - } - if (instances_.find(instance_name) != instances_.end()) { - // update_instance_timestamp(instance_name); - LOG(ERROR) << "Instance is already registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_EXISTED; - } - - InstanceMetaInfo default_info(instance_name, ""); - instances_[instance_name] = default_info; - disagg_pd_policy_->insert_instance(instance_name, - &(instances_[instance_name])); - // save instance metainfo to etcd - save_persistence_metainfo(default_info); - return ErrorCode::OK; -} - -ErrorCode InstanceMgr::register_instance(const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Register instance, instance_name: " << instance_name; - } - if (instances_.find(instance_name) != instances_.end()) { - // update_instance_timestamp(instance_name); - LOG(ERROR) << "Instance is already registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_EXISTED; - } - - instances_[instance_name] = metainfo; - disagg_pd_policy_->insert_instance(instance_name, - &(instances_[instance_name])); - // save instance metainfo to etcd - save_persistence_metainfo(metainfo); - return ErrorCode::OK; -} - -ErrorCode InstanceMgr::update_instance_metainfo( - const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Update instance metainfo, instance_name: " - << instance_name; - } - if (instances_.find(instance_name) == instances_.end()) { - LOG(ERROR) << "Instance is not registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_NOT_EXISTED; - } - - instances_[instance_name] = metainfo; - update_instance_timestamp(instance_name); - disagg_pd_policy_->update_instance(instance_name, - &(instances_[instance_name])); - return ErrorCode::OK; -} - -void InstanceMgr::save_persistence_metainfo(const InstanceMetaInfo& metainfo) { - if (!use_etcd_) { - return; - } - std::string key = ETCD_KEYS_PREFIX_MAP[metainfo.type] + metainfo.name; - InstanceIdentityInfo value; - value.instance_addr = metainfo.name; - value.rpc_addr = metainfo.rpc_address; - value.instance_type = static_cast(metainfo.type); - bool ok = etcd_client_->set(key, value); - if (!ok) { - LOG(ERROR) << "Save instance metainfo to etcd failed, key: " << key; - return; - } - - if (utils::enable_debug_log()) { - InstanceIdentityInfo debug_value; - bool ok = etcd_client_->get(key, debug_value); - if (!ok) { - LOG(ERROR) << "Get instance metainfo from etcd failed, key: " << key; - return; - } - LOG(WARNING) << "Instance after put: " << debug_value.debug_string(); - } -} - -void InstanceMgr::delete_persistence_metainfo( - const std::vector& instance_names) { - if (!use_etcd_ || instance_names.empty()) { - return; - } - // TODO: use batch delete later - for (const auto& name : instance_names) { - InstanceMetaInfo& metainfo = instances_[name]; - std::string key = ETCD_KEYS_PREFIX_MAP[metainfo.type] + metainfo.name; - bool ok = etcd_client_->rm(key); - if (!ok) { - LOG(ERROR) << "Delete instance metainfo from etcd failed, key: " << key; - } - } - - if (utils::enable_debug_log()) { - std::vector debug_values; - bool ok = etcd_client_->get_prefix(ETCD_ALL_KEYS_PREFIX, debug_values); - if (!ok) { - LOG(ERROR) << "Get instance metainfo from etcd failed, key: " - << ETCD_ALL_KEYS_PREFIX; - return; - } - std::string concat_debug_str; - for (const auto& v : debug_values) { - concat_debug_str += v.debug_string(); - concat_debug_str += "\n"; - } - LOG(WARNING) << "Instances after delete: " << concat_debug_str; - } -} - -ErrorCode InstanceMgr::heartbeat(const std::string& instance_name) { - std::lock_guard guard(inst_mutex_); - if (utils::enable_debug_log()) { - LOG(WARNING) << "Receive heartbeat, instance_name: " << instance_name; - } - if (instances_.find(instance_name) == instances_.end()) { - LOG(ERROR) << "Instance is not registered, instance_name: " - << instance_name; - return ErrorCode::INSTANCE_NOT_EXISTED; - } - - update_instance_timestamp(instance_name); - - return ErrorCode::OK; -} - -void InstanceMgr::update_instance_timestamp(const std::string& inst_name) { - auto now = std::chrono::system_clock::now(); - auto timestamp_ms = std::chrono::duration_cast( - now.time_since_epoch()) - .count(); - instances_[inst_name].latest_timestamp = timestamp_ms; -} - -InstancesPair InstanceMgr::select_instances_pair(bool only_prefill) { - return disagg_pd_policy_->select_instances_pair(only_prefill); -} - -InstanceMetaInfo InstanceMgr::get_instance_info( - const std::string& instance_name) { - std::lock_guard guard(inst_mutex_); - if (instances_.find(instance_name) == instances_.end()) { - LOG(ERROR) << "Get instance info failed, instance is not registered, " - "instance_name: " - << instance_name; - return InstanceMetaInfo(); - } - return instances_[instance_name]; -} - -// TODO: refactor later, currently return all decode instances -std::vector InstanceMgr::get_static_decode_list( - const std::string& instance_name) { - std::vector decode_list; - std::lock_guard guard(inst_mutex_); - for (auto& inst : instances_) { - if (inst.second.type == InstanceType::DECODE) { - decode_list.emplace_back(inst.second.name); - } - } - - return decode_list; -} - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/instance_mgr.h b/xllm_service/rpc_service/instance_mgr.h deleted file mode 100644 index 4f6bdef..0000000 --- a/xllm_service/rpc_service/instance_mgr.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm-service/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include -#include -#include - -#include "common/types.h" -#include "disagg_pd_policy.h" -#include "etcd_client.h" - -namespace xllm_service { - -class InstanceMgr { - public: - explicit InstanceMgr(const RpcServiceConfig& config); - ~InstanceMgr(); - ErrorCode heartbeat(const std::string& instance_name); - - ErrorCode register_instance(const std::string& instance_name); - ErrorCode register_instance(const std::string& instance_name, - const InstanceMetaInfo& metainfo); - ErrorCode update_instance_metainfo(const std::string& instance_name, - const InstanceMetaInfo& metainfo); - InstanceMetaInfo get_instance_info(const std::string& instance_name); - - // select instances(prefill/decode/default etc.) to handle request - // according the disagg pd policy (or some other policies.). - InstancesPair select_instances_pair(bool only_prefill = false); - - std::vector get_static_decode_list( - const std::string& instance_name); - - private: - void internal_init(); - // save instance metainfo to etcd - void save_persistence_metainfo(const InstanceMetaInfo& metainfo); - // delete instance metainfo from etcd - void delete_persistence_metainfo( - const std::vector& instance_names); - void detect_disconnected_instances(); - void update_instance_timestamp(const std::string& inst_name); - - private: - RpcServiceConfig config_; - bool exited_ = false; - std::mutex inst_mutex_; - std::unordered_map instances_; - std::unique_ptr heartbeat_thread_; - - std::unique_ptr disagg_pd_policy_; - - bool use_etcd_ = false; - std::unique_ptr etcd_client_; -}; - -} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt b/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt new file mode 100644 index 0000000..3d9fb9a --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/CMakeLists.txt @@ -0,0 +1,18 @@ +include(cc_binary) +include(cc_library) +include(cc_test) + +cc_library( + NAME + loadbalance_policy + HDRS + loadbalance_policy.h + round_robin.h + cache_aware_routing.h + SRCS + round_robin.cpp + cache_aware_routing.cpp + DEPS + :common + :managers +) diff --git a/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.cpp b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.cpp new file mode 100644 index 0000000..11ada6f --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.cpp @@ -0,0 +1,87 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "cache_aware_routing.h" + +namespace xllm_service { + +constexpr float MIN_SCORE = -2.0; + +bool CacheAwareRouting::select_instances_pair(ScheduleResult* res) { + LoadBalanceInfos lb_infos; + if (!res->token_ids.empty()) { + Slice token_ids(res->token_ids.data(), res->token_ids.size()); + global_kvcache_mgr_->match(token_ids, &lb_infos.overlap_scores); + DLOG(INFO) << lb_infos.debug_string(); + } + + instance_mgr_->get_load_metrics(&lb_infos); + DLOG(INFO) << lb_infos.debug_string(); + + if (lb_infos.prefill_load_metrics.size() == 0) { + LOG(INFO) << "No node available!"; + return false; + } + + // find preifll + cost_function(lb_infos.overlap_scores.hbm_instance_score, + lb_infos.overlap_scores.max_block_num, + lb_infos.prefill_load_metrics, + lb_infos.prefill_max_waiting_requests_num, + &res->routing.prefill_name); + + // find decode + if (lb_infos.decode_load_metrics.size()) { + cost_function(lb_infos.overlap_scores.hbm_instance_score, + lb_infos.overlap_scores.max_block_num, + lb_infos.decode_load_metrics, + lb_infos.decode_max_waiting_requests_num, + &res->routing.decode_name); + } + + return true; +} + +void CacheAwareRouting::cost_function( + const std::unordered_map& overlap_scores, + const uint32_t& max_block_num, + const std::unordered_map& load_metrics, + const int64_t& max_waiting_requests_num, + std::string* best_choice) { + float best_score = MIN_SCORE; + for (const auto& it : load_metrics) { + const auto matched_blocks_it = overlap_scores.find(it.first); + uint32_t matched_blocks = 0; + if (matched_blocks_it != overlap_scores.end()) { + matched_blocks = matched_blocks_it->second; + } + + auto score = + (max_block_num == 0 ? 0 : matched_blocks / max_block_num) - + it.second.gpu_cache_usage_perc - + (max_waiting_requests_num == 0 + ? 0 + : it.second.waiting_requests_num / max_waiting_requests_num); + + if (score > best_score) { + best_score = score; + *best_choice = it.first; + } + } +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.h b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.h new file mode 100644 index 0000000..98c8605 --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/cache_aware_routing.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "common/macros.h" +#include "loadbalance_policy.h" + +namespace xllm_service { + +class CacheAwareRouting final : public LoadBalancePolicy { + public: + CacheAwareRouting(std::shared_ptr instance_mgr, + std::shared_ptr global_kvcache_mgr) + : LoadBalancePolicy(instance_mgr, global_kvcache_mgr) {}; + + virtual ~CacheAwareRouting() = default; + + bool select_instances_pair(ScheduleResult* res) override; + + protected: + DISALLOW_COPY_AND_ASSIGN(CacheAwareRouting); + + void cost_function( + const std::unordered_map& overlap_scores, + const uint32_t& max_block_num, + const std::unordered_map& load_metrics, + const int64_t& max_waiting_requests_num, + std::string* best_choice); +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/etcd_client.h b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h similarity index 56% rename from xllm_service/rpc_service/etcd_client.h rename to xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h index 56ae128..205e880 100644 --- a/xllm_service/rpc_service/etcd_client.h +++ b/xllm_service/rpc_service/loadbalance_policy/loadbalance_policy.h @@ -15,32 +15,26 @@ limitations under the License. #pragma once -#include -#include - +#include "../managers/global_kvcache_mgr.h" +#include "../managers/instance_mgr.h" #include "common/types.h" namespace xllm_service { -// the format is: -// key: XLLM:PREFILL:inst_id -> value -// or -// key: XLLM:DECODE:inst_id -> value -class EtcdClient { +class LoadBalancePolicy { public: - EtcdClient(const std::string& etcd_addr); - ~EtcdClient(); - - bool get(const std::string& key, InstanceIdentityInfo& value); - // get all keys with prefix - bool get_prefix(const std::string& key_prefix, - std::vector& values); - bool set(const std::string& key, const InstanceIdentityInfo& value); - bool rm(const std::string& key); - - private: - etcd::SyncClient client_; - std::string etcd_addr_; + LoadBalancePolicy(std::shared_ptr instance_mgr, + std::shared_ptr global_kvcache_mgr) + : instance_mgr_(instance_mgr), global_kvcache_mgr_(global_kvcache_mgr) {} + + virtual ~LoadBalancePolicy() = default; + + virtual bool select_instances_pair(ScheduleResult* res) = 0; + + protected: + std::shared_ptr instance_mgr_; + + std::shared_ptr global_kvcache_mgr_; }; } // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/round_robin.cpp b/xllm_service/rpc_service/loadbalance_policy/round_robin.cpp new file mode 100644 index 0000000..0ed8a1d --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/round_robin.cpp @@ -0,0 +1,26 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "round_robin.h" + +namespace xllm_service { + +bool RoundRobin::select_instances_pair(ScheduleResult* res) { + return instance_mgr_->get_next_instance_pair(&res->routing); +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/loadbalance_policy/round_robin.h b/xllm_service/rpc_service/loadbalance_policy/round_robin.h new file mode 100644 index 0000000..2616ad8 --- /dev/null +++ b/xllm_service/rpc_service/loadbalance_policy/round_robin.h @@ -0,0 +1,37 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "common/macros.h" +#include "loadbalance_policy.h" + +namespace xllm_service { + +class RoundRobin final : public LoadBalancePolicy { + public: + RoundRobin(std::shared_ptr instance_mgr, + std::shared_ptr global_kvcache_mgr) + : LoadBalancePolicy(instance_mgr, global_kvcache_mgr) {}; + + virtual ~RoundRobin() = default; + + bool select_instances_pair(ScheduleResult* res) override; + + protected: + DISALLOW_COPY_AND_ASSIGN(RoundRobin); +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/main.cpp b/xllm_service/rpc_service/main.cpp index fdd6b78..3707eb5 100644 --- a/xllm_service/rpc_service/main.cpp +++ b/xllm_service/rpc_service/main.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include "common/global_gflags.h" +#include "common/types.h" #include "common/utils.h" #include "rpc_service/service.h" @@ -27,7 +28,10 @@ int main(int argc, char* argv[]) { // Initialize glog google::InitGoogleLogging(argv[0]); - FLAGS_logtostderr = true; + + LOG(INFO) << "Dump all gflags: " << std::endl + << google::CommandlineFlagsIntoString(); + google::FlushLogFiles(google::INFO); LOG(INFO) << "Starting xllm rpc service, port: " << FLAGS_port; @@ -39,13 +43,23 @@ int main(int argc, char* argv[]) { xllm_service::RpcServiceConfig config; config.etcd_addr = FLAGS_etcd_addr; - config.disagg_pd_policy = FLAGS_disagg_pd_policy; + config.load_balance_policy = FLAGS_load_balance_policy; config.detect_disconnected_instance_interval = FLAGS_detect_disconnected_instance_interval; + xllm_service::ModelConfig model_config; + model_config.block_size = FLAGS_block_size; + model_config.model_type = FLAGS_model_type; + model_config.tokenizer_path = FLAGS_tokenizer_path; + + xllm_service::HttpServiceConfig http_config; + http_config.num_threads = FLAGS_num_threads; + http_config.timeout_ms = FLAGS_timeout_ms; + http_config.test_instance_addr = FLAGS_test_instance_addr; + // create xllm service - auto xllm_service_impl = - std::make_shared(config); + auto xllm_service_impl = std::make_shared( + config, model_config, http_config); xllm_service::XllmRpcService service(xllm_service_impl); // Initialize brpc server diff --git a/xllm_service/rpc_service/managers/CMakeLists.txt b/xllm_service/rpc_service/managers/CMakeLists.txt new file mode 100644 index 0000000..79f5b49 --- /dev/null +++ b/xllm_service/rpc_service/managers/CMakeLists.txt @@ -0,0 +1,22 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + managers + HDRS + instance_mgr.h + global_kvcache_mgr.h + SRCS + instance_mgr.cpp + global_kvcache_mgr.cpp + DEPS + :common + :etcd_client + absl::random_random + absl::strings + glog::glog + proto::proto_rpc_service + proto_xllm +) +target_link_libraries(managers PRIVATE brpc-static) diff --git a/xllm_service/rpc_service/managers/global_kvcache_mgr.cpp b/xllm_service/rpc_service/managers/global_kvcache_mgr.cpp new file mode 100644 index 0000000..6d6484e --- /dev/null +++ b/xllm_service/rpc_service/managers/global_kvcache_mgr.cpp @@ -0,0 +1,256 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "global_kvcache_mgr.h" + +#include + +#include "common/hash_util.h" + +namespace xllm_service { + +inline size_t round_down(size_t n, size_t multiple) { + return (n / multiple) * multiple; +} + +static std::string ETCD_CACHE_PREFIX = "XLLM:CACHE:"; + +GlobalKVCacheMgr::GlobalKVCacheMgr( + const std::shared_ptr& etcd_client, + const ModelConfig& model_config, + const bool is_master_service) + : model_config_(model_config), + is_master_service_(is_master_service), + etcd_client_(etcd_client) { + if (!is_master_service_) { + auto handle_kvcache = std::bind(&GlobalKVCacheMgr::update_kvcache, + this, + std::placeholders::_1, + std::placeholders::_2); + etcd_client_->add_watch(ETCD_CACHE_PREFIX, handle_kvcache); + } + + { + std::unique_lock lock(kvcache_mutex_); + etcd_client_->get_prefix(ETCD_CACHE_PREFIX, &kvcache_infos_); + DLOG(INFO) << "Load etcd cache infos:" << kvcache_infos_.size(); + } +} + +GlobalKVCacheMgr::~GlobalKVCacheMgr() { + exited_ = true; + etcd_client_->remove_watch(ETCD_CACHE_PREFIX); +} + +void set_score(const std::unordered_set& instance_names, + const uint32_t& match_length, + std::unordered_map* scores, + std::unordered_set* instances) { + for (const auto& name : instance_names) { + if (scores->count(name) == 0) { + scores->insert_or_assign(name, match_length); + } else { + (*scores)[name] = match_length; + } + instances->insert(name); + } +} + +void GlobalKVCacheMgr::match(const Slice& token_ids, + OverlapScores* overlap_scores) { + // allign tokens to block boundary + const size_t n_tokens = + round_down(token_ids.size(), model_config_.block_size); + if (n_tokens == 0) { + return; + } + + overlap_scores->max_block_num = n_tokens / model_config_.block_size; + + std::shared_lock lock(kvcache_mutex_); + Murmur3Key token_hash_key; + for (size_t i = 0; i < n_tokens; i += model_config_.block_size) { + if (i == 0) { + murmur_hash3(nullptr, + token_ids.slice(i, i + model_config_.block_size), + token_hash_key.data); + } else { + murmur_hash3(token_hash_key.data, + token_ids.slice(i, i + model_config_.block_size), + token_hash_key.data); + } + + auto iter = kvcache_infos_.find(token_hash_key); + if (iter != kvcache_infos_.end() && !iter->second.empty()) { + if (!iter->second.hbm_instance_set.empty()) { + set_score(iter->second.hbm_instance_set, + i / model_config_.block_size + 1, + &(overlap_scores->hbm_instance_score), + &(overlap_scores->instances)); + overlap_scores->max_matched_instance_name = + *iter->second.hbm_instance_set.begin(); + overlap_scores->max_matched_block_num = + i / model_config_.block_size + 1; + } + + if (!iter->second.dram_instance_set.empty()) { + set_score(iter->second.dram_instance_set, + i / model_config_.block_size + 1, + &(overlap_scores->dram_instance_score), + &(overlap_scores->instances)); + overlap_scores->max_matched_instance_name = + *iter->second.hbm_instance_set.begin(); + overlap_scores->max_matched_block_num = + i / model_config_.block_size + 1; + } + + if (!iter->second.ssd_instance_set.empty()) { + set_score(iter->second.ssd_instance_set, + i / model_config_.block_size + 1, + &(overlap_scores->ssd_instance_score), + &(overlap_scores->instances)); + overlap_scores->max_matched_instance_name = + *iter->second.hbm_instance_set.begin(); + overlap_scores->max_matched_block_num = + i / model_config_.block_size + 1; + } + } else { + break; + } + } +} + +void GlobalKVCacheMgr::update_kvcache(const etcd::Response& response, + const uint64_t prefix_len) { + if (response.events().empty() || exited_) { + return; + } + threadpool_.schedule([this, + response = std::move(response), + prefix_len = std::move(prefix_len)] { + if (exited_) return; + Murmur3KeyCacheMap put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + auto key = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + CacheLocations cachelocations; + auto json_str = event.kv().as_string(); + if (!cachelocations.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } + + put_map.insert_or_assign(Murmur3Key{key.c_str()}, + std::move(cachelocations)); + + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.emplace_back(Murmur3Key{key.c_str()}); + } + } + + { + std::unique_lock lock(kvcache_mutex_); + for (auto& iter : put_map) { + kvcache_infos_.insert_or_assign(iter.first, std::move(iter.second)); + } + + for (auto& iter : delete_list) { + kvcache_infos_.erase(iter); + } + } + }); +} + +void GlobalKVCacheMgr::record_updated_kvcaches( + const std::string& instance_name, + const proto::KvCacheEvent& kvcache_event) { + std::lock_guard update_lock(update_mutex_); + std::shared_lock metric_lock(kvcache_mutex_); + for (int i = 0; i < kvcache_event.stored_cache_size(); i++) { + Murmur3Key key(kvcache_event.stored_cache(i).c_str()); + if (updated_kvcaches_.count(key) == 0) { + if (kvcache_infos_.count(key) == 0) { + updated_kvcaches_.insert_or_assign(key, CacheLocations()); + } else { + updated_kvcaches_.insert_or_assign(key, kvcache_infos_[key]); + } + } + updated_kvcaches_.at(key).hbm_instance_set.insert(instance_name); + } + + for (int i = 0; i < kvcache_event.offload_cache_size(); i++) { + Murmur3Key key(kvcache_event.offload_cache(i).c_str()); + if (updated_kvcaches_.count(key) == 0) { + if (kvcache_infos_.count(key) == 0) { + continue; + } else { + updated_kvcaches_.insert_or_assign(key, kvcache_infos_[key]); + } + } + if (updated_kvcaches_.at(key).hbm_instance_set.count(instance_name) != 0) { + updated_kvcaches_.at(key).hbm_instance_set.erase(instance_name); + updated_kvcaches_.at(key).dram_instance_set.insert(instance_name); + } else { + updated_kvcaches_.at(key).dram_instance_set.erase(instance_name); + updated_kvcaches_.at(key).ssd_instance_set.insert(instance_name); + } + } + + for (int i = 0; i < kvcache_event.removed_cache_size(); i++) { + Murmur3Key key(kvcache_event.removed_cache(i).c_str()); + if (updated_kvcaches_.count(key) == 0) { + if (kvcache_infos_.count(key) == 0) { + continue; + } else { + updated_kvcaches_.insert_or_assign(key, kvcache_infos_[key]); + } + } + updated_kvcaches_.at(key).hbm_instance_set.erase(instance_name); + updated_kvcaches_.at(key).dram_instance_set.erase(instance_name); + updated_kvcaches_.at(key).ssd_instance_set.erase(instance_name); + } +} + +bool GlobalKVCacheMgr::upload_kvcache() { + std::lock_guard update_lock(update_mutex_); + if (updated_kvcaches_.empty()) { + return true; + } + bool rt = etcd_client_->set(ETCD_CACHE_PREFIX, updated_kvcaches_); + { + std::unique_lock metric_lock(kvcache_mutex_); + for (auto& iter : updated_kvcaches_) { + if (iter.second.empty()) { + kvcache_infos_.erase(iter.first); + } else { + kvcache_infos_.insert_or_assign(iter.first, std::move(iter.second)); + } + } + } + if (rt) { + updated_kvcaches_.clear(); + } + return rt; +} + +void GlobalKVCacheMgr::set_as_master() { + is_master_service_ = true; + etcd_client_->remove_watch(ETCD_CACHE_PREFIX); +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/managers/global_kvcache_mgr.h b/xllm_service/rpc_service/managers/global_kvcache_mgr.h new file mode 100644 index 0000000..0115aa5 --- /dev/null +++ b/xllm_service/rpc_service/managers/global_kvcache_mgr.h @@ -0,0 +1,66 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "../etcd_client/etcd_client.h" +#include "common/hash_util.h" +#include "common/macros.h" +#include "common/slice.h" +#include "common/threadpool.h" +#include "common/types.h" +#include "xllm_rpc_service.pb.h" + +namespace xllm_service { + +class GlobalKVCacheMgr final { + public: + explicit GlobalKVCacheMgr(const std::shared_ptr& etcd_client, + const ModelConfig& model_config, + const bool is_master_service); + ~GlobalKVCacheMgr(); + + void match(const Slice& token_ids, OverlapScores* overlap_scores); + + void record_updated_kvcaches(const std::string& instance_name, + const proto::KvCacheEvent& kvcache_event); + bool upload_kvcache(); + + void set_as_master(); + + private: + DISALLOW_COPY_AND_ASSIGN(GlobalKVCacheMgr); + + void update_kvcache(const etcd::Response& response, + const uint64_t prefix_len); + + private: + ModelConfig model_config_; + std::atomic_bool is_master_service_ = false; + bool exited_ = false; + std::shared_mutex kvcache_mutex_; + Murmur3KeyCacheMap kvcache_infos_; + std::shared_ptr etcd_client_; // not own + + std::mutex update_mutex_; + Murmur3KeyCacheMap updated_kvcaches_; + + ThreadPool threadpool_; +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/managers/instance_mgr.cpp b/xllm_service/rpc_service/managers/instance_mgr.cpp new file mode 100644 index 0000000..65dc7e7 --- /dev/null +++ b/xllm_service/rpc_service/managers/instance_mgr.cpp @@ -0,0 +1,444 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "instance_mgr.h" + +#include +#include + +#include +#include +#include + +#include "common/types.h" +#include "common/utils.h" + +namespace xllm_service { + +// magic number, TODO: move to config file or env var +static constexpr int kDetectIntervals = 15; // 15seconds +static std::unordered_map ETCD_KEYS_PREFIX_MAP = { + {InstanceType::DEFAULT, "XLLM:DEFAULT:"}, + {InstanceType::PREFILL, "XLLM:PREFILL:"}, + {InstanceType::DECODE, "XLLM:DECODE:"}, +}; +static std::string ETCD_ALL_KEYS_PREFIX = "XLLM:"; +static std::string ETCD_LOADMETRICS_PREFIX = "XLLM:LOADMETRICS:"; + +InstanceMgr::InstanceMgr(const std::shared_ptr& etcd_client, + const HttpServiceConfig& config, + const bool is_master_service) + : http_config_(config), + is_master_service_(is_master_service), + etcd_client_(etcd_client) { + auto handle_instance_metainfo = + std::bind(&InstanceMgr::update_instance_metainfo, + this, + std::placeholders::_1, + std::placeholders::_2); + for (auto& it : ETCD_KEYS_PREFIX_MAP) { + etcd_client_->add_watch(it.second, handle_instance_metainfo); + } + if (!is_master_service_) { + auto handle_load_metrics = std::bind(&InstanceMgr::update_load_metrics, + this, + std::placeholders::_1, + std::placeholders::_2); + etcd_client_->add_watch(ETCD_LOADMETRICS_PREFIX, handle_load_metrics); + } + + init(); +} + +void InstanceMgr::init() { + { + std::unique_lock lock(inst_mutex_); + for (auto& it : ETCD_KEYS_PREFIX_MAP) { + etcd_client_->get_prefix(it.second, &instances_); + } + LOG(INFO) << "Load instance info from etcd:" << instances_.size(); + std::vector channel_creat_fail_insts; + prefill_index_.reserve(instances_.size()); + decode_index_.reserve(instances_.size()); + + for (auto& ist : instances_) { + if (!create_channel(ist.first)) { + channel_creat_fail_insts.emplace_back(ist.first); + } else { + switch (ist.second.type) { + case InstanceType::DEFAULT: + case InstanceType::PREFILL: + ist.second.instance_index = prefill_index_.size(); + prefill_index_.emplace_back(ist.first); + break; + case InstanceType::DECODE: + ist.second.instance_index = decode_index_.size(); + decode_index_.emplace_back(ist.first); + break; + default: + LOG(WARNING) << "Unknown InstanceType: " << int(ist.second.type); + channel_creat_fail_insts.emplace_back(ist.first); + break; + } + } + } + for (auto& name : channel_creat_fail_insts) { + instances_.erase(name); + } + } + { + std::unique_lock lock(load_metric_mutex_); + etcd_client_->get_prefix(ETCD_LOADMETRICS_PREFIX, &load_metrics_); + } + + for (int i = 0; i < prefill_index_.size(); i++) { + LOG(INFO) << i << " : " << prefill_index_[i]; + } +} + +InstanceMgr::~InstanceMgr() { exited_ = true; } + +InstanceMetaInfo InstanceMgr::get_instance_info( + const std::string& instance_name) { + std::shared_lock lock(inst_mutex_); + if (instances_.find(instance_name) == instances_.end()) { + LOG(ERROR) << "Get instance info failed, instance is not registered, " + "instance_name: " + << instance_name; + return InstanceMetaInfo(); + } + return instances_[instance_name]; +} + +bool InstanceMgr::get_next_instance_pair(Routing* routing) { + std::unique_lock lock(inst_mutex_); + if (prefill_index_.empty()) { + LOG(ERROR) << "No prefill or default instance found!"; + return false; + } + next_prefill_index_ = next_prefill_index_ % prefill_index_.size(); + routing->prefill_name = prefill_index_[next_prefill_index_]; + next_prefill_index_++; + if (decode_index_.empty()) { + return true; + } + next_decode_index_ = next_decode_index_ % decode_index_.size(); + routing->decode_name = decode_index_[next_decode_index_]; + next_decode_index_++; + return true; +} + +// TODO: refactor later, currently return all decode instances +std::vector InstanceMgr::get_static_decode_list( + const std::string& instance_name) { + std::vector decode_list; + std::shared_lock lock(inst_mutex_); + for (auto& inst : instances_) { + if (inst.second.type == InstanceType::DECODE) { + decode_list.emplace_back(inst.second.name); + } + } + + return decode_list; +} + +void InstanceMgr::get_load_metrics(LoadBalanceInfos* infos) { + std::shared_lock inst_lock(inst_mutex_); + std::shared_lock metric_lock(load_metric_mutex_); + + for (auto name : infos->overlap_scores.instances) { + auto it = load_metrics_.find(name); + if (it == load_metrics_.end()) { + continue; + } + auto instance_it = instances_.find(name); + if (instance_it == instances_.end()) { + continue; + } + + if (instance_it->second.type == InstanceType::DECODE) { + infos->decode_load_metrics.insert(std::make_pair(name, it->second)); + infos->decode_max_waiting_requests_num = + std::max(infos->decode_max_waiting_requests_num, + it->second.waiting_requests_num); + } else { + infos->prefill_load_metrics.insert(std::make_pair(name, it->second)); + infos->prefill_max_waiting_requests_num = + std::max(infos->prefill_max_waiting_requests_num, + it->second.waiting_requests_num); + } + } + + std::string least_loaded_prefill_instance; + float least_loaded_prefill_gpu_cache_usage_perc = 1; + std::string least_loaded_decode_instance; + float least_loaded_decode_gpu_cache_usage_perc = 1; + + if (infos->prefill_load_metrics.size() == 0 || + infos->decode_load_metrics.size() == 0) { + for (const auto& metric : load_metrics_) { + auto instance_it = instances_.find(metric.first); + if (instance_it != instances_.end()) { + if (instance_it->second.type != InstanceType::DECODE) { + if (metric.second.gpu_cache_usage_perc < + least_loaded_prefill_gpu_cache_usage_perc) { + least_loaded_prefill_gpu_cache_usage_perc = + metric.second.gpu_cache_usage_perc; + least_loaded_prefill_instance = metric.first; + } + } else { + if (metric.second.gpu_cache_usage_perc < + least_loaded_decode_gpu_cache_usage_perc) { + least_loaded_decode_gpu_cache_usage_perc = + metric.second.gpu_cache_usage_perc; + least_loaded_decode_instance = metric.first; + } + } + } + } + } + + if (infos->prefill_load_metrics.size() == 0 && + !least_loaded_prefill_instance.empty()) { + infos->prefill_load_metrics.insert( + std::make_pair(least_loaded_prefill_instance, + load_metrics_[least_loaded_prefill_instance])); + } + + if (infos->decode_load_metrics.size() == 0 && + !least_loaded_decode_instance.empty()) { + infos->decode_load_metrics.insert( + std::make_pair(least_loaded_decode_instance, + load_metrics_[least_loaded_decode_instance])); + } +} + +void InstanceMgr::record_load_metrics_update( + const std::string& instance_name, + const proto::LoadMetrics& load_metrics) { + std::lock_guard lock(update_mutex_); + + updated_metrics_.insert_or_assign( + instance_name, + LoadMetrics(load_metrics.waiting_requests_num(), + load_metrics.gpu_cache_usage_perc())); +} + +bool InstanceMgr::upload_load_metrics() { + std::lock_guard lock(update_mutex_); + bool status = etcd_client_->set(ETCD_LOADMETRICS_PREFIX, updated_metrics_); + status = + status && etcd_client_->rm(ETCD_LOADMETRICS_PREFIX, removed_instance_); + { + std::unique_lock lock(inst_mutex_); + for (auto& iter : updated_metrics_) { + load_metrics_.insert_or_assign(iter.first, std::move(iter.second)); + } + for (auto& iter : removed_instance_) { + load_metrics_.erase(iter); + } + } + updated_metrics_.clear(); + removed_instance_.clear(); + + return status; +} + +void InstanceMgr::set_as_master() { + is_master_service_ = true; + etcd_client_->remove_watch(ETCD_LOADMETRICS_PREFIX); +} + +std::shared_ptr InstanceMgr::get_channel( + const std::string& instance_name) { + std::shared_lock lock(inst_mutex_); + auto iter = cached_channels_.find(instance_name); + if (iter == cached_channels_.end()) { + return nullptr; + } + return iter->second; +} + +bool InstanceMgr::create_channel(const std::string& instance_name) { + if (cached_channels_.find(instance_name) == cached_channels_.end()) { + auto channel = std::make_shared(); + brpc::ChannelOptions options; + // Add to params + options.protocol = "http"; + options.timeout_ms = http_config_.timeout_ms; /*milliseconds*/ + options.max_retry = 3; + std::string load_balancer = ""; + if (channel->Init(instance_name.c_str(), load_balancer.c_str(), &options) != + 0) { + LOG(ERROR) << "Fail to initialize channel for " << instance_name; + return false; + } + cached_channels_[instance_name] = std::move(channel); + } + + return true; +} + +void InstanceMgr::update_instance_metainfo(const etcd::Response& response, + const uint64_t& prefix_len) { + if (response.events().empty() || exited_) { + return; + } + + threadpool_.schedule([this, + response = std::move(response), + prefix_len = std::move(prefix_len)] { + if (exited_) return; + std::unordered_map put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + std::string instance_name = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + InstanceMetaInfo metainfo; + auto json_str = event.kv().as_string(); + if (!metainfo.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } + + put_map.insert(std::make_pair(instance_name, std::move(metainfo))); + + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.push_back(instance_name); + } + } + + { + std::unique_lock lock(inst_mutex_); + for (auto& iter : put_map) { + if (instances_.find(iter.first) != instances_.end()) { + LOG(ERROR) << "Instance is already registered, instance_name: " + << iter.first; + continue; + } + + if (!create_channel(iter.first)) { + LOG(ERROR) << "create channel fail: " << iter.first; + continue; + } + + instances_.insert(std::make_pair(iter.first, std::move(iter.second))); + + switch (iter.second.type) { + case InstanceType::DEFAULT: + case InstanceType::PREFILL: + iter.second.instance_index = prefill_index_.size(); + prefill_index_.emplace_back(iter.first); + break; + case InstanceType::DECODE: + iter.second.instance_index = decode_index_.size(); + decode_index_.emplace_back(iter.first); + break; + default: + LOG(WARNING) << "Unknown InstanceType: " << int(iter.second.type); + break; + } + } + + for (auto& iter : delete_list) { + if (instances_.find(iter) == instances_.end()) { + LOG(ERROR) << "Instance is already deleted, instance_name: " << iter; + continue; + } + // TODO: notify cache manager to clear expire cache + uint64_t index = instances_[iter].instance_index; + + switch (instances_[iter].type) { + case InstanceType::DEFAULT: + case InstanceType::PREFILL: + if (index == -1 || index >= prefill_index_.size()) { + break; + } + std::swap(prefill_index_[index], prefill_index_.back()); + instances_[prefill_index_[index]].instance_index = index; + prefill_index_.pop_back(); + break; + case InstanceType::DECODE: + if (index == -1 || index >= decode_index_.size()) { + break; + } + std::swap(decode_index_[index], decode_index_.back()); + instances_[decode_index_[index]].instance_index = index; + decode_index_.pop_back(); + break; + default: + LOG(WARNING) << "Unknown InstanceType: " + << int(instances_[iter].type); + break; + } + + instances_.erase(iter); + cached_channels_.erase(iter); + { + std::lock_guard lock(update_mutex_); + updated_metrics_.erase(iter); + removed_instance_.insert(iter); + } + } + } + }); +} + +void InstanceMgr::update_load_metrics(const etcd::Response& response, + const uint64_t& prefix_len) { + if (response.events().empty() || exited_) { + return; + } + threadpool_.schedule([this, + response = std::move(response), + prefix_len = std::move(prefix_len)] { + if (exited_) return; + std::unordered_map put_map; + std::vector delete_list; + + for (const auto& event : response.events()) { + std::string instance_name = event.kv().key().substr(prefix_len); + + if (event.event_type() == etcd::Event::EventType::PUT) { + LoadMetrics load_metrics; + auto json_str = event.kv().as_string(); + if (!load_metrics.parse_from_json(json_str)) { + LOG(ERROR) << "pase json:" << json_str << " error!"; + continue; + } + + put_map.insert(std::make_pair(instance_name, std::move(load_metrics))); + + } else if (event.event_type() == etcd::Event::EventType::DELETE_) { + delete_list.push_back(instance_name); + } + } + + { + std::unique_lock lock(load_metric_mutex_); + for (auto& iter : put_map) { + load_metrics_.insert_or_assign(iter.first, std::move(iter.second)); + } + + for (auto& iter : delete_list) { + load_metrics_.erase(iter); + } + } + }); +} + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/managers/instance_mgr.h b/xllm_service/rpc_service/managers/instance_mgr.h new file mode 100644 index 0000000..6905270 --- /dev/null +++ b/xllm_service/rpc_service/managers/instance_mgr.h @@ -0,0 +1,100 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm-service/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include +#include +#include + +#include "../etcd_client/etcd_client.h" +#include "common/macros.h" +#include "common/threadpool.h" +#include "common/types.h" +#include "xllm_rpc_service.pb.h" + +namespace xllm_service { + +class InstanceMgr final { + public: + explicit InstanceMgr(const std::shared_ptr& etcd_client, + const HttpServiceConfig& config, + const bool is_master_service); + + ~InstanceMgr(); + + InstanceMetaInfo get_instance_info(const std::string& instance_name); + + bool get_next_instance_pair(Routing* routing); + + std::vector get_static_decode_list( + const std::string& instance_name); + + void get_load_metrics(LoadBalanceInfos* infos); + + std::shared_ptr get_channel(const std::string& instance_name); + + void record_load_metrics_update(const std::string& instance_name, + const proto::LoadMetrics& load_metrics); + bool upload_load_metrics(); + + void set_as_master(); + + private: + DISALLOW_COPY_AND_ASSIGN(InstanceMgr); + + void init(); + + bool create_channel(const std::string& target_uri); + // use etcd as ServiceDiscovery + void update_instance_metainfo(const etcd::Response& response, + const uint64_t& prefix_len); + + void update_load_metrics(const etcd::Response& response, + const uint64_t& prefix_len); + + private: + bool exited_ = false; + bool use_etcd_ = false; + std::atomic_bool is_master_service_ = false; + + RpcServiceConfig config_; + HttpServiceConfig http_config_; + + std::shared_ptr etcd_client_; + + std::shared_mutex inst_mutex_; + std::unordered_map instances_; + std::vector prefill_index_; + std::vector decode_index_; + uint64_t next_prefill_index_ = 0; + uint64_t next_decode_index_ = 0; + + std::shared_mutex load_metric_mutex_; + std::unordered_map load_metrics_; + std::unordered_map> + cached_channels_; + + std::mutex update_mutex_; + std::unordered_map updated_metrics_; + std::unordered_set removed_instance_; + + ThreadPool threadpool_; +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/rpc_service_test.cpp b/xllm_service/rpc_service/rpc_service_test.cpp index d9f490d..36eb9cb 100644 --- a/xllm_service/rpc_service/rpc_service_test.cpp +++ b/xllm_service/rpc_service/rpc_service_test.cpp @@ -26,36 +26,44 @@ class XllmRpcServiceTest : public ::testing::Test { void TearDown() override { google::ShutdownGoogleLogging(); } }; +// TODO +// TEST_F(XllmRpcServiceTest, RegisterInstance) { +// RpcServiceConfig config; +// HttpServiceConfig http_config; +// ModelConfig model_config; +// auto xllm_service = +// std::make_shared(config, model_config, +// http_config); +// std::string inst_name = "127.0.0.1@nic0"; +// InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", +// InstanceType::PREFILL); EXPECT_EQ(ErrorCode::OK, +// xllm_service->register_instance(inst_name, metainfo)); -TEST_F(XllmRpcServiceTest, RegisterInstance) { - RpcServiceConfig config; - auto xllm_service = std::make_shared(config); - std::string inst_name = "127.0.0.1@nic0"; - InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", InstanceType::PREFILL); - EXPECT_EQ(ErrorCode::OK, - xllm_service->register_instance(inst_name, metainfo)); - - metainfo.type = InstanceType::DECODE; - EXPECT_EQ(ErrorCode::INSTANCE_EXISTED, - xllm_service->register_instance(inst_name, metainfo)); -} - -TEST_F(XllmRpcServiceTest, UpdateInstanceMetainfo) { - RpcServiceConfig config; - auto xllm_service = std::make_shared(config); - std::string inst_name = "127.0.0.1@nic0"; - InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", InstanceType::PREFILL); - EXPECT_EQ(ErrorCode::OK, - xllm_service->register_instance(inst_name, metainfo)); - metainfo.type = InstanceType::DECODE; - EXPECT_EQ(ErrorCode::OK, - xllm_service->update_instance_metainfo(inst_name, metainfo)); - - std::string inst_name2 = "127.0.0.1@nic2"; - InstanceMetaInfo metainfo2( - inst_name2, "127.0.0.1:7778", InstanceType::PREFILL); - EXPECT_EQ(ErrorCode::INSTANCE_NOT_EXISTED, - xllm_service->update_instance_metainfo(inst_name2, metainfo)); -} +// metainfo.type = InstanceType::DECODE; +// EXPECT_EQ(ErrorCode::INSTANCE_EXISTED, +// xllm_service->register_instance(inst_name, metainfo)); +// } + +// TEST_F(XllmRpcServiceTest, UpdateInstanceMetainfo) { +// RpcServiceConfig config; +// HttpServiceConfig http_config; +// ModelConfig model_config; +// auto xllm_service = +// std::make_shared(config, model_config, +// http_config); +// std::string inst_name = "127.0.0.1@nic0"; +// InstanceMetaInfo metainfo(inst_name, "127.0.0.1:7777", +// InstanceType::PREFILL); EXPECT_EQ(ErrorCode::OK, +// xllm_service->register_instance(inst_name, metainfo)); +// metainfo.type = InstanceType::DECODE; +// EXPECT_EQ(ErrorCode::OK, +// xllm_service->update_instance_metainfo(inst_name, metainfo)); + +// std::string inst_name2 = "127.0.0.1@nic2"; +// InstanceMetaInfo metainfo2( +// inst_name2, "127.0.0.1:7778", InstanceType::PREFILL); +// EXPECT_EQ(ErrorCode::INSTANCE_NOT_EXISTED, +// xllm_service->update_instance_metainfo(inst_name2, metainfo)); +// } } // namespace xllm_service::test diff --git a/xllm_service/rpc_service/scheduler.cpp b/xllm_service/rpc_service/scheduler.cpp new file mode 100644 index 0000000..a66f4a8 --- /dev/null +++ b/xllm_service/rpc_service/scheduler.cpp @@ -0,0 +1,151 @@ +#include "scheduler.h" + +#include +#include +#include + +#include "chat_template/jinja_chat_template.h" +#include "common.pb.h" +#include "common/hash_util.h" +#include "loadbalance_policy/cache_aware_routing.h" +#include "loadbalance_policy/round_robin.h" +#include "tokenizer/tokenizer_factory.h" + +static constexpr int kHeartbeatInterval = 3; // in seconds +static std::string ETCD_MASTER_SERVICE_KEY = "XLLM:SERVICE:MASTER"; + +namespace xllm_service { + +Scheduler::Scheduler(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config) + : rpc_config_(rpc_config), + model_config_(model_config), + http_config_(http_config) { + tokenizer_ = TokenizerFactory::create_tokenizer(model_config_.tokenizer_path, + &tokenizer_args_); + chat_template_ = std::make_unique(tokenizer_args_); + + etcd_client_ = std::make_shared(rpc_config_.etcd_addr); + + if (!etcd_client_->get(ETCD_MASTER_SERVICE_KEY, nullptr)) { + is_master_service_ = etcd_client_->set( + ETCD_MASTER_SERVICE_KEY, rpc_config_.service_name, kHeartbeatInterval); + LOG(INFO) << "Set current service as master!"; + } + + instance_mgr_ = std::make_shared( + etcd_client_, http_config_, is_master_service_); + + global_kvcache_mgr_ = std::make_shared( + etcd_client_, model_config_, is_master_service_); + + if (rpc_config_.load_balance_policy == "CAR") { + lb_policy_ = + std::make_unique(instance_mgr_, global_kvcache_mgr_); + } else { + lb_policy_ = + std::make_unique(instance_mgr_, global_kvcache_mgr_); + } + + if (is_master_service_) { + heartbeat_thread_ = std::make_unique( + &Scheduler::update_master_service_heartbeat, this); + } else { + auto handle_master = std::bind(&Scheduler::handle_master_service_watch, + this, + std::placeholders::_1, + std::placeholders::_2); + etcd_client_->add_watch(ETCD_MASTER_SERVICE_KEY, handle_master); + } +} + +Scheduler::~Scheduler() { etcd_client_->stop_watch(); } + +bool Scheduler::schedule(const ChatMessages& messages, ScheduleResult* res) { + if (chat_template_ == nullptr) { + LOG(ERROR) << "Chat template has not configured for model type: " + << model_config_.model_type; + return false; + } + + auto prompt = chat_template_->apply(messages); + if (!prompt.has_value()) { + LOG(ERROR) << "Failed to construct prompt from messages"; + return false; + } + + return schedule(prompt.value(), res); +} + +bool Scheduler::schedule(const std::string& prompt, ScheduleResult* res) { + if (prompt.size() != 0) { + if (!get_tls_tokenizer()->encode(prompt, &res->token_ids)) { + LOG(ERROR) << "Encode prompt faill: " << prompt; + return false; + } + } + auto ret = lb_policy_->select_instances_pair(res); + DLOG(INFO) << res->routing.debug_string(); + + return ret; +} + +std::shared_ptr Scheduler::get_channel( + const std::string& target_name) { + return instance_mgr_->get_channel(target_name); +} + +void Scheduler::update_master_service_heartbeat() { + while (!exited_) { + std::this_thread::sleep_for(std::chrono::seconds(kHeartbeatInterval)); + + global_kvcache_mgr_->upload_kvcache(); + + instance_mgr_->upload_load_metrics(); + } +} + +void Scheduler::handle_instance_heartbeat(const proto::HeartbeatRequest* req) { + if (exited_) { + return; + } + global_kvcache_mgr_->record_updated_kvcaches(req->name(), req->cache_event()); + instance_mgr_->record_load_metrics_update(req->name(), req->load_metrics()); +} + +void Scheduler::handle_master_service_watch(const etcd::Response& response, + const uint64_t& prefix_len) { + if (exited_ || response.events().empty()) { + return; + } + + if (etcd_client_->set(ETCD_MASTER_SERVICE_KEY, + rpc_config_.service_name, + kHeartbeatInterval)) { + is_master_service_ = true; + + heartbeat_thread_ = std::make_unique( + &Scheduler::update_master_service_heartbeat, this); + + global_kvcache_mgr_->set_as_master(); + instance_mgr_->set_as_master(); + } +} + +InstanceMetaInfo Scheduler::get_instance_info( + const std::string& instance_name) { + return instance_mgr_->get_instance_info(instance_name); +} + +std::vector Scheduler::get_static_decode_list( + const std::string& instance_name) { + return instance_mgr_->get_static_decode_list(instance_name); +} + +Tokenizer* Scheduler::get_tls_tokenizer() { + thread_local std::unique_ptr tls_tokenizer(tokenizer_->clone()); + return tls_tokenizer.get(); +} + +} // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/rpc_service/scheduler.h b/xllm_service/rpc_service/scheduler.h new file mode 100644 index 0000000..429e52d --- /dev/null +++ b/xllm_service/rpc_service/scheduler.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "chat_template/jinja_chat_template.h" +#include "common/hash_util.h" +#include "common/macros.h" +#include "common/types.h" +#include "etcd_client/etcd_client.h" +#include "loadbalance_policy/loadbalance_policy.h" +#include "managers/global_kvcache_mgr.h" +#include "managers/instance_mgr.h" +#include "tokenizer/tokenizer.h" +#include "tokenizer/tokenizer_args.h" +#include "xllm_rpc_service.pb.h" + +namespace xllm_service { + +class Scheduler { + public: + explicit Scheduler(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config); + + ~Scheduler(); + + bool schedule(const ChatMessages& messages, ScheduleResult* res); + + bool schedule(const std::string& prompt, ScheduleResult* res); + + std::shared_ptr get_channel(const std::string& target_name); + + InstanceMetaInfo get_instance_info(const std::string& instance_name); + + std::vector get_static_decode_list( + const std::string& instance_name); + + void handle_instance_heartbeat(const proto::HeartbeatRequest* req); + + void exited() { exited_ = true; } + + private: + DISALLOW_COPY_AND_ASSIGN(Scheduler); + + void update_master_service_heartbeat(); + + void handle_master_service_watch(const etcd::Response& response, + const uint64_t& prefix_len); + + Tokenizer* get_tls_tokenizer(); + + private: + bool exited_ = false; + + bool is_master_service_ = false; + + TokenizerArgs tokenizer_args_; + + RpcServiceConfig rpc_config_; + + ModelConfig model_config_; + + HttpServiceConfig http_config_; + + // chat template instance + std::unique_ptr chat_template_; + + std::shared_ptr etcd_client_; + + std::unique_ptr tokenizer_; + + std::shared_ptr instance_mgr_; + + std::shared_ptr global_kvcache_mgr_; + + std::unique_ptr lb_policy_; + + std::unique_ptr heartbeat_thread_; +}; + +} // namespace xllm_service diff --git a/xllm_service/rpc_service/service.cpp b/xllm_service/rpc_service/service.cpp index e5ca3b6..6e696db 100644 --- a/xllm_service/rpc_service/service.cpp +++ b/xllm_service/rpc_service/service.cpp @@ -53,42 +53,45 @@ grpc::StatusCode to_grpc_status_code(llm::StatusCode code) { } } // namespace -XllmRpcServiceImpl::XllmRpcServiceImpl(const RpcServiceConfig& config) { +XllmRpcServiceImpl::XllmRpcServiceImpl(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config) { enable_decode_response_to_service_ = utils::get_bool_env("ENABLE_DECODE_RESPONSE_TO_SERVICE", false); - instance_mgr_ = std::make_unique(config); + + scheduler_ = + std::make_unique(rpc_config, model_config, http_config); } -XllmRpcServiceImpl::~XllmRpcServiceImpl() {} +XllmRpcServiceImpl::~XllmRpcServiceImpl() { scheduler_->exited(); } -ErrorCode XllmRpcServiceImpl::heartbeat(const std::string& instance_name) { - return instance_mgr_->heartbeat(instance_name); +void XllmRpcServiceImpl::heartbeat(const proto::HeartbeatRequest* req) { + scheduler_->handle_instance_heartbeat(req); } -ErrorCode XllmRpcServiceImpl::register_instance( - const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - return instance_mgr_->register_instance(instance_name, metainfo); +InstanceMetaInfo XllmRpcServiceImpl::get_instance_info( + const std::string& instance_name) { + return scheduler_->get_instance_info(instance_name); } -ErrorCode XllmRpcServiceImpl::update_instance_metainfo( - const std::string& instance_name, - const InstanceMetaInfo& metainfo) { - return instance_mgr_->update_instance_metainfo(instance_name, metainfo); +std::vector XllmRpcServiceImpl::get_static_decode_list( + const std::string& instance_name) { + return scheduler_->get_static_decode_list(instance_name); } -InstancesPair XllmRpcServiceImpl::select_instances_pair(bool only_prefill) { - return instance_mgr_->select_instances_pair(only_prefill); +bool XllmRpcServiceImpl::schedule(const std::string& prompt, + ScheduleResult* res) { + return scheduler_->schedule(prompt, res); } -InstanceMetaInfo XllmRpcServiceImpl::get_instance_info( - const std::string& instance_name) { - return instance_mgr_->get_instance_info(instance_name); +bool XllmRpcServiceImpl::schedule(const ChatMessages& messages, + ScheduleResult* res) { + return scheduler_->schedule(messages, res); } -std::vector XllmRpcServiceImpl::get_static_decode_list( - const std::string& instance_name) { - return instance_mgr_->get_static_decode_list(instance_name); +std::shared_ptr XllmRpcServiceImpl::get_channel( + const std::string& target_name) { + return scheduler_->get_channel(target_name); } bool XllmRpcServiceImpl::handle_generation( @@ -274,32 +277,6 @@ void XllmRpcService::Hello(google::protobuf::RpcController* cntl_base, resp->set_ok(true); } -void XllmRpcService::RegisterInstance( - google::protobuf::RpcController* cntl_base, - const proto::InstanceMetaInfo* req, - proto::StatusCode* resp, - google::protobuf::Closure* done) { - brpc::ClosureGuard done_guard(done); - InstanceType type = InstanceType::DEFAULT; - if (req->has_type() && req->type() == proto::InstanceType::PREFILL) { - type = InstanceType::PREFILL; - } else if (req->has_type() && req->type() == proto::InstanceType::DECODE) { - type = InstanceType::DECODE; - } - InstanceMetaInfo metainfo(req->name(), req->rpc_address(), type); - metainfo.cluster_ids = std::vector(req->cluster_ids().begin(), - req->cluster_ids().end()); - metainfo.addrs = - std::vector(req->addrs().begin(), req->addrs().end()); - metainfo.k_cache_ids = std::vector(req->k_cache_ids().begin(), - req->k_cache_ids().end()); - metainfo.v_cache_ids = std::vector(req->v_cache_ids().begin(), - req->v_cache_ids().end()); - metainfo.dp_size = req->dp_size(); - ErrorCode code = xllm_service_->register_instance(req->name(), metainfo); - resp->set_status_code(ConvertErrorCode::to_int(code)); -} - void XllmRpcService::GetInstanceInfo(google::protobuf::RpcController* cntl_base, const proto::InstanceID* req, proto::InstanceMetaInfo* resp, @@ -335,9 +312,7 @@ void XllmRpcService::Heartbeat(google::protobuf::RpcController* cntl_base, proto::Status* resp, google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); - auto inst_name = req->name(); - // TODO: handle req.cache_event and req.load_metrics - xllm_service_->heartbeat(inst_name); + xllm_service_->heartbeat(req); resp->set_ok(true); } diff --git a/xllm_service/rpc_service/service.h b/xllm_service/rpc_service/service.h index 4575e11..2c15a29 100644 --- a/xllm_service/rpc_service/service.h +++ b/xllm_service/rpc_service/service.h @@ -24,8 +24,9 @@ limitations under the License. #include "common/xllm/output.h" #include "common/xllm/status.h" #include "completion.pb.h" -#include "instance_mgr.h" +#include "managers/instance_mgr.h" #include "response_handler.h" +#include "scheduler.h" #include "xllm_rpc_service.pb.h" namespace xllm_service { @@ -41,25 +42,17 @@ struct ServiceConfig { class XllmRpcServiceImpl final { public: - XllmRpcServiceImpl(const RpcServiceConfig& config); + XllmRpcServiceImpl(const RpcServiceConfig& rpc_config, + const ModelConfig& model_config, + const HttpServiceConfig& http_config); ~XllmRpcServiceImpl(); - ErrorCode heartbeat(const std::string& instance_name); - ErrorCode register_instance(const std::string& instance_name, - const InstanceMetaInfo& metainfo); - ErrorCode update_instance_metainfo(const std::string& instance_name, - const InstanceMetaInfo& metainfo); + void heartbeat(const proto::HeartbeatRequest* req); InstanceMetaInfo get_instance_info(const std::string& instance_name); ServiceConfig get_config(); - // methods for master - - // select instances(prefill/decode/default etc.) to handle request - // according the disagg pd policy (or some other policies.). - InstancesPair select_instances_pair(bool only_prefill = false); - std::vector get_static_decode_list( const std::string& prefill_name); @@ -82,9 +75,13 @@ class XllmRpcServiceImpl final { bool include_usage); void finish_request(const std::string& service_request_id); - private: - std::unique_ptr instance_mgr_; + bool schedule(const std::string& prompt, ScheduleResult* res); + + bool schedule(const ChatMessages& messages, ScheduleResult* res); + std::shared_ptr get_channel(const std::string& target_name); + + private: // `request` -> `callback` map std::unordered_map callbacks_; std::mutex callback_mutex_; @@ -114,6 +111,9 @@ class XllmRpcServiceImpl final { // used when receive token from decode instance. ResponseHandler response_handler_; + + // instance discovery by register to etcd + std::unique_ptr scheduler_; }; // parse proto data and call XllmRpcService @@ -127,11 +127,6 @@ class XllmRpcService : public proto::XllmRpcService { proto::Status* resp, google::protobuf::Closure* done) override; - virtual void RegisterInstance(google::protobuf::RpcController* cntl_base, - const proto::InstanceMetaInfo* req, - proto::StatusCode* resp, - google::protobuf::Closure* done) override; - virtual void Heartbeat(google::protobuf::RpcController* cntl_base, const proto::HeartbeatRequest* req, proto::Status* resp, diff --git a/xllm_service/tokenizer/tokenizer_args.cpp b/xllm_service/tokenizer/tokenizer_args.cpp index e463296..28f7bb3 100644 --- a/xllm_service/tokenizer/tokenizer_args.cpp +++ b/xllm_service/tokenizer/tokenizer_args.cpp @@ -27,7 +27,7 @@ std::optional load_chat_template_file(const std::string& dir) { } } // namespace -bool load_tokenizer_args(const std::string& model_weights_path, +void load_tokenizer_args(const std::string& model_weights_path, TokenizerArgs& tokenizer_args) { // tokenizer args from tokenizer_config.json JsonReader tokenizer_reader; @@ -68,8 +68,6 @@ bool load_tokenizer_args(const std::string& model_weights_path, tokenizer_args.pad_token() = v.value(); } } - - return true; } } // namespace xllm_service \ No newline at end of file diff --git a/xllm_service/tokenizer/tokenizer_args.h b/xllm_service/tokenizer/tokenizer_args.h index c1951e1..34e6443 100644 --- a/xllm_service/tokenizer/tokenizer_args.h +++ b/xllm_service/tokenizer/tokenizer_args.h @@ -93,7 +93,7 @@ inline std::ostream& operator<<(std::ostream& os, const TokenizerArgs& args) { return os; } -bool load_tokenizer_args(const std::string& model_weights_path, +void load_tokenizer_args(const std::string& model_weights_path, TokenizerArgs& tokenizer_args); } // namespace xllm_service diff --git a/xllm_service/tokenizer/tokenizer_factory.cpp b/xllm_service/tokenizer/tokenizer_factory.cpp index 204d868..120cb77 100644 --- a/xllm_service/tokenizer/tokenizer_factory.cpp +++ b/xllm_service/tokenizer/tokenizer_factory.cpp @@ -2,28 +2,32 @@ #include +#include "tokenizer_args.h" + namespace xllm_service { std::unique_ptr TokenizerFactory::create_tokenizer( const std::string& model_weights_path, - TokenizerArgs tokenizer_args) { + TokenizerArgs* tokenizer_args) { + load_tokenizer_args(model_weights_path, *tokenizer_args); + const std::string tokenizer_json_path = model_weights_path + "/tokenizer.json"; if (std::filesystem::exists(tokenizer_json_path)) { // 1. fast tokenizer LOG(INFO) << "Create fast tokenizer."; return std::make_unique(tokenizer_json_path); - } else if (tokenizer_args.tokenizer_type() == "tiktoken" || - tokenizer_args.tokenizer_class() == "TikTokenTokenizer") { + } else if (tokenizer_args->tokenizer_type() == "tiktoken" || + tokenizer_args->tokenizer_class() == "TikTokenTokenizer") { // 2. create tiktoken tokenizer LOG(INFO) << "Create Tiktoken tokenizer."; return std::make_unique(model_weights_path, - tokenizer_args); + *tokenizer_args); } else { // 3. create sentencepiece tokenizer LOG(INFO) << "Create SentencePiece tokenizer."; return std::make_unique(model_weights_path, - tokenizer_args); + *tokenizer_args); } } diff --git a/xllm_service/tokenizer/tokenizer_factory.h b/xllm_service/tokenizer/tokenizer_factory.h index 2120071..5afd20d 100644 --- a/xllm_service/tokenizer/tokenizer_factory.h +++ b/xllm_service/tokenizer/tokenizer_factory.h @@ -11,7 +11,7 @@ class TokenizerFactory { public: static std::unique_ptr create_tokenizer( const std::string& model_weights_path, - TokenizerArgs tokenizer_args); + TokenizerArgs* tokenizer_args); }; } // namespace xllm_service