Skip to content

Commit 98d81a7

Browse files
committed
Add preprocess hooks
1 parent c3764cc commit 98d81a7

File tree

9 files changed

+177
-116
lines changed

9 files changed

+177
-116
lines changed

eval_protocol/proxy/README.md

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -332,32 +332,6 @@ eval_protocol/proxy/
332332
└── README.md # This file
333333
```
334334
335-
### Adding Custom Authentication
336-
337-
Extend `AuthProvider` in `auth.py`:
338-
```python
339-
from .auth import AuthProvider
340-
from fastapi import HTTPException, Request
341-
342-
class MyAuthProvider(AuthProvider):
343-
def validate(self, request: Request) -> Optional[str]:
344-
api_key = None
345-
auth_header = request.headers.get("authorization", "")
346-
if auth_header.startswith("Bearer "):
347-
api_key = auth_header.replace("Bearer ", "").strip()
348-
if not api_key:
349-
raise HTTPException(status_code=401, detail="Invalid API key")
350-
return api_key
351-
```
352-
353-
Then pass it to `create_app`:
354-
```python
355-
from proxy_core import create_app
356-
from my_auth import MyAuthProvider
357-
358-
app = create_app(auth_provider=MyAuthProvider())
359-
```
360-
361335
### Testing
362336
363337
#### Test chat completion:

eval_protocol/proxy/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
Langfuse tracing for distributed evaluation workflows.
66
"""
77

8-
from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig
8+
from .proxy_core import create_app, AuthProvider, NoAuthProvider, ProxyConfig, ChatParams, TracesParams
99

1010
__all__ = [
1111
"create_app",
1212
"AuthProvider",
1313
"NoAuthProvider",
1414
"ProxyConfig",
15+
"ChatParams",
16+
"TracesParams",
1517
]

eval_protocol/proxy/docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ services:
4949
- LOG_LEVEL=INFO
5050
# Langfuse and secrets
5151
- SECRETS_PATH=/app/proxy_core/secrets.yaml
52+
- LANGFUSE_HOST=${LANGFUSE_HOST:-https://cloud.langfuse.com}
5253
ports:
5354
- "4000:4000" # Main public-facing port
5455
networks:

eval_protocol/proxy/proxy_core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from .models import ProxyConfig
1+
from .models import ProxyConfig, ChatParams, TracesParams
22
from .app import create_app
33
from .auth import AuthProvider, NoAuthProvider
44

55
__all__ = [
66
"ProxyConfig",
7+
"ChatParams",
8+
"TracesParams",
79
"create_app",
810
"AuthProvider",
911
"NoAuthProvider",

eval_protocol/proxy/proxy_core/app.py

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from fastapi import FastAPI, Depends, HTTPException, Request, Query
7-
from typing import Optional, List
7+
from typing import Optional, Callable, Dict, Any, List
88
import os
99
import redis
1010
import logging
@@ -13,24 +13,28 @@
1313
import sys
1414
from contextlib import asynccontextmanager
1515

16-
from .models import ProxyConfig, LangfuseTracesResponse
16+
from .models import ProxyConfig, LangfuseTracesResponse, TracesParams, ChatParams, ChatRequestHook, TracesRequestHook
1717
from .auth import AuthProvider, NoAuthProvider
1818
from .litellm import handle_chat_completion, proxy_to_litellm
1919
from .langfuse import fetch_langfuse_traces
2020

2121
# Configure logging before any other imports (so all modules inherit this config)
2222
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
23-
logging.basicConfig(
24-
level=getattr(logging, log_level),
25-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
26-
handlers=[logging.StreamHandler(sys.stdout)],
27-
)
23+
if not logging.getLogger().hasHandlers():
24+
logging.basicConfig(
25+
level=getattr(logging, log_level),
26+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
27+
handlers=[logging.StreamHandler(sys.stdout)],
28+
)
2829

2930
logger = logging.getLogger(__name__)
3031

3132

32-
def build_proxy_config() -> ProxyConfig:
33-
"""Load environment and secrets, and build ProxyConfig (no Redis)."""
33+
def build_proxy_config(
34+
preprocess_chat_request: Optional[ChatRequestHook] = None,
35+
preprocess_traces_request: Optional[TracesRequestHook] = None,
36+
) -> ProxyConfig:
37+
"""Load environment and secrets, and build ProxyConfig"""
3438
# Env
3539
litellm_url = os.getenv("LITELLM_URL")
3640
if not litellm_url:
@@ -67,6 +71,8 @@ def build_proxy_config() -> ProxyConfig:
6771
langfuse_host=langfuse_host,
6872
langfuse_keys=langfuse_keys,
6973
default_project_id=default_project_id,
74+
preprocess_chat_request=preprocess_chat_request,
75+
preprocess_traces_request=preprocess_traces_request,
7076
)
7177

7278

@@ -97,12 +103,15 @@ def init_redis() -> redis.Redis:
97103

98104
def create_app(
99105
auth_provider: AuthProvider = NoAuthProvider(),
106+
preprocess_chat_request: Optional[ChatRequestHook] = None,
107+
preprocess_traces_request: Optional[TracesRequestHook] = None,
100108
) -> FastAPI:
101109
@asynccontextmanager
102110
async def lifespan(app: FastAPI):
103111
# Build runtime on startup
104-
app.state.config = build_proxy_config()
112+
app.state.config = build_proxy_config(preprocess_chat_request, preprocess_traces_request)
105113
app.state.redis = init_redis()
114+
106115
try:
107116
yield
108117
finally:
@@ -119,8 +128,46 @@ def get_config(request: Request) -> ProxyConfig:
119128
def get_redis(request: Request) -> redis.Redis:
120129
return request.app.state.redis
121130

131+
def get_traces_params(
132+
tags: Optional[List[str]] = Query(default=None),
133+
project_id: Optional[str] = None,
134+
limit: int = 100,
135+
sample_size: Optional[int] = None,
136+
user_id: Optional[str] = None,
137+
session_id: Optional[str] = None,
138+
name: Optional[str] = None,
139+
environment: Optional[str] = None,
140+
version: Optional[str] = None,
141+
release: Optional[str] = None,
142+
fields: Optional[str] = None,
143+
hours_back: Optional[int] = None,
144+
from_timestamp: Optional[str] = None,
145+
to_timestamp: Optional[str] = None,
146+
sleep_between_gets: float = 2.5,
147+
max_retries: int = 3,
148+
) -> TracesParams:
149+
return TracesParams(
150+
tags=tags,
151+
project_id=project_id,
152+
limit=limit,
153+
sample_size=sample_size,
154+
user_id=user_id,
155+
session_id=session_id,
156+
name=name,
157+
environment=environment,
158+
version=version,
159+
release=release,
160+
fields=fields,
161+
hours_back=hours_back,
162+
from_timestamp=from_timestamp,
163+
to_timestamp=to_timestamp,
164+
sleep_between_gets=sleep_between_gets,
165+
max_retries=max_retries,
166+
)
167+
122168
async def require_auth(request: Request) -> None:
123-
auth_provider.validate(request)
169+
account_id = auth_provider.validate_and_return_account_id(request)
170+
request.state.account_id = account_id
124171
return None
125172

126173
# =====================
@@ -161,11 +208,9 @@ async def chat_completion_with_full_metadata(
161208
encoded_base_url: Optional[str] = None,
162209
config: ProxyConfig = Depends(get_config),
163210
redis_client: redis.Redis = Depends(get_redis),
211+
_: None = Depends(require_auth),
164212
):
165-
return await handle_chat_completion(
166-
config=config,
167-
redis_client=redis_client,
168-
request=request,
213+
params = ChatParams(
169214
project_id=project_id,
170215
rollout_id=rollout_id,
171216
invocation_id=invocation_id,
@@ -174,6 +219,12 @@ async def chat_completion_with_full_metadata(
174219
row_id=row_id,
175220
encoded_base_url=encoded_base_url,
176221
)
222+
return await handle_chat_completion(
223+
config=config,
224+
redis_client=redis_client,
225+
request=request,
226+
params=params,
227+
)
177228

178229
@app.post("/project_id/{project_id}/chat/completions")
179230
@app.post("/v1/project_id/{project_id}/chat/completions")
@@ -182,12 +233,14 @@ async def chat_completion_with_project_only(
182233
request: Request,
183234
config: ProxyConfig = Depends(get_config),
184235
redis_client: redis.Redis = Depends(get_redis),
236+
_: None = Depends(require_auth),
185237
):
238+
params = ChatParams(project_id=project_id)
186239
return await handle_chat_completion(
187240
config=config,
188241
redis_client=redis_client,
189242
request=request,
190-
project_id=project_id,
243+
params=params,
191244
)
192245

193246
# ===============
@@ -198,45 +251,20 @@ async def chat_completion_with_project_only(
198251
@app.get("/project_id/{project_id}/traces", response_model=LangfuseTracesResponse)
199252
@app.get("/v1/project_id/{project_id}/traces", response_model=LangfuseTracesResponse)
200253
async def get_langfuse_traces(
201-
tags: List[str] = Query(...), # REQUIRED query param
254+
request: Request,
255+
params: TracesParams = Depends(get_traces_params),
202256
project_id: Optional[str] = None,
203-
limit: int = 100,
204-
sample_size: Optional[int] = None,
205-
user_id: Optional[str] = None,
206-
session_id: Optional[str] = None,
207-
name: Optional[str] = None,
208-
environment: Optional[str] = None,
209-
version: Optional[str] = None,
210-
release: Optional[str] = None,
211-
fields: Optional[str] = None,
212-
hours_back: Optional[int] = None,
213-
from_timestamp: Optional[str] = None,
214-
to_timestamp: Optional[str] = None,
215-
sleep_between_gets: float = 2.5,
216-
max_retries: int = 3,
217257
config: ProxyConfig = Depends(get_config),
218258
redis_client: redis.Redis = Depends(get_redis),
219259
_: None = Depends(require_auth),
220260
) -> LangfuseTracesResponse:
261+
if project_id is not None:
262+
params.project_id = project_id
221263
return await fetch_langfuse_traces(
222264
config=config,
223265
redis_client=redis_client,
224-
tags=tags,
225-
project_id=project_id,
226-
limit=limit,
227-
sample_size=sample_size,
228-
user_id=user_id,
229-
session_id=session_id,
230-
name=name,
231-
environment=environment,
232-
version=version,
233-
release=release,
234-
fields=fields,
235-
hours_back=hours_back,
236-
from_timestamp=from_timestamp,
237-
to_timestamp=to_timestamp,
238-
sleep_between_gets=sleep_between_gets,
239-
max_retries=max_retries,
266+
request=request,
267+
params=params,
240268
)
241269

242270
# Health
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
2+
import logging
33
from fastapi import Request
4+
from fastapi import HTTPException
5+
import httpx
6+
from typing import Optional
7+
8+
logger = logging.getLogger(__name__)
49

510

611
class AuthProvider(ABC):
712
@abstractmethod
8-
def validate(self, request: Request) -> Optional[str]: ...
13+
def validate_and_return_account_id(self, request: Request) -> Optional[str]: ...
914

1015

1116
class NoAuthProvider(AuthProvider):
12-
def validate(self, request: Request) -> Optional[str]:
17+
def validate_and_return_account_id(self, request: Request) -> Optional[str]:
1318
return None

eval_protocol/proxy/proxy_core/langfuse.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
import asyncio
99
from typing import List, Optional, Dict, Any, Set
1010
from datetime import datetime, timedelta
11-
from fastapi import HTTPException
11+
from fastapi import HTTPException, Request
1212
import redis
1313
from .redis_utils import get_insertion_ids
14-
from .models import ProxyConfig, LangfuseTracesResponse, TraceResponse
14+
from .models import ProxyConfig, LangfuseTracesResponse, TraceResponse, TracesParams
1515

1616
logger = logging.getLogger(__name__)
1717

1818

19-
def _extract_tag_value(tags: List[str], prefix: str) -> Optional[str]:
19+
def _extract_tag_value(tags: Optional[List[str]], prefix: str) -> Optional[str]:
2020
"""Extract value from a tag with the given prefix (e.g., 'rollout_id:' or 'insertion_id:')."""
21+
if not tags:
22+
return None
2123
for tag in tags:
2224
if tag.startswith(prefix):
2325
return tag.split(":", 1)[1]
@@ -60,7 +62,7 @@ async def _fetch_trace_list_with_retry(
6062
langfuse_client: Any,
6163
page: int,
6264
limit: int,
63-
tags: List[str],
65+
tags: Optional[List[str]],
6466
user_id: Optional[str],
6567
session_id: Optional[str],
6668
name: Optional[str],
@@ -152,22 +154,8 @@ async def _fetch_trace_detail_with_retry(
152154
async def fetch_langfuse_traces(
153155
config: ProxyConfig,
154156
redis_client: redis.Redis,
155-
tags: List[str],
156-
project_id: Optional[str] = None,
157-
limit: int = 100,
158-
sample_size: Optional[int] = None,
159-
user_id: Optional[str] = None,
160-
session_id: Optional[str] = None,
161-
name: Optional[str] = None,
162-
environment: Optional[str] = None,
163-
version: Optional[str] = None,
164-
release: Optional[str] = None,
165-
fields: Optional[str] = None,
166-
hours_back: Optional[int] = None,
167-
from_timestamp: Optional[str] = None,
168-
to_timestamp: Optional[str] = None,
169-
sleep_between_gets: float = 2.5,
170-
max_retries: int = 3,
157+
request: Request,
158+
params: TracesParams,
171159
):
172160
"""
173161
Fetch full traces from Langfuse for the specified project.
@@ -184,9 +172,27 @@ async def fetch_langfuse_traces(
184172
185173
Returns a list of full trace objects (including observations) in JSON format.
186174
"""
187-
# Validate tags
188-
if not tags or not any(tag.startswith("rollout_id:") for tag in tags):
189-
raise HTTPException(status_code=422, detail="Tags must include at least one 'rollout_id:*' tag")
175+
176+
# Preprocess traces request
177+
if config.preprocess_traces_request:
178+
params = config.preprocess_traces_request(request, params)
179+
180+
tags = params.tags
181+
project_id = params.project_id
182+
limit = params.limit
183+
sample_size = params.sample_size
184+
user_id = params.user_id
185+
session_id = params.session_id
186+
name = params.name
187+
environment = params.environment
188+
version = params.version
189+
release = params.release
190+
fields = params.fields
191+
hours_back = params.hours_back
192+
from_timestamp = params.from_timestamp
193+
to_timestamp = params.to_timestamp
194+
sleep_between_gets = params.sleep_between_gets
195+
max_retries = params.max_retries
190196

191197
# Use default project if not specified
192198
if project_id is None:

0 commit comments

Comments
 (0)