-
Notifications
You must be signed in to change notification settings - Fork 276
Dp balancer #991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Dp balancer #991
Changes from 2 commits
8eda8ab
86df27c
54cd9ac
74cfa55
ea0ada4
d3f12d0
dc1e2f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .dp_base_balancer import RoundRobinDpBalancer | ||
from typing import List | ||
from lightllm.server.router.req_queue.base_queue import BaseQueue | ||
from .dp_balancer_for_pd import DpBalancerForPd | ||
|
||
|
||
def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]): | ||
if args.dp_balancer == "round_robin": | ||
return DpBalancerForPd(dp_size_in_node, inner_queues) | ||
if args.run_mode in ["prefill", "decode"]: | ||
return DpBalancerForPd(dp_size_in_node, inner_queues) | ||
else: | ||
raise ValueError(f"Invalid dp balancer: {args.dp_balancer}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import List, Union | ||
from lightllm.server.router.req_queue.base_queue import BaseQueue | ||
from lightllm.server.router.batch import Batch, Req | ||
from lightllm.utils.log_utils import init_logger | ||
from .dp_base_balancer import DpBalancer | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class DpBalancerForPd(DpBalancer): | ||
""" | ||
This balancer is main to balance the batch size of each dp rank. | ||
Because, for dp mode, if it exists a dp rank without any request, it will | ||
padding a request and cause the waste of GPU compute resource. | ||
""" | ||
|
||
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): | ||
super().__init__(dp_size_in_node, inner_queues) | ||
|
||
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: | ||
if len(reqs_waiting_for_dp_index) == 0: | ||
return | ||
# calculate the total load of each dp rank | ||
if current_batch is not None: | ||
all_dp_req_num = current_batch.get_all_dp_req_num() | ||
total_load_per_dp = [ | ||
all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node) | ||
] | ||
else: | ||
total_load_per_dp = [len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)] | ||
for req_group in reqs_waiting_for_dp_index: | ||
# calculate the length of this request group | ||
if isinstance(req_group, list): | ||
req_length = len(req_group) | ||
else: | ||
req_length = 1 | ||
|
||
# find the dp rank with minimum load | ||
min_load = min(total_load_per_dp) | ||
select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load] | ||
|
||
# select the dp rank with the minimum load | ||
if len(select_dp_indexes) == 1: | ||
suggested_dp_index = select_dp_indexes[0] | ||
else: | ||
# if multiple dp ranks have the same minimum load, randomly select one | ||
import random | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
suggested_dp_index = random.choice(select_dp_indexes) | ||
|
||
# assign the request to the dp rank and update the load count | ||
if not isinstance(req_group, list): | ||
req_group = [req_group] | ||
|
||
for req in req_group: | ||
req.sample_params.suggested_dp_index = suggested_dp_index | ||
self.inner_queues[suggested_dp_index].append(req) | ||
|
||
# update the load count for this dp rank | ||
total_load_per_dp[suggested_dp_index] += req_length | ||
|
||
reqs_waiting_for_dp_index.clear() | ||
return |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import random | ||
from abc import ABC, abstractmethod | ||
from typing import List, Union | ||
from lightllm.server.router.req_queue.base_queue import BaseQueue | ||
from lightllm.server.router.batch import Batch, Req | ||
from lightllm.utils.log_utils import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class DpBalancer(ABC): | ||
""" | ||
DP负载均衡器基类 | ||
定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法 | ||
""" | ||
|
||
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): | ||
self.dp_size_in_node = dp_size_in_node | ||
self.inner_queues = inner_queues | ||
self.pre_select_dp_index = self.dp_size_in_node - 1 | ||
|
||
@abstractmethod | ||
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: | ||
pass | ||
|
||
|
||
class RoundRobinDpBalancer(DpBalancer): | ||
""" | ||
轮询负载均衡器 | ||
在队列长度最小的DP中进行轮询选择 | ||
""" | ||
|
||
def get_suggest_dp_index( | ||
self, | ||
) -> int: | ||
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues) | ||
select_dp_indexes = [ | ||
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length | ||
] | ||
|
||
# 如果没有可选择的索引,随机选择一个 | ||
if not select_dp_indexes: | ||
self.pre_select_dp_index = random.randint(0, self.dp_size_in_node - 1) | ||
return self.pre_select_dp_index | ||
|
||
# 轮询选择 | ||
for i in range(self.dp_size_in_node): | ||
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node | ||
if next_dp_index in select_dp_indexes: | ||
self.pre_select_dp_index = next_dp_index | ||
return self.pre_select_dp_index | ||
|
||
self.pre_select_dp_index = random.choice(select_dp_indexes) | ||
return self.pre_select_dp_index | ||
|
||
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: | ||
for req_group in reqs_waiting_for_dp_index: | ||
suggested_dp_index = self.get_suggest_dp_index() | ||
if not isinstance(req_group, list): | ||
req_group = [req_group] | ||
for req in req_group: | ||
req.sample_params.suggested_dp_index = suggested_dp_index | ||
self.inner_queues[suggested_dp_index].append(req) | ||
reqs_waiting_for_dp_index.clear() | ||
return |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,6 +2,7 @@ | |||||||||||
from typing import List | ||||||||||||
from ..batch import Batch, Req | ||||||||||||
from lightllm.server.router.req_queue.base_queue import BaseQueue | ||||||||||||
from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer | ||||||||||||
from lightllm.common.basemodel.infer_lock import g_router_lock | ||||||||||||
from lightllm.utils.log_utils import init_logger | ||||||||||||
|
||||||||||||
|
@@ -12,14 +13,14 @@ class DpQueue: | |||||||||||
def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: | ||||||||||||
self.dp_size_in_node = dp_size_in_node | ||||||||||||
self.base_queue_class = base_queue_class | ||||||||||||
self.pre_select_dp_index = self.dp_size_in_node - 1 | ||||||||||||
from lightllm.server.router.manager import RouterManager | ||||||||||||
|
||||||||||||
self.router: RouterManager = router | ||||||||||||
self.inner_queues: List[BaseQueue] = [ | ||||||||||||
base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node) | ||||||||||||
] | ||||||||||||
|
||||||||||||
self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues) | ||||||||||||
self.reqs_waiting_for_dp_index = [] | ||||||||||||
return | ||||||||||||
|
||||||||||||
def get_dp_queue(self, dp_index: int): | ||||||||||||
|
@@ -31,10 +32,16 @@ def get_wait_req_num(self): | |||||||||||
|
||||||||||||
# @calculate_time(show=True, min_cost_ms=10) | ||||||||||||
def generate_new_batch(self, current_batch: Batch): | ||||||||||||
batches = [ | ||||||||||||
self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node) | ||||||||||||
] | ||||||||||||
return self._merge_batch(batches) | ||||||||||||
try: | ||||||||||||
self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index) | ||||||||||||
batches = [ | ||||||||||||
self.inner_queues[dp_index].generate_new_batch(current_batch) | ||||||||||||
for dp_index in range(self.dp_size_in_node) | ||||||||||||
] | ||||||||||||
return self._merge_batch(batches) | ||||||||||||
except Exception as e: | ||||||||||||
logger.error(f"generate new batch failed: {e}") | ||||||||||||
raise e | ||||||||||||
Comment on lines
+43
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When an exception occurs during batch generation, using
Suggested change
|
||||||||||||
|
||||||||||||
def _merge_batch(self, dp_batches: List[Batch]): | ||||||||||||
merged_batch: Batch = None | ||||||||||||
|
@@ -48,28 +55,20 @@ def _merge_batch(self, dp_batches: List[Batch]): | |||||||||||
def append(self, req: Req): | ||||||||||||
suggested_dp_index = req.sample_params.suggested_dp_index | ||||||||||||
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: | ||||||||||||
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") | ||||||||||||
suggested_dp_index = self._get_suggest_dp_index() | ||||||||||||
self.pre_select_dp_index = suggested_dp_index | ||||||||||||
req.sample_params.suggested_dp_index = suggested_dp_index | ||||||||||||
self.inner_queues[suggested_dp_index].append(req) | ||||||||||||
# 在调度时,统一分配请求id | ||||||||||||
self.reqs_waiting_for_dp_index.append(req) | ||||||||||||
else: | ||||||||||||
self.inner_queues[suggested_dp_index].append(req) | ||||||||||||
return | ||||||||||||
|
||||||||||||
def extend(self, req_group: List[Req]): | ||||||||||||
# 同一个组的,要分配在同一个 dp 上,效率最高 | ||||||||||||
index = self._get_suggest_dp_index() | ||||||||||||
for req in req_group: | ||||||||||||
suggested_dp_index = req.sample_params.suggested_dp_index | ||||||||||||
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: | ||||||||||||
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") | ||||||||||||
self.pre_select_dp_index = index | ||||||||||||
req.sample_params.suggested_dp_index = index | ||||||||||||
self.inner_queues[index].append(req) | ||||||||||||
else: | ||||||||||||
suggested_dp_index = req_group[0].sample_params.suggested_dp_index | ||||||||||||
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: | ||||||||||||
# 同一个组的,要分配在同一个 dp 上 | ||||||||||||
self.reqs_waiting_for_dp_index.append(req_group) | ||||||||||||
else: | ||||||||||||
for req in req_group: | ||||||||||||
self.inner_queues[suggested_dp_index].append(req) | ||||||||||||
|
||||||||||||
return | ||||||||||||
|
||||||||||||
def is_busy(self): | ||||||||||||
|
@@ -87,21 +86,3 @@ def update_token_load(self, current_batch: Batch, force_update=False): | |||||||||||
self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index) | ||||||||||||
self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index) | ||||||||||||
return | ||||||||||||
|
||||||||||||
def _get_suggest_dp_index(self): | ||||||||||||
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues) | ||||||||||||
select_dp_indexes = [ | ||||||||||||
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length | ||||||||||||
] | ||||||||||||
|
||||||||||||
# multi thread safe keep | ||||||||||||
if not select_dp_indexes: | ||||||||||||
return random.randint(0, self.dp_size_in_node - 1) | ||||||||||||
|
||||||||||||
# round_robin select. | ||||||||||||
for i in range(self.dp_size_in_node): | ||||||||||||
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node | ||||||||||||
if next_dp_index in select_dp_indexes: | ||||||||||||
return next_dp_index | ||||||||||||
|
||||||||||||
return random.choice(select_dp_indexes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
get_dp_balancer
function returnsDpBalancerForPd
whenargs.dp_balancer
is "round_robin". This seems incorrect as it should returnRoundRobinDpBalancer
in this case. This could lead to unexpected behavior. Consider swapping the return values for theround_robin
case to ensure the correct balancer is used.