diff --git a/.gitignore b/.gitignore index d454a74c..92291199 100644 --- a/.gitignore +++ b/.gitignore @@ -4,9 +4,14 @@ __pycache__ /litellm -pipelines/* !pipelines/.gitignore .DS_Store +# Ignore everything in pipelines +pipelines/* + +# But keep files directly inside pipelines/ +!pipelines/*.* .venv -venv/ \ No newline at end of file +venv/ +.idea/ diff --git a/config.py b/config.py index 28b10310..4e758d24 100644 --- a/config.py +++ b/config.py @@ -22,3 +22,5 @@ API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!") PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines") +DATABASE_URL = os.getenv("DATABASE_URL") +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") diff --git a/main.py b/main.py index e277d3a3..9485b135 100644 --- a/main.py +++ b/main.py @@ -2,12 +2,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool - from starlette.responses import StreamingResponse, Response from pydantic import BaseModel, ConfigDict from typing import List, Union, Generator, Iterator - from utils.pipelines.auth import bearer_security, get_current_user from utils.pipelines.main import get_last_user_message, stream_message_template from utils.pipelines.misc import convert_to_raw_url @@ -28,13 +26,13 @@ import sys import subprocess - -from config import API_KEY, PIPELINES_DIR, LOG_LEVELS +from config import API_KEY, PIPELINES_DIR, LOG_LEVELS, DATABASE_URL, REDIS_URL +from utils.pipelines.database import init_database, close_database, is_database_available +from utils.pipelines.redis_client import init_redis, close_redis, is_redis_available if not os.path.exists(PIPELINES_DIR): os.makedirs(PIPELINES_DIR) - PIPELINES = {} PIPELINE_MODULES = {} PIPELINE_NAMES = {} @@ -88,13 +86,13 @@ def get_all_pipelines(): "pipelines": ( pipeline.valves.pipelines if hasattr(pipeline, "valves") - and hasattr(pipeline.valves, "pipelines") + and hasattr(pipeline.valves, "pipelines") else [] ), "priority": ( pipeline.valves.priority if hasattr(pipeline, "valves") - and hasattr(pipeline.valves, "priority") + and hasattr(pipeline.valves, "priority") else 0 ), "valves": pipeline.valves if hasattr(pipeline, "valves") else None, @@ -131,7 +129,6 @@ def install_frontmatter_requirements(requirements): async def load_module_from_path(module_name, module_path): - try: # Read the module content with open(module_path, "r") as file: @@ -224,6 +221,22 @@ async def load_modules_from_directory(directory): async def on_startup(): + # Initialize database if DATABASE_URL is provided + if DATABASE_URL is not None: + try: + await init_database() + except Exception as e: + logging.error(f"Failed to initialize database: {e}") + # Continue without database if initialization fails + + # Initialize Redis if REDIS_URL is provided + if REDIS_URL is not None: + try: + await init_redis() + except Exception as e: + logging.error(f"Failed to initialize Redis: {e}") + # Continue without Redis if it fails + await load_modules_from_directory(PIPELINES_DIR) for module in PIPELINE_MODULES.values(): @@ -236,6 +249,14 @@ async def on_shutdown(): if hasattr(module, "on_shutdown"): await module.on_shutdown() + # Close database connection if it was initialized + if is_database_available(): + await close_database() + + # Close Redis connection if it was initialized + if is_redis_available(): + await close_redis() + async def reload(): await on_shutdown() @@ -258,10 +279,8 @@ async def lifespan(app: FastAPI): app.state.PIPELINES = PIPELINES - origins = ["*"] - app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -324,7 +343,21 @@ async def get_models(user: str = Depends(get_current_user)): @app.get("/v1") @app.get("/") async def get_status(): - return {"status": True} + status_info = {"status": True} + + # Add database status if available + if DATABASE_URL is not None: + status_info["database"] = { + "available": is_database_available(), + "url_configured": True + } + else: + status_info["database"] = { + "available": False, + "url_configured": False + } + + return status_info @app.get("/v1/pipelines") @@ -387,7 +420,7 @@ async def download_file(url: str, dest_folder: str): @app.post("/v1/pipelines/add") @app.post("/pipelines/add") async def add_pipeline( - form_data: AddPipelineForm, user: str = Depends(get_current_user) + form_data: AddPipelineForm, user: str = Depends(get_current_user) ): if user != API_KEY: raise HTTPException( @@ -417,7 +450,7 @@ async def add_pipeline( @app.post("/v1/pipelines/upload") @app.post("/pipelines/upload") async def upload_pipeline( - file: UploadFile = File(...), user: str = Depends(get_current_user) + file: UploadFile = File(...), user: str = Depends(get_current_user) ): if user != API_KEY: raise HTTPException( @@ -466,7 +499,7 @@ class DeletePipelineForm(BaseModel): @app.delete("/v1/pipelines/delete") @app.delete("/pipelines/delete") async def delete_pipeline( - form_data: DeletePipelineForm, user: str = Depends(get_current_user) + form_data: DeletePipelineForm, user: str = Depends(get_current_user) ): if user != API_KEY: raise HTTPException( @@ -552,7 +585,6 @@ async def get_valves_spec(pipeline_id: str): @app.post("/v1/{pipeline_id}/valves/update") @app.post("/{pipeline_id}/valves/update") async def update_valves(pipeline_id: str, form_data: dict): - if pipeline_id not in PIPELINE_MODULES: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -663,8 +695,8 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): user_message = get_last_user_message(messages) if ( - form_data.model not in app.state.PIPELINES - or app.state.PIPELINES[form_data.model]["type"] == "filter" + form_data.model not in app.state.PIPELINES + or app.state.PIPELINES[form_data.model]["type"] == "filter" ): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/pipeline.service b/pipeline.service new file mode 100644 index 00000000..1621a025 --- /dev/null +++ b/pipeline.service @@ -0,0 +1,28 @@ +[Unit] +Description=SquadRun Pipelines Service +After=network.target + +[Service] +# Run inside project directory +WorkingDirectory=/home/ubuntu/squadrun-pipelines/ + +# Export environment (virtualenv) +Environment="PATH=/home/ubuntu/squadrun-pipelines/.venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + +# Run the script with bash +ExecStart=/bin/bash ./start.sh + +# Restart policy +Restart=always +RestartSec=5 + +# User/group to run as +User=ubuntu +Group=ubuntu + +# Logging (journalctl -u squadrun-pipelines) +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/pipelines/__pycache__/rate_limit_filter_pipeline.cpython-312.pyc b/pipelines/__pycache__/rate_limit_filter_pipeline.cpython-312.pyc new file mode 100644 index 00000000..978e29a0 Binary files /dev/null and b/pipelines/__pycache__/rate_limit_filter_pipeline.cpython-312.pyc differ diff --git a/pipelines/rate_limit_filter_pipeline.py b/pipelines/rate_limit_filter_pipeline.py new file mode 100644 index 00000000..396714be --- /dev/null +++ b/pipelines/rate_limit_filter_pipeline.py @@ -0,0 +1,274 @@ +import json +import os +import logging +from datetime import datetime +from typing import List, Optional, Tuple, Dict, Any + +from pydantic import BaseModel + +from utils.pipelines.database import get_db_connection, is_database_available +from utils.pipelines.redis_client import set_with_ttl, get, is_redis_available + +logger = logging.getLogger(__name__) + +# Redis key patterns +DAILY_LIMIT_KEY = "{user_email}-{model}-{year}-{month}-{day}" +MONTHLY_LIMIT_KEY = "{user_email}-{model}-{year}-{month}" + + +class Pipeline: + class Valves(BaseModel): + pipelines: List[str] = [] + priority: int = 0 + + # JSON strings defining default limits + user_level_limit: str = json.dumps( + [{"users": [], "daily_request_limit": 100, "monthly_request_limit": 3000, "model_groups": []}] + ) + user_group_level_limit: str = json.dumps( + [{"user_groups": [], "daily_request_limit": 100, "monthly_request_limit": 3000, "model_groups": []}] + ) + + def __init__(self) -> None: + self.type = "filter" + self.name = "Rate Limit Filter" + + self.valves = self.Valves( + pipelines=os.getenv("RATE_LIMIT_PIPELINES", "*").split(",") + ) + self.user_level_limit_list = json.loads(self.valves.user_level_limit) + self.user_group_level_limit_list = json.loads(self.valves.user_group_level_limit) + + # Track processed messages (deduplication) + self._processed_messages: set[str] = set() + self._max_processed_cache_size = 10_000 + + async def sync_configs(self) -> None: + """Reload limit configs from valves.""" + self.user_level_limit_list = json.loads(self.valves.user_level_limit) + self.user_group_level_limit_list = json.loads(self.valves.user_group_level_limit) + logger.info(f"[RateLimit] User-level limits: {self.user_level_limit_list}") + logger.info(f"[RateLimit] User-group limits: {self.user_group_level_limit_list}") + + async def on_startup(self) -> None: + logger.info(f"[RateLimit] Rate limit filter pipeline started: {__name__}") + + async def on_shutdown(self) -> None: + logger.info(f"[RateLimit] Rate limit filter pipeline stopped: {__name__}") + + # ------------------------ + # Database fetch helpers + # ------------------------ + + async def get_user_groups(self, user_email: str) -> List[str]: + """Fetch groups associated with a user.""" + try: + async with get_db_connection() as conn: + query = """ + SELECT g.name + FROM "group" g + JOIN auth a ON a.id = ANY (SELECT jsonb_array_elements_text(g.user_ids::jsonb)) + WHERE a.email = $1; + """ + rows = await conn.fetch(query, user_email) + return [row["name"] for row in rows] + except Exception as e: + logger.error(f"[RateLimit] Error fetching groups for {user_email}: {e}") + return [] + + async def get_model_groups(self, model_id: str) -> List[str]: + """Fetch groups attached to a model.""" + if not is_database_available(): + logger.warning("[RateLimit] Database not available, returning empty model groups") + return [] + try: + async with get_db_connection() as conn: + query = """ + SELECT g.name + FROM model m + JOIN "group" g + ON g.id = ANY ( + SELECT jsonb_array_elements_text((m.access_control->'read'->'group_ids')::jsonb) + ) + WHERE m.id = $1; + """ + rows = await conn.fetch(query, model_id) + return [row["name"] for row in rows] + except Exception as e: + logger.error(f"[RateLimit] Error fetching groups for model {model_id}: {e}") + return [] + + # ------------------------ + # Redis fetch helpers + # ------------------------ + + async def _fetch_used_requests(self, user_email: str, model_id: str, scope: str) -> int: + """Generic fetcher for daily/monthly usage from Redis.""" + now = datetime.now() + if scope == "daily": + key = DAILY_LIMIT_KEY.format( + user_email=user_email, model=model_id, + year=now.year, month=now.month, day=now.day + ) + elif scope == "monthly": + key = MONTHLY_LIMIT_KEY.format( + user_email=user_email, model=model_id, + year=now.year, month=now.month + ) + else: + raise ValueError("Scope must be 'daily' or 'monthly'") + + return int(await get(key) or 0) + + async def fetch_daily_used_for_a_user(self, user_email: str, model_id: str) -> int: + return await self._fetch_used_requests(user_email, model_id, "daily") + + async def fetch_monthly_used_for_a_user(self, user_email: str, model_id: str) -> int: + return await self._fetch_used_requests(user_email, model_id, "monthly") + + # ------------------------ + # Limit resolution helpers + # ------------------------ + + async def fetch_user_request_limits_for_model( + self, user_email: str, model_groups: List[str] + ) -> Tuple[Optional[int], Optional[int]]: + """Find smallest user-level daily & monthly limits matching given model groups.""" + min_daily, min_monthly = float("inf"), float("inf") + should_limit = False + + for cfg in self.user_level_limit_list: + if user_email in cfg["users"] and any(m in cfg["model_groups"] for m in model_groups): + min_daily = min(cfg["daily_request_limit"], min_daily) + min_monthly = min(cfg["monthly_request_limit"], min_monthly) + should_limit = True + + return (min_daily, min_monthly) if should_limit else (float("inf"), float("inf")) + + async def fetch_user_groups_request_limits_for_model( + self, user_groups: List[str], model_groups: List[str] + ) -> Tuple[Optional[int], Optional[int]]: + """Find smallest group-level daily & monthly limits for user groups and model groups.""" + min_daily, min_monthly = float("inf"), float("inf") + should_limit = False + + for cfg in self.user_group_level_limit_list: + if any(g in cfg["user_groups"] for g in user_groups) and \ + any(m in cfg["model_groups"] for m in model_groups): + min_daily = min(cfg["daily_request_limit"], min_daily) + min_monthly = min(cfg["monthly_request_limit"], min_monthly) + should_limit = True + + return (min_daily, min_monthly) if should_limit else (float("inf"), float("inf")) + + # ------------------------ + # Core rate limiting logic + # ------------------------ + + def rate_limit_the_user( + self, daily_used: int, monthly_used: int, + daily_request_limit: int, monthly_request_limit: int + ) -> None: + logger.info( + f"[RateLimit] Checking user: daily_used={daily_used}, " + f"monthly_used={monthly_used}, daily_limit={daily_request_limit}, " + f"monthly_limit={monthly_request_limit}" + ) + if daily_used + 1 > daily_request_limit: + raise Exception("Rate limit exceeded: Daily limit reached for this model") + if monthly_used + 1 > monthly_request_limit: + raise Exception("Rate limit exceeded: Monthly limit reached for this model") + + async def check_user_rate_limited(self, user: dict, model_id: str) -> None: + """Check whether a user exceeded daily/monthly request limits.""" + daily_used = await self.fetch_daily_used_for_a_user(user["email"], model_id) + monthly_used = await self.fetch_monthly_used_for_a_user(user["email"], model_id) + + user_groups = await self.get_user_groups(user["email"]) + model_groups = await self.get_model_groups(model_id) + + user_limits = await self.fetch_user_request_limits_for_model(user["email"], model_groups) + group_limits = await self.fetch_user_groups_request_limits_for_model(user_groups, model_groups) + + logger.info( + f"[RateLimit] User {user['email']} groups={user_groups}, model_groups={model_groups}, " + f"daily_used={daily_used}, monthly_used={monthly_used}, " + f"user_limits={user_limits}, group_limits={group_limits}" + ) + + daily_request_limit = min(user_limits[0], group_limits[0]) + monthly_request_limit = min(user_limits[1], group_limits[1]) + self.rate_limit_the_user(daily_used, monthly_used, daily_request_limit, monthly_request_limit) + + async def increment_rate_limit(self, user_email: str, model_id: str) -> None: + """Increment rate limit counters in Redis.""" + if not is_redis_available(): + return + + try: + now = datetime.now() + daily_key = DAILY_LIMIT_KEY.format( + user_email=user_email, model=model_id, + year=now.year, month=now.month, day=now.day + ) + monthly_key = MONTHLY_LIMIT_KEY.format( + user_email=user_email, model=model_id, + year=now.year, month=now.month + ) + + async def incr(key: str, ttl: int): + count = await get(key) + new_count = (int(count) + 1) if count is not None else 1 + await set_with_ttl(key, new_count, ttl) + return new_count + + daily_count = await incr(daily_key, 86400) # 24h + monthly_count = await incr(monthly_key, 2592000) # 30d + + logger.debug(f"[RateLimit] Incremented {user_email}:{model_id} daily={daily_count}, monthly={monthly_count}") + + except Exception as e: + logger.error(f"[RateLimit] Error incrementing rate limit for user {user_email}: {e}") + + # ------------------------ + # Entry point + # ------------------------ + + async def inlet(self, body: Dict[str, Any], user: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Main filter method that checks rate limits before processing requests.""" + user_id = user.get("id") if user else None + logger.info(f"[RateLimit] Processing request for user={user_id or 'anonymous'}") + + if not user or user.get("role") not in {"user", "admin"}: + logger.info("[RateLimit] Skipping rate limit check for non-user/admin role") + return body + + message_id = body.get("metadata", {}).get("message_id") + if not message_id: + logger.warning("[RateLimit] Missing message_id in request body, skipping rate limit check") + return body + + # --- Deduplication with cleanup --- + if message_id in self._processed_messages: + logger.debug(f"[RateLimit] Message {message_id} already processed, skipping check") + return body + + if len(self._processed_messages) >= self._max_processed_cache_size: + logger.debug("[RateLimit] Clearing processed message cache (size exceeded 10k)") + self._processed_messages.clear() + + self._processed_messages.add(message_id) + logger.info(f"[RateLimit] Processing message_id={message_id} for user={user.get('email')}") + + model_id = body.get("model", "unknown") + await self.sync_configs() + + try: + await self.check_user_rate_limited(user, model_id) + await self.increment_rate_limit(user["email"], model_id) + return body + except Exception as e: + if "Rate limit exceeded" in str(e): + raise + logger.error(f"[RateLimit] Error in rate limit check for user {user_id}: {e}") + return body diff --git a/requirements.txt b/requirements.txt index b3cc96e8..690463b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,6 +28,7 @@ redis sqlmodel chromadb psycopg2-binary +asyncpg # Observability langfuse diff --git a/test_rate_limiting.py b/test_rate_limiting.py new file mode 100644 index 00000000..98eae9d5 --- /dev/null +++ b/test_rate_limiting.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +""" +Test script for the updated rate limiting functionality +""" +import asyncio +import logging +from pipelines.rate_limit_filter_pipeline import Pipeline + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def test_rate_limiting(): + """Test the rate limiting functionality""" + try: + # Initialize the pipeline + pipeline = Pipeline() + + # Test user data + test_user = { + "id": "test_user_123", + "role": "user", + "email": "test@example.com" + } + + # Test request body + test_body = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}] + } + + logger.info("Testing rate limiting with user groups and model groups...") + + # Test 1: Normal request (should pass) + logger.info("Test 1: Normal request") + try: + result = await pipeline.inlet(test_body, test_user) + logger.info("✅ Normal request passed") + except Exception as e: + logger.error(f"❌ Normal request failed: {e}") + + # Test 2: Multiple requests to test rate limiting + logger.info("Test 2: Multiple requests to test rate limiting") + for i in range(5): + try: + result = await pipeline.inlet(test_body, test_user) + logger.info(f"✅ Request {i+1} passed") + except Exception as e: + logger.error(f"❌ Request {i+1} failed: {e}") + break + + # Test 3: Different model + logger.info("Test 3: Different model") + test_body_claude = { + "model": "claude-3", + "messages": [{"role": "user", "content": "Hello"}] + } + + try: + result = await pipeline.inlet(test_body_claude, test_user) + logger.info("✅ Different model request passed") + except Exception as e: + logger.error(f"❌ Different model request failed: {e}") + + # Test 4: Admin user (should skip rate limiting) + logger.info("Test 4: Admin user") + admin_user = { + "id": "admin_123", + "role": "admin", + "email": "admin@example.com" + } + + try: + result = await pipeline.inlet(test_body, admin_user) + logger.info("✅ Admin request passed") + except Exception as e: + logger.error(f"❌ Admin request failed: {e}") + + # Test 5: User without ID + logger.info("Test 5: User without ID") + user_no_id = { + "role": "user", + "email": "no_id@example.com" + } + + try: + result = await pipeline.inlet(test_body, user_no_id) + logger.info("✅ User without ID request passed (should skip rate limiting)") + except Exception as e: + logger.error(f"❌ User without ID request failed: {e}") + + logger.info("🎉 Rate limiting tests completed!") + + except Exception as e: + logger.error(f"❌ Test failed with error: {e}") + + +async def main(): + """Main test function""" + logger.info("Starting rate limiting tests...") + await test_rate_limiting() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/utils/pipelines/database.py b/utils/pipelines/database.py new file mode 100644 index 00000000..3964cbe1 --- /dev/null +++ b/utils/pipelines/database.py @@ -0,0 +1,105 @@ +import os +import logging +from typing import Optional +from contextlib import asynccontextmanager +import asyncpg +from config import DATABASE_URL + +logger = logging.getLogger(__name__) + +# Global database connection pool +_db_pool: Optional[asyncpg.Pool] = None + + +async def init_database(): + """Initialize database connection pool if DATABASE_URL is provided""" + global _db_pool + + if DATABASE_URL is None: + logger.info("DATABASE_URL not provided, skipping database initialization") + return + + try: + _db_pool = await asyncpg.create_pool( + DATABASE_URL, + min_size=1, + max_size=10, + command_timeout=60 + ) + logger.info("Database connection pool initialized successfully") + + # Test the connection + async with _db_pool.acquire() as conn: + await conn.fetchval('SELECT 1') + logger.info("Database connection test successful") + + except Exception as e: + logger.error(f"Failed to initialize database connection: {e}") + _db_pool = None + raise + + +async def close_database(): + """Close database connection pool""" + global _db_pool + + if _db_pool is not None: + await _db_pool.close() + _db_pool = None + logger.info("Database connection pool closed") + + +def get_db_pool() -> Optional[asyncpg.Pool]: + """Get the database connection pool""" + return _db_pool + + +@asynccontextmanager +async def get_db_connection(): + """Context manager for database connections""" + if _db_pool is None: + raise RuntimeError("Database not initialized. DATABASE_URL not provided or connection failed.") + + async with _db_pool.acquire() as conn: + yield conn + + +async def execute_query(query: str, *args): + """Execute a query and return results""" + if _db_pool is None: + raise RuntimeError("Database not initialized. DATABASE_URL not provided or connection failed.") + + async with _db_pool.acquire() as conn: + return await conn.fetch(query, *args) + + +async def execute_one(query: str, *args): + """Execute a query and return one result""" + if _db_pool is None: + raise RuntimeError("Database not initialized. DATABASE_URL not provided or connection failed.") + + async with _db_pool.acquire() as conn: + return await conn.fetchrow(query, *args) + + +async def execute_scalar(query: str, *args): + """Execute a query and return a scalar value""" + if _db_pool is None: + raise RuntimeError("Database not initialized. DATABASE_URL not provided or connection failed.") + + async with _db_pool.acquire() as conn: + return await conn.fetchval(query, *args) + + +async def execute_command(query: str, *args): + """Execute a command (INSERT, UPDATE, DELETE) and return status""" + if _db_pool is None: + raise RuntimeError("Database not initialized. DATABASE_URL not provided or connection failed.") + + async with _db_pool.acquire() as conn: + return await conn.execute(query, *args) + + +def is_database_available() -> bool: + """Check if database is available""" + return _db_pool is not None diff --git a/utils/pipelines/redis_client.py b/utils/pipelines/redis_client.py new file mode 100644 index 00000000..0ff367ba --- /dev/null +++ b/utils/pipelines/redis_client.py @@ -0,0 +1,95 @@ +import os +import logging +from typing import Optional, Any +import redis.asyncio as redis +from config import REDIS_URL + +logger = logging.getLogger(__name__) + +# Global Redis client +_redis_client: Optional[redis.Redis] = None + + +async def init_redis(): + """Initialize Redis client""" + global _redis_client + + if REDIS_URL is None: + logger.info("REDIS_URL not provided, skipping Redis initialization") + return + + try: + _redis_client = redis.from_url(REDIS_URL, decode_responses=True) + await _redis_client.ping() + logger.info("Redis client initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize Redis client: {e}") + _redis_client = None + + +async def close_redis(): + """Close Redis client""" + global _redis_client + + if _redis_client is not None: + await _redis_client.close() + _redis_client = None + logger.info("Redis client closed") + + +def get_redis_client() -> Optional[redis.Redis]: + """Get Redis client""" + return _redis_client + + +def is_redis_available() -> bool: + """Check if Redis is available""" + return _redis_client is not None + + +async def set_with_ttl(key: str, value: Any, ttl: int = 3600) -> bool: + """Set key-value pair with TTL (time to live in seconds)""" + if _redis_client is None: + raise RuntimeError("Redis client not initialized") + + try: + return await _redis_client.set(key, value, ex=ttl) + except Exception as e: + logger.error(f"Failed to set Redis key {key}: {e}") + raise + + +async def get(key: str) -> Optional[str]: + """Get value by key""" + if _redis_client is None: + raise RuntimeError("Redis client not initialized") + + try: + return await _redis_client.get(key) + except Exception as e: + logger.error(f"Failed to get Redis key {key}: {e}") + return None + + +async def delete(key: str) -> bool: + """Delete key""" + if _redis_client is None: + raise RuntimeError("Redis client not initialized") + + try: + return await _redis_client.delete(key) > 0 + except Exception as e: + logger.error(f"Failed to delete Redis key {key}: {e}") + raise + + +async def exists(key: str) -> bool: + """Check if key exists""" + if _redis_client is None: + raise RuntimeError("Redis client not initialized") + + try: + return await _redis_client.exists(key) > 0 + except Exception as e: + logger.error(f"Failed to check Redis key existence {key}: {e}") + raise