44"""
55
66from fastapi import FastAPI , Depends , HTTPException , Request , Query
7- from typing import Optional , List
7+ from typing import Optional , Callable , Dict , Any , List
88import os
99import redis
1010import logging
1313import sys
1414from contextlib import asynccontextmanager
1515
16- from .models import ProxyConfig , LangfuseTracesResponse
16+ from .models import ProxyConfig , LangfuseTracesResponse , TracesParams , ChatParams , ChatRequestHook , TracesRequestHook
1717from .auth import AuthProvider , NoAuthProvider
1818from .litellm import handle_chat_completion , proxy_to_litellm
1919from .langfuse import fetch_langfuse_traces
2020
2121# Configure logging before any other imports (so all modules inherit this config)
2222log_level = os .getenv ("LOG_LEVEL" , "INFO" ).upper ()
23- logging .basicConfig (
24- level = getattr (logging , log_level ),
25- format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ,
26- handlers = [logging .StreamHandler (sys .stdout )],
27- )
23+ if not logging .getLogger ().hasHandlers ():
24+ logging .basicConfig (
25+ level = getattr (logging , log_level ),
26+ format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ,
27+ handlers = [logging .StreamHandler (sys .stdout )],
28+ )
2829
2930logger = logging .getLogger (__name__ )
3031
3132
32- def build_proxy_config () -> ProxyConfig :
33- """Load environment and secrets, and build ProxyConfig (no Redis)."""
33+ def build_proxy_config (
34+ preprocess_chat_request : Optional [ChatRequestHook ] = None ,
35+ preprocess_traces_request : Optional [TracesRequestHook ] = None ,
36+ ) -> ProxyConfig :
37+ """Load environment and secrets, and build ProxyConfig"""
3438 # Env
3539 litellm_url = os .getenv ("LITELLM_URL" )
3640 if not litellm_url :
@@ -67,6 +71,8 @@ def build_proxy_config() -> ProxyConfig:
6771 langfuse_host = langfuse_host ,
6872 langfuse_keys = langfuse_keys ,
6973 default_project_id = default_project_id ,
74+ preprocess_chat_request = preprocess_chat_request ,
75+ preprocess_traces_request = preprocess_traces_request ,
7076 )
7177
7278
@@ -97,12 +103,15 @@ def init_redis() -> redis.Redis:
97103
98104def create_app (
99105 auth_provider : AuthProvider = NoAuthProvider (),
106+ preprocess_chat_request : Optional [ChatRequestHook ] = None ,
107+ preprocess_traces_request : Optional [TracesRequestHook ] = None ,
100108) -> FastAPI :
101109 @asynccontextmanager
102110 async def lifespan (app : FastAPI ):
103111 # Build runtime on startup
104- app .state .config = build_proxy_config ()
112+ app .state .config = build_proxy_config (preprocess_chat_request , preprocess_traces_request )
105113 app .state .redis = init_redis ()
114+
106115 try :
107116 yield
108117 finally :
@@ -119,8 +128,46 @@ def get_config(request: Request) -> ProxyConfig:
119128 def get_redis (request : Request ) -> redis .Redis :
120129 return request .app .state .redis
121130
131+ def get_traces_params (
132+ tags : Optional [List [str ]] = Query (default = None ),
133+ project_id : Optional [str ] = None ,
134+ limit : int = 100 ,
135+ sample_size : Optional [int ] = None ,
136+ user_id : Optional [str ] = None ,
137+ session_id : Optional [str ] = None ,
138+ name : Optional [str ] = None ,
139+ environment : Optional [str ] = None ,
140+ version : Optional [str ] = None ,
141+ release : Optional [str ] = None ,
142+ fields : Optional [str ] = None ,
143+ hours_back : Optional [int ] = None ,
144+ from_timestamp : Optional [str ] = None ,
145+ to_timestamp : Optional [str ] = None ,
146+ sleep_between_gets : float = 2.5 ,
147+ max_retries : int = 3 ,
148+ ) -> TracesParams :
149+ return TracesParams (
150+ tags = tags ,
151+ project_id = project_id ,
152+ limit = limit ,
153+ sample_size = sample_size ,
154+ user_id = user_id ,
155+ session_id = session_id ,
156+ name = name ,
157+ environment = environment ,
158+ version = version ,
159+ release = release ,
160+ fields = fields ,
161+ hours_back = hours_back ,
162+ from_timestamp = from_timestamp ,
163+ to_timestamp = to_timestamp ,
164+ sleep_between_gets = sleep_between_gets ,
165+ max_retries = max_retries ,
166+ )
167+
122168 async def require_auth (request : Request ) -> None :
123- auth_provider .validate (request )
169+ account_id = auth_provider .validate_and_return_account_id (request )
170+ request .state .account_id = account_id
124171 return None
125172
126173 # =====================
@@ -161,11 +208,9 @@ async def chat_completion_with_full_metadata(
161208 encoded_base_url : Optional [str ] = None ,
162209 config : ProxyConfig = Depends (get_config ),
163210 redis_client : redis .Redis = Depends (get_redis ),
211+ _ : None = Depends (require_auth ),
164212 ):
165- return await handle_chat_completion (
166- config = config ,
167- redis_client = redis_client ,
168- request = request ,
213+ params = ChatParams (
169214 project_id = project_id ,
170215 rollout_id = rollout_id ,
171216 invocation_id = invocation_id ,
@@ -174,6 +219,12 @@ async def chat_completion_with_full_metadata(
174219 row_id = row_id ,
175220 encoded_base_url = encoded_base_url ,
176221 )
222+ return await handle_chat_completion (
223+ config = config ,
224+ redis_client = redis_client ,
225+ request = request ,
226+ params = params ,
227+ )
177228
178229 @app .post ("/project_id/{project_id}/chat/completions" )
179230 @app .post ("/v1/project_id/{project_id}/chat/completions" )
@@ -182,12 +233,14 @@ async def chat_completion_with_project_only(
182233 request : Request ,
183234 config : ProxyConfig = Depends (get_config ),
184235 redis_client : redis .Redis = Depends (get_redis ),
236+ _ : None = Depends (require_auth ),
185237 ):
238+ params = ChatParams (project_id = project_id )
186239 return await handle_chat_completion (
187240 config = config ,
188241 redis_client = redis_client ,
189242 request = request ,
190- project_id = project_id ,
243+ params = params ,
191244 )
192245
193246 # ===============
@@ -198,45 +251,20 @@ async def chat_completion_with_project_only(
198251 @app .get ("/project_id/{project_id}/traces" , response_model = LangfuseTracesResponse )
199252 @app .get ("/v1/project_id/{project_id}/traces" , response_model = LangfuseTracesResponse )
200253 async def get_langfuse_traces (
201- tags : List [str ] = Query (...), # REQUIRED query param
254+ request : Request ,
255+ params : TracesParams = Depends (get_traces_params ),
202256 project_id : Optional [str ] = None ,
203- limit : int = 100 ,
204- sample_size : Optional [int ] = None ,
205- user_id : Optional [str ] = None ,
206- session_id : Optional [str ] = None ,
207- name : Optional [str ] = None ,
208- environment : Optional [str ] = None ,
209- version : Optional [str ] = None ,
210- release : Optional [str ] = None ,
211- fields : Optional [str ] = None ,
212- hours_back : Optional [int ] = None ,
213- from_timestamp : Optional [str ] = None ,
214- to_timestamp : Optional [str ] = None ,
215- sleep_between_gets : float = 2.5 ,
216- max_retries : int = 3 ,
217257 config : ProxyConfig = Depends (get_config ),
218258 redis_client : redis .Redis = Depends (get_redis ),
219259 _ : None = Depends (require_auth ),
220260 ) -> LangfuseTracesResponse :
261+ if project_id is not None :
262+ params .project_id = project_id
221263 return await fetch_langfuse_traces (
222264 config = config ,
223265 redis_client = redis_client ,
224- tags = tags ,
225- project_id = project_id ,
226- limit = limit ,
227- sample_size = sample_size ,
228- user_id = user_id ,
229- session_id = session_id ,
230- name = name ,
231- environment = environment ,
232- version = version ,
233- release = release ,
234- fields = fields ,
235- hours_back = hours_back ,
236- from_timestamp = from_timestamp ,
237- to_timestamp = to_timestamp ,
238- sleep_between_gets = sleep_between_gets ,
239- max_retries = max_retries ,
266+ request = request ,
267+ params = params ,
240268 )
241269
242270 # Health
0 commit comments