diff --git a/litellm/proxy/hooks/__init__.py b/litellm/proxy/hooks/__init__.py index 790ebcd8791..2bedae7fb7a 100644 --- a/litellm/proxy/hooks/__init__.py +++ b/litellm/proxy/hooks/__init__.py @@ -9,6 +9,7 @@ from .max_iterations_limiter import _PROXY_MaxIterationsHandler from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler from .parallel_request_limiter_v3 import _PROXY_MaxParallelRequestsHandler_v3 +from .max_available_capacity_limiter import _PROXY_MaxAvailableCapacityLimiter from .responses_id_security import ResponsesIDSecurity ### CHECK IF ENTERPRISE HOOKS ARE AVAILABLE #### @@ -22,6 +23,7 @@ PROXY_HOOKS = { "max_budget_limiter": _PROXY_MaxBudgetLimiter, "parallel_request_limiter": _PROXY_MaxParallelRequestsHandler_v3, + "max_available_capacity_limiter": _PROXY_MaxAvailableCapacityLimiter, "cache_control_check": _PROXY_CacheControlCheck, "responses_id_security": ResponsesIDSecurity, "litellm_skills": SkillsInjectionHook, @@ -43,6 +45,7 @@ def get_proxy_hook( hook_name: Union[ Literal[ "max_budget_limiter", + "max_available_capacity_limiter", "managed_files", "parallel_request_limiter", "cache_control_check", diff --git a/litellm/proxy/hooks/max_available_capacity_limiter.py b/litellm/proxy/hooks/max_available_capacity_limiter.py new file mode 100644 index 00000000000..85227e8ed99 --- /dev/null +++ b/litellm/proxy/hooks/max_available_capacity_limiter.py @@ -0,0 +1,298 @@ +import datetime +from typing import TYPE_CHECKING, Optional +from fastapi import HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.router import Deployment + + +if TYPE_CHECKING: + from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache + + InternalUsageCache = _InternalUsageCache +else: + InternalUsageCache = object + + + + +DEFAULT_REQUEST_BUDGET = 5 +DEFAULT_TTL_TIME = 300 #SECONDS - 5MIN +BASE_RATE = 0.5 + + + +CACHE_KEY_TOKENS_USED = "{model}:tokens_used" +CACHE_KEY_REQUESTS_USED = "{model}:requests_used" +CACHE_KEY_LAST_REFILL = "{model}:last_refill" +CACHE_KEY_WORKLOAD = "{model}:workload" +CACHE_KEY_TPM_LIMIT = "{model}:tpm_limit" +CACHE_KEY_RPM_LIMIT = "{model}:rpm_limit" + +CACHE_KEY_USER_REQUESTS_LEFT = "{cache_key}:requests_left" +CACHE_KEY_USER_LAST_REFILL = "{cache_key}:last_refill" + +class _PROXY_MaxAvailableCapacityLimiter(CustomLogger): + """ + Limits user token consumption based on available system capacity. + + Tokens are refilled over time at a rate determined by current workload. + Lower workload = faster refill, higher workload = slower refill. + """ + + def __init__(self, internal_usage_cache: InternalUsageCache): + self.cache = internal_usage_cache.dual_cache + + + # ==================== Hooks ==================== + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ) -> None: + verbose_proxy_logger.debug("Inside method async_pre_call_hook") + model = data["model"] + api_key = user_api_key_dict.api_key + cache_key = f"{api_key}:{model}" + + verbose_proxy_logger.error(f"POKUS CISLO 1") + try: + user_requests_left = await self._get_user_budget(model, cache_key) + except HTTPException: + raise + + except Exception as e: + + verbose_proxy_logger.error(f"Error in max available capacity rate limiter: {e}, allowing request") + return None # request allowed + + if user_requests_left <= 0: + raise HTTPException(status_code=429, detail={"error": "Model capacity reached for {model}. Priority: {priority}, ..."}) + + + await self.cache.async_increment_cache(CACHE_KEY_USER_REQUESTS_LEFT.format(cache_key=cache_key), -1, ttl=DEFAULT_TTL_TIME) + return None + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ) -> None: + verbose_proxy_logger.debug("Inside method async_post_call_success_hook") + verbose_proxy_logger.debug("MaxAvailableCapacityLimiter: post call success") + verbose_proxy_logger.debug(f"data: {data}") + verbose_proxy_logger.debug(f"response: {response}") + + async def async_log_success_event( + self, kwargs, response_obj, start_time, end_time + ) -> None: + verbose_proxy_logger.debug("Inside method async_log_success_event") + + model = response_obj["model"] + total_tokens = response_obj.get("usage").get("total_tokens", 0) + + await self.cache.async_increment_cache(CACHE_KEY_TOKENS_USED.format(model=model), total_tokens, ttl=DEFAULT_TTL_TIME) + await self.cache.async_increment_cache(CACHE_KEY_REQUESTS_USED.format(model=model), 1, ttl=DEFAULT_TTL_TIME) + return None + + + + + # ==================== Budget Management ==================== + + async def _get_user_budget(self, model: str, cache_key: str) -> int: + """Get existing budget from cache or create new one.""" + verbose_proxy_logger.debug("Inside method _get_user_budget") + + return await self._refill_user_budget(model, cache_key) + + async def _create_user_budget(self, cache_key: str) -> int: + verbose_proxy_logger.debug("Inside method _create_user_budget") + + await self.cache.async_set_cache(CACHE_KEY_USER_REQUESTS_LEFT.format(cache_key=cache_key), DEFAULT_REQUEST_BUDGET, ttl=DEFAULT_TTL_TIME) + await self.cache.async_set_cache(CACHE_KEY_USER_LAST_REFILL.format(cache_key=cache_key), datetime.datetime.now(datetime.timezone.utc).isoformat(), ttl=DEFAULT_TTL_TIME) + + return DEFAULT_REQUEST_BUDGET + + async def _refill_user_budget(self, model: str, cache_key: str) -> int: + """Refill user budget based on elapsed time and current workload.""" + + last_refill = await self.cache.async_get_cache(CACHE_KEY_USER_LAST_REFILL.format(cache_key=cache_key)) + verbose_proxy_logger.debug(f"Inside method _refill_user_budget {last_refill}") + + if last_refill is None: + return await self._create_user_budget(cache_key) + + now = datetime.datetime.now(datetime.timezone.utc) + timestamp = datetime.datetime.fromisoformat(last_refill) + elapsed_seconds = (now - timestamp).total_seconds() + + refill_rate = await self._calculate_refill_rate(model) + requests_to_add = int(elapsed_seconds * refill_rate) + + + await self.cache.async_set_cache(CACHE_KEY_USER_LAST_REFILL.format(cache_key=cache_key), now.isoformat(), ttl=DEFAULT_TTL_TIME) + return int(await self.cache.async_increment_cache(CACHE_KEY_USER_REQUESTS_LEFT.format(cache_key=cache_key), requests_to_add, ttl=DEFAULT_TTL_TIME) or 0) + + # ==================== Refill Rate Calculation ==================== + + async def _calculate_refill_rate(self, model: str) -> float: + """ + Calculate REQUEST refill rate based on system workload. + + Args: + model: model name + + + Returns: + Effective refill rate in requests per second + """ + verbose_proxy_logger.debug("Inside method _calculate_refill_rate") + + workload = await self._get_model_workload(model) + + if workload < 0.5: + # Green zone: 20% bonus (0.12 req/s = 7.2 req/min) + return BASE_RATE * 1.2 + + if workload < 0.8: + # Yellow zone: linear decrease 100% -> 40% + # 0.5 -> 0.1 req/s, 0.8 -> 0.04 req/s + factor = 1.0 - (workload - 0.5) * 2 # 2 = 1/(0.8-0.5) + return BASE_RATE * max(factor, 0.4) + + if workload <= 1: + # Red zone: exponential decrease 40% -> 2% + # 0.8 -> 0.04 req/s, 0.9 -> 0.01 req/s, 1.0 -> 0.002 req/s + factor = 0.4 * ((1.0 - workload) / 0.2) ** 2 + return BASE_RATE * max(factor, 0.02) + + return BASE_RATE * 0.01 + + # ==================== Workload Calculation ==================== + + async def _get_model_workload(self, model: str) -> float: + """Calculate current workload for a model.""" + verbose_proxy_logger.debug("Inside method _get_model_workload") + + await self._refill_model_capacity(model) + + tokens_used, requests_used = await self._fetch_tokens_requests_used_in_window(model) + tpm_limit, rpm_limit = await self._get_model_limits(model) + tpm_limit *= DEFAULT_TTL_TIME // 60 + rpm_limit *= DEFAULT_TTL_TIME // 60 + workload = self._calculate_load(tokens_used, tpm_limit, requests_used, rpm_limit) + + await self.cache.async_set_cache(CACHE_KEY_WORKLOAD.format(model=model), workload, ttl=DEFAULT_TTL_TIME) + return workload + + async def _refill_model_capacity(self, model: str) -> None: + verbose_proxy_logger.debug("Inside method _refill_model_capacity") + + now = datetime.datetime.now(datetime.timezone.utc) + cache_last = CACHE_KEY_LAST_REFILL.format(model=model) + cache_tokens = CACHE_KEY_TOKENS_USED.format(model=model) + cache_requests = CACHE_KEY_REQUESTS_USED.format(model=model) + + last_refill_str = await self.cache.async_get_cache(cache_last) + if last_refill_str is None: + await self.cache.async_set_cache(cache_last, now.isoformat(), ttl=DEFAULT_TTL_TIME) + return + + last_refill = datetime.datetime.fromisoformat(last_refill_str) + elapsed_seconds = (now - last_refill).total_seconds() + + if elapsed_seconds <= 0: + return + + tpm_limit, rpm_limit = await self._get_model_limits(model) + tokens_per_sec = tpm_limit / 60.0 + requests_per_sec = rpm_limit / 60.0 + + did_refill = False + + # --- TOKENS --- + if tokens_per_sec > 0: + tokens_to_refill = int(elapsed_seconds * tokens_per_sec) + if tokens_to_refill > 0: + current_tokens = await self.cache.async_get_cache(cache_tokens) or 0 + new_tokens = max(0, current_tokens - tokens_to_refill) + await self.cache.async_set_cache(cache_tokens, new_tokens, ttl=DEFAULT_TTL_TIME) + did_refill = True + + # --- REQUESTS --- + if requests_per_sec > 0: + requests_to_refill = int(elapsed_seconds * requests_per_sec) + if requests_to_refill > 0: + current_requests = await self.cache.async_get_cache(cache_requests) or 0 + new_requests = max(0, current_requests - requests_to_refill) + await self.cache.async_set_cache(cache_requests, new_requests, ttl=DEFAULT_TTL_TIME) + did_refill = True + + # Update timestamp IBA ak sa realne nieco doplnilo + if did_refill: + await self.cache.async_set_cache(cache_last, now.isoformat(), ttl=DEFAULT_TTL_TIME) + + async def _fetch_tokens_requests_used_in_window(self, model: str) -> tuple[int, int]: + verbose_proxy_logger.debug("Inside method _fetch_tokens_requests_used_in_window") + tokens_used_interval = await self.cache.async_get_cache(CACHE_KEY_TOKENS_USED.format(model=model)) or 0 + requests_used_interval = await self.cache.async_get_cache(CACHE_KEY_REQUESTS_USED.format(model=model)) or 0 + + return tokens_used_interval, requests_used_interval + + async def _get_model_limits(self, model: str) -> tuple[int, int]: + verbose_proxy_logger.debug("Inside method _get_model_limits") + + tpm_limit = await self.cache.async_get_cache(CACHE_KEY_TPM_LIMIT.format(model=model)) + rpm_limit = await self.cache.async_get_cache(CACHE_KEY_RPM_LIMIT.format(model=model)) + + if tpm_limit is not None and rpm_limit is not None: + return tpm_limit, rpm_limit + + deployment = self._get_deployment(model) + + if deployment is None: + return 0, 0 + + if deployment.litellm_params is None: + return 0, 0 + + tpm_limit = deployment.litellm_params.tpm + rpm_limit = deployment.litellm_params.rpm + + await self.cache.async_set_cache(CACHE_KEY_TPM_LIMIT.format(model=model), tpm_limit, ttl=DEFAULT_TTL_TIME) + await self.cache.async_set_cache(CACHE_KEY_RPM_LIMIT.format(model=model), rpm_limit, ttl=DEFAULT_TTL_TIME) + + return (tpm_limit, rpm_limit) if tpm_limit is not None and rpm_limit is not None else (0, 0) + + + def _get_deployment(self, model: str) -> Optional[Deployment]: # TODO CHECK IF NOT BETTER TO CACHE IT + """Get deployment configuration for a model.""" + verbose_proxy_logger.debug("Inside method _get_deployment") + from litellm.proxy.proxy_server import llm_router + + if llm_router is None: + return None + + return llm_router.get_deployment_by_model_group_name(model_group_name=model) + + # ==================== Load Calculation ==================== + + def _calculate_load(self, tokens_used: int, tokens_limit: int, requests_used: int, requests_limit: int) -> float: + verbose_proxy_logger.debug("Inside method _calculate_load") + if tokens_limit == 0 and requests_limit == 0: + return 0.0 + + tpm_load = tokens_used / tokens_limit if tokens_limit > 0 else 0.0 + rpm_load = requests_used / requests_limit if requests_limit > 0 else 0.0 + + verbose_proxy_logger.debug(f"Inside calculate load, tpm load: {tpm_load} rpm load: {rpm_load} tokens used: {tokens_used} requests used: {requests_used}") + + return max(tpm_load, rpm_load)