Skip to content

Commit 90c2fb5

Browse files
committed
feat: update services to support cache aware routing.
1 parent 4b5e3ae commit 90c2fb5

File tree

10 files changed

+214
-211
lines changed

10 files changed

+214
-211
lines changed

xllm_service/common/global_gflags.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ DEFINE_int32(http_server_max_concurrency,
1818
128,
1919
"Limit number of requests processed in parallel");
2020

21-
DEFINE_string(rpc_server_host,
22-
"",
23-
"Rpc server listen address, may be IPV4/IPV6/UDS."
24-
" If this is set, the flag port will be ignored");
25-
2621
DEFINE_int32(rpc_server_port, 8889, "Port for xllm rpc service to listen on");
2722

2823
DEFINE_int32(rpc_server_idle_timeout_s,

xllm_service/common/global_gflags.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ DECLARE_int32(http_server_num_threads);
1212

1313
DECLARE_int32(http_server_max_concurrency);
1414

15-
DECLARE_string(rpc_server_host);
16-
1715
DECLARE_int32(rpc_server_port);
1816

1917
DECLARE_int32(rpc_server_idle_timeout_s);

xllm_service/http_service/service.cpp

Lines changed: 64 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -48,56 +48,6 @@ XllmHttpServiceImpl::XllmHttpServiceImpl(const HttpServiceConfig& config)
4848

4949
XllmHttpServiceImpl::~XllmHttpServiceImpl() {}
5050

51-
bool XllmHttpServiceImpl::create_channel(const std::string& target_uri) {
52-
std::lock_guard<std::mutex> guard(channel_mutex_);
53-
if (cached_channels_.find(target_uri) == cached_channels_.end()) {
54-
brpc::Channel* channel = new brpc::Channel();
55-
brpc::ChannelOptions options;
56-
// Add to params
57-
options.protocol = "http";
58-
options.timeout_ms = config_.timeout_ms; /*milliseconds*/
59-
options.max_retry = 3;
60-
std::string load_balancer = "";
61-
if (channel->Init(target_uri.c_str(), load_balancer.c_str(), &options) !=
62-
0) {
63-
LOG(ERROR) << "Fail to initialize channel for " << target_uri;
64-
return false;
65-
}
66-
cached_channels_[target_uri] = channel;
67-
}
68-
69-
return true;
70-
}
71-
72-
std::string XllmHttpServiceImpl::get_redirect_uri(bool only_prefill) {
73-
std::string target_instance_addr;
74-
if (!rpc_service_) {
75-
// for testing
76-
if (config_.test_instance_addr.empty()) {
77-
LOG(ERROR) << "Rpc service is not start.";
78-
return "";
79-
}
80-
target_instance_addr = config_.test_instance_addr;
81-
} else {
82-
InstancesPair instances_pair =
83-
rpc_service_->select_instances_pair(only_prefill);
84-
if (instances_pair.prefill_instance_http_addr.empty()) {
85-
LOG(ERROR) << "No prefill instance available.";
86-
return "";
87-
}
88-
target_instance_addr = instances_pair.prefill_instance_http_addr;
89-
90-
if (!only_prefill) {
91-
if (instances_pair.decode_instance_http_addr.empty()) {
92-
// TODO:
93-
}
94-
// TODO: add instances_pair.decode_instance_http_addr to request?
95-
}
96-
}
97-
98-
return target_instance_addr;
99-
}
100-
10151
void XllmHttpServiceImpl::Hello(::google::protobuf::RpcController* controller,
10252
const proto::HttpHelloRequest* request,
10353
proto::HttpHelloResponse* response,
@@ -198,7 +148,8 @@ void XllmHttpServiceImpl::handle(std::shared_ptr<T> call_data,
198148

199149
// async redistribute the request and wait the response
200150
// TODO: optimize the thread pool to async mode.
201-
auto channel_ptr = cached_channels_[target_uri];
151+
brpc::Channel* channel_ptr = rpc_service_->get_channel(target_uri).get();
152+
202153
// send request to prefill instance.
203154
thread_pool_->schedule([this,
204155
service_request_id,
@@ -360,24 +311,6 @@ void XllmHttpServiceImpl::post_serving(
360311
// create xllm_service request_id: service_request_id
361312
std::string service_request_id = generate_service_request_id(serving_method);
362313
json_value["service_request_id"] = service_request_id;
363-
std::string req_attachment = json_value.dump();
364-
request_tracer_->log(service_request_id, req_attachment);
365-
366-
// redistribute the request to the correct P/D instance
367-
// TODO: redistribute policy to select the instance
368-
std::string target_uri = get_redirect_uri();
369-
if (target_uri.empty()) {
370-
cntl->SetFailed(
371-
"Internal runtime error, can not found a running instance.");
372-
return;
373-
}
374-
if (cached_channels_.find(target_uri) == cached_channels_.end()) {
375-
if (!create_channel(target_uri)) {
376-
LOG(ERROR) << "Create channel failed, target_uri is " << target_uri;
377-
cntl->SetFailed("Internal runtime error.");
378-
return;
379-
}
380-
}
381314

382315
std::function<void(const std::string&)> trace_callback;
383316
if (config_.enable_request_trace) {
@@ -388,33 +321,82 @@ void XllmHttpServiceImpl::post_serving(
388321
trace_callback = nullptr;
389322
}
390323

324+
SchduleResult schedule_res;
391325
if (serving_method == "/v1/completions") {
326+
if (json_value.contains("prompt")) {
327+
if (!rpc_service_->schedule(json_value.at("prompt").get<std::string>(),
328+
&schedule_res)) {
329+
cntl->SetFailed("Schedule fail!");
330+
LOG(ERROR) << "XllmRpcServiceImpl::schedule error!";
331+
return;
332+
}
333+
} else {
334+
cntl->SetFailed("Input has no prompt!");
335+
LOG(ERROR) << "Input has no prompt!";
336+
return;
337+
}
338+
json_value["token_ids"] = schedule_res.token_ids;
339+
json_value["routing"] = schedule_res.routing.serialize_to_json();
340+
341+
std::string req_attachment = json_value.dump();
392342
auto arena = response->GetArena();
393343
auto resp_pb =
394344
google::protobuf::Arena::CreateMessage<llm::proto::CompletionResponse>(
395345
arena);
396346
auto call_data = std::make_shared<CompletionCallData>(
397-
cntl, stream, done_guard.release(), resp_pb, trace_callback);
347+
cntl, stream, done_guard.release(), resp_pb);
398348
handle_v1_completions(call_data,
399349
req_attachment,
400350
service_request_id,
401351
stream,
402352
model,
403353
include_usage,
404-
target_uri);
354+
schedule_res.routing.prefill_name);
405355
} else if (serving_method == "/v1/chat/completions") {
356+
if (json_value.contains("messages") && json_value["messages"].is_array()) {
357+
ChatMessages messages;
358+
try {
359+
const auto& msgs = json_value["messages"];
360+
messages.reserve(msgs.size());
361+
for (const auto& msg : msgs) {
362+
if (msg.contains("role") && msg["role"].is_string() &&
363+
msg.contains("content") && msg["content"].is_string()) {
364+
messages.emplace_back(msg["role"].get<std::string>(),
365+
msg["content"].get<std::string>());
366+
}
367+
}
368+
} catch (const nlohmann::json::exception& e) {
369+
cntl->SetFailed("Parse request fail, Invalid messages!");
370+
LOG(ERROR) << "Parse request fail, Invalid messages!";
371+
return;
372+
}
373+
374+
if (!rpc_service_->schedule(messages, &schedule_res)) {
375+
cntl->SetFailed("Schedule fail!");
376+
LOG(ERROR) << "XllmRpcServiceImpl::schedule error!";
377+
return;
378+
}
379+
} else {
380+
cntl->SetFailed("Input has no messages!");
381+
LOG(ERROR) << "Input has no messages!";
382+
return;
383+
}
384+
json_value["token_ids"] = schedule_res.token_ids;
385+
json_value["routing"] = schedule_res.routing.serialize_to_json();
386+
387+
std::string req_attachment = json_value.dump();
406388
auto arena = response->GetArena();
407389
auto resp_pb =
408390
google::protobuf::Arena::CreateMessage<llm::proto::ChatResponse>(arena);
409391
auto call_data = std::make_shared<ChatCallData>(
410-
cntl, stream, done_guard.release(), resp_pb, trace_callback);
392+
cntl, stream, done_guard.release(), resp_pb);
411393
handle_v1_chat_completions(call_data,
412394
req_attachment,
413395
service_request_id,
414396
stream,
415397
model,
416398
include_usage,
417-
target_uri);
399+
schedule_res.routing.prefill_name);
418400
} else {
419401
LOG(ERROR) << "Not supported method: " << serving_method;
420402
cntl->SetFailed("Not supported method: " + serving_method);
@@ -456,22 +438,18 @@ void XllmHttpServiceImpl::get_serving(
456438
// done_guard.release());
457439
auto call_data = std::make_shared<CompletionCallData>(
458440
cntl, false, done_guard.release(), nullptr);
459-
std::string target_uri = get_redirect_uri(true /*only_prefill*/);
460-
if (target_uri.empty()) {
461-
cntl->SetFailed(
462-
"Internal runtime error, can not found a running instance.");
441+
442+
SchduleResult schedule_res;
443+
if (!rpc_service_->schedule("", &schedule_res)) {
444+
cntl->SetFailed("Schedule fail!");
445+
LOG(ERROR) << "XllmRpcServiceImpl::schedule error!";
463446
return;
464447
}
465-
if (cached_channels_.find(target_uri) == cached_channels_.end()) {
466-
if (!create_channel(target_uri)) {
467-
LOG(ERROR) << "Create channel failed, target_uri is " << target_uri;
468-
cntl->SetFailed("Internal runtime error.");
469-
return;
470-
}
471-
}
472448

473-
auto channel_ptr = cached_channels_[target_uri];
474-
target_uri += serving_method;
449+
brpc::Channel* channel_ptr =
450+
rpc_service_->get_channel(schedule_res.routing.prefill_name).get();
451+
std::string target_uri = schedule_res.routing.prefill_name + serving_method;
452+
475453
thread_pool_->schedule(
476454
[/*req_attachment, */ call_data, cntl, channel_ptr, target_uri]() {
477455
brpc::Controller* redirect_cntl = new brpc::Controller();

xllm_service/http_service/service.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ class XllmHttpServiceImpl : public proto::XllmHttpService {
6262

6363
private:
6464
bool create_channel(const std::string& target_uri);
65-
// only prefill is true means only prefill instance is returned
66-
std::string get_redirect_uri(bool only_prefill = false);
65+
6766
void post_serving(const std::string& serving_method,
6867
::google::protobuf::RpcController* controller,
6968
const proto::HttpRequest* request,
@@ -109,13 +108,8 @@ class XllmHttpServiceImpl : public proto::XllmHttpService {
109108
std::shared_ptr<XllmRpcServiceImpl> rpc_service_;
110109

111110
std::unique_ptr<RequestTracer> request_tracer_;
112-
// uri -> channel
113-
// e.g. 127.0.0.1:9999/v1/completions -> channel1
114-
// 127.0.0.1:9999/v1/chat/completions -> channel2
115-
// NOTE: different methods to one instance has different channels
116-
std::unordered_map<std::string, brpc::Channel*> cached_channels_;
111+
117112
std::unique_ptr<ThreadPool> thread_pool_;
118-
std::mutex channel_mutex_;
119113

120114
// In disagg pd mode, we support receive generated token from
121115
// prefill or from decode directly.

xllm_service/master.cpp

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "master.h"
22

3+
#include <boost/asio.hpp>
34
#include <csignal>
45

56
#include "common/global_gflags.h"
@@ -20,14 +21,26 @@ Master::Master(const ServerOptions& server_options)
2021
rpc_config.detect_disconnected_instance_interval =
2122
server_options.detect_disconnected_instance_interval;
2223

23-
rpc_service_impl_ =
24-
std::make_shared<xllm_service::XllmRpcServiceImpl>(rpc_config);
25-
rpc_service_ =
26-
std::make_unique<xllm_service::XllmRpcService>(rpc_service_impl_);
24+
rpc_config.service_name = server_options_.rpc_server_host + ":" +
25+
std::to_string(server_options_.rpc_port);
26+
27+
ModelConfig model_config;
28+
model_config.block_size = server_options.block_size;
29+
model_config.model_type = server_options.model_type;
30+
model_config.tokenizer_path = server_options.tokenizer_path;
2731

28-
HttpServiceConfig http_config;
32+
xllm_service::HttpServiceConfig http_config;
2933
http_config.num_threads = server_options.http_num_threads;
34+
http_config.timeout_ms = server_options.timeout_ms;
35+
http_config.test_instance_addr = server_options.test_instance_addr;
3036
http_config.enable_request_trace = server_options.enable_request_trace;
37+
38+
rpc_service_impl_ = std::make_shared<xllm_service::XllmRpcServiceImpl>(
39+
rpc_config, model_config, http_config);
40+
41+
rpc_service_ =
42+
std::make_unique<xllm_service::XllmRpcService>(rpc_service_impl_);
43+
3144
http_service_ = std::make_unique<xllm_service::XllmHttpServiceImpl>(
3245
rpc_service_impl_, http_config);
3346
}
@@ -145,13 +158,37 @@ void shutdown_handler(int signal) {
145158
exit(1);
146159
}
147160

161+
std::string get_local_ip() {
162+
using namespace boost::asio;
163+
io_service io;
164+
ip::tcp::resolver resolver(io);
165+
ip::tcp::resolver::query query(ip::host_name(), "");
166+
ip::tcp::resolver::iterator iter = resolver.resolve(query);
167+
ip::tcp::resolver::iterator end;
168+
169+
while (iter != end) {
170+
ip::address addr = iter->endpoint().address();
171+
if (!addr.is_loopback() && addr.is_v4()) {
172+
return addr.to_string();
173+
}
174+
++iter;
175+
}
176+
177+
LOG(FATAL) << "Get local ip faill!";
178+
return "";
179+
}
180+
148181
int main(int argc, char* argv[]) {
149182
// Initialize gflags
150183
gflags::ParseCommandLineFlags(&argc, &argv, true);
151184

152185
// Initialize glog
153186
google::InitGoogleLogging(argv[0]);
154-
FLAGS_logtostderr = true;
187+
// FLAGS_logtostderr = true;
188+
189+
LOG(INFO) << "Dump all gflags: " << std::endl
190+
<< google::CommandlineFlagsIntoString();
191+
google::FlushLogFiles(google::INFO);
155192

156193
LOG(INFO) << "Starting xllm master service.";
157194

@@ -176,7 +213,7 @@ int main(int argc, char* argv[]) {
176213
server_options.http_idle_timeout_s = FLAGS_http_server_idle_timeout_s;
177214
server_options.http_num_threads = FLAGS_http_server_num_threads;
178215
server_options.http_max_concurrency = FLAGS_http_server_max_concurrency;
179-
server_options.rpc_server_host = FLAGS_rpc_server_host;
216+
server_options.rpc_server_host = get_local_ip();
180217
server_options.rpc_port = FLAGS_rpc_server_port;
181218
server_options.rpc_idle_timeout_s = FLAGS_rpc_server_idle_timeout_s;
182219
server_options.rpc_num_threads = FLAGS_rpc_server_num_threads;
@@ -186,10 +223,16 @@ int main(int argc, char* argv[]) {
186223
server_options.detect_disconnected_instance_interval =
187224
FLAGS_detect_disconnected_instance_interval;
188225
server_options.enable_request_trace = FLAGS_enable_request_trace;
226+
227+
server_options.tokenizer_path = FLAGS_tokenizer_path;
189228
server_options.block_size = FLAGS_block_size;
190229
server_options.model_type = FLAGS_model_type;
191230
server_options.tokenizer_path = FLAGS_tokenizer_path;
192231

232+
server_options.num_threads = FLAGS_num_threads;
233+
server_options.timeout_ms = FLAGS_timeout_ms;
234+
server_options.test_instance_addr = FLAGS_test_instance_addr;
235+
193236
xllm_service::Master master(server_options);
194237

195238
if (!master.start()) {

xllm_service/master.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ struct ServerOptions {
1717
int32_t http_num_threads = 32;
1818
int32_t http_max_concurrency = 128;
1919
bool enable_request_trace = false;
20+
int num_threads = 16;
21+
int timeout_ms = -1;
22+
std::string test_instance_addr = "";
2023

2124
// rpc server options
2225
std::string rpc_server_host = "";

0 commit comments

Comments
 (0)