Skip to content

Commit eca601a

Browse files
committed
2 parents 91121cc + d2294bf commit eca601a

File tree

11 files changed

+555
-23
lines changed

11 files changed

+555
-23
lines changed

eval_protocol/cli_commands/logs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@ def logs_command(args):
1919
print("Press Ctrl+C to stop the server")
2020
print("-" * 50)
2121

22+
# setup Elasticsearch
23+
from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup
24+
25+
elasticsearch_config = ElasticsearchSetup().setup_elasticsearch()
26+
2227
try:
23-
serve_logs(port=args.port)
28+
serve_logs(port=args.port, elasticsearch_config=elasticsearch_config)
2429
return 0
2530
except KeyboardInterrupt:
2631
print("\n🛑 Server stopped by user")

eval_protocol/logging/elasticsearch_direct_http_handler.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,16 @@ def emit(self, record: logging.LogRecord) -> None:
5151
print(f"Error preparing log for Elasticsearch: {e}")
5252

5353
def _get_rollout_id(self, record: logging.LogRecord) -> str:
54-
"""Get the rollout ID from environment variables."""
54+
"""Get the rollout ID from record extra data or environment variables."""
55+
# Check if rollout_id is provided in the extra data first
56+
if hasattr(record, "rollout_id") and record.rollout_id is not None: # type: ignore
57+
return str(record.rollout_id) # type: ignore
58+
59+
# Fall back to environment variable
5560
rollout_id = os.getenv("EP_ROLLOUT_ID")
5661
if rollout_id is None:
5762
raise ValueError(
58-
"EP_ROLLOUT_ID environment variable is not set but needed for ElasticsearchDirectHttpHandler"
63+
"EP_ROLLOUT_ID environment variable is not set and no rollout_id provided in extra data for ElasticsearchDirectHttpHandler"
5964
)
6065
return rollout_id
6166

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,18 @@ def _get_status() -> Dict[str, Any]:
194194
terminated = bool(status.get("terminated", False))
195195
if terminated:
196196
break
197+
except requests.exceptions.HTTPError as e:
198+
if e.response is not None and e.response.status_code == 404:
199+
# 404 means server doesn't implement /status endpoint, stop polling
200+
logger.info(
201+
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
202+
)
203+
break
204+
else:
205+
raise
197206
except Exception:
198-
# transient errors; continue polling
199-
pass
207+
# For all other exceptions, raise them
208+
raise
200209

201210
await asyncio.sleep(poll_interval)
202211
else:

eval_protocol/utils/logs_models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
Pydantic models for the logs server API.
3+
4+
This module contains data models that match the TypeScript schemas in eval-protocol.ts
5+
to ensure consistent data structure between Python backend and TypeScript frontend.
6+
"""
7+
8+
from typing import Any, List, Optional
9+
from pydantic import BaseModel, ConfigDict, Field
10+
11+
12+
class LogEntry(BaseModel):
13+
"""
14+
Represents a single log entry from Elasticsearch.
15+
16+
This model matches the LogEntrySchema in eval-protocol.ts to ensure
17+
consistent data structure between Python backend and TypeScript frontend.
18+
"""
19+
20+
timestamp: str = Field(..., alias="@timestamp", description="ISO 8601 timestamp of the log entry")
21+
level: str = Field(..., description="Log level (DEBUG, INFO, WARNING, ERROR)")
22+
message: str = Field(..., description="The log message")
23+
logger_name: str = Field(..., description="Name of the logger that created this entry")
24+
rollout_id: str = Field(..., description="ID of the rollout this log belongs to")
25+
status_code: Optional[int] = Field(None, description="Optional status code")
26+
status_message: Optional[str] = Field(None, description="Optional status message")
27+
status_details: Optional[List[Any]] = Field(None, description="Optional status details")
28+
29+
model_config = ConfigDict(populate_by_name=True)
30+
31+
32+
class LogsResponse(BaseModel):
33+
"""
34+
Response model for the get_logs endpoint.
35+
36+
This model matches the LogsResponseSchema in eval-protocol.ts to ensure
37+
consistent data structure between Python backend and TypeScript frontend.
38+
"""
39+
40+
logs: List[LogEntry] = Field(..., description="Array of log entries")
41+
total: int = Field(..., description="Total number of logs available")
42+
rollout_id: str = Field(..., description="The rollout ID these logs belong to")
43+
filtered_by_level: Optional[str] = Field(None, description="Log level filter applied")
44+
45+
model_config = ConfigDict()

eval_protocol/utils/logs_server.py

Lines changed: 129 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@
66
import time
77
from contextlib import asynccontextmanager
88
from queue import Queue
9-
from typing import TYPE_CHECKING, Any, List, Optional
9+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
1010

1111
import psutil
1212
import uvicorn
13-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
13+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query
14+
from fastapi.middleware.cors import CORSMiddleware
1415

1516
from eval_protocol.dataset_logger import default_logger
1617
from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE
1718
from eval_protocol.event_bus import event_bus
1819
from eval_protocol.models import Status
1920
from eval_protocol.utils.vite_server import ViteServer
21+
from eval_protocol.logging.elasticsearch_client import ElasticsearchClient
22+
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
23+
from eval_protocol.utils.logs_models import LogEntry, LogsResponse
2024

2125
if TYPE_CHECKING:
2226
from eval_protocol.models import EvaluationRow
@@ -71,8 +75,11 @@ async def _start_broadcast_loop(self):
7175
while True:
7276
try:
7377
# Wait for a message to be queued
74-
message = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)
75-
await self._send_text_to_all_connections(message)
78+
message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)
79+
80+
# Regular string message for all connections
81+
await self._send_text_to_all_connections(str(message_data))
82+
7683
except Exception as e:
7784
logger.error(f"Error in broadcast loop: {e}")
7885
await asyncio.sleep(0.1)
@@ -238,8 +245,8 @@ class LogsServer(ViteServer):
238245
Enhanced server for serving Vite-built SPA with file watching and WebSocket support.
239246
240247
This server extends ViteServer to add:
241-
- WebSocket connections for real-time updates
242-
- Live log streaming
248+
- WebSocket connections for real-time evaluation row updates
249+
- REST API for log querying
243250
"""
244251

245252
def __init__(
@@ -250,17 +257,49 @@ def __init__(
250257
host: str = "localhost",
251258
port: Optional[int] = 8000,
252259
index_file: str = "index.html",
260+
elasticsearch_config: Optional[ElasticsearchConfig] = None,
253261
):
254262
# Initialize WebSocket manager
255263
self.websocket_manager = WebSocketManager()
256264

257-
super().__init__(build_dir, host, port if port is not None else 8000, index_file)
265+
# Initialize Elasticsearch client if config is provided
266+
self.elasticsearch_client: Optional[ElasticsearchClient] = None
267+
if elasticsearch_config:
268+
self.elasticsearch_client = ElasticsearchClient(elasticsearch_config)
269+
270+
self.app = FastAPI(title="Logs Server")
271+
272+
# Add WebSocket endpoint and API routes
273+
self._setup_websocket_routes()
274+
self._setup_api_routes()
275+
276+
super().__init__(build_dir, host, port if port is not None else 8000, index_file, self.app)
277+
278+
# Add CORS middleware to allow frontend access
279+
allowed_origins = [
280+
"http://localhost:5173", # Vite dev server
281+
"http://127.0.0.1:5173", # Vite dev server (alternative)
282+
f"http://{host}:{port}", # Server's own origin
283+
f"http://localhost:{port}", # Server on localhost
284+
]
285+
286+
self.app.add_middleware(
287+
CORSMiddleware,
288+
allow_origins=allowed_origins,
289+
allow_credentials=True,
290+
allow_methods=["*"],
291+
allow_headers=["*"],
292+
)
258293

259294
# Initialize evaluation watcher
260295
self.evaluation_watcher = EvaluationWatcher(self.websocket_manager)
261296

262-
# Add WebSocket endpoint
263-
self._setup_websocket_routes()
297+
# Log all registered routes for debugging
298+
logger.info("Registered routes:")
299+
for route in self.app.routes:
300+
path = getattr(route, "path", "UNKNOWN")
301+
methods = getattr(route, "methods", {"UNKNOWN"})
302+
logger.info(f" {methods} {path}")
264303

265304
# Subscribe to events and start listening for cross-process events
266305
event_bus.subscribe(self._handle_event)
@@ -275,14 +314,17 @@ async def websocket_endpoint(websocket: WebSocket):
275314
await self.websocket_manager.connect(websocket)
276315
try:
277316
while True:
278-
# Keep connection alive
317+
# Keep connection alive (for evaluation row updates)
279318
await websocket.receive_text()
280319
except WebSocketDisconnect:
281320
self.websocket_manager.disconnect(websocket)
282321
except Exception as e:
283322
logger.error(f"WebSocket error: {e}")
284323
self.websocket_manager.disconnect(websocket)
285324

325+
def _setup_api_routes(self):
326+
"""Set up API routes."""
327+
286328
@self.app.get("/api/status")
287329
async def status():
288330
"""Get server status including active connections."""
@@ -295,8 +337,75 @@ async def status():
295337
# LogsServer inherits from ViteServer which doesn't expose watch_paths
296338
# Expose an empty list to satisfy consumers and type checker
297339
"watch_paths": [],
340+
"elasticsearch_enabled": self.elasticsearch_client is not None,
298341
}
299342

343+
@self.app.get("/api/logs/{rollout_id}", response_model=LogsResponse, response_model_exclude_none=True)
344+
async def get_logs(
345+
rollout_id: str,
346+
level: Optional[str] = Query(None, description="Filter by log level (DEBUG, INFO, WARNING, ERROR)"),
347+
limit: int = Query(100, description="Maximum number of log entries to return"),
348+
) -> LogsResponse:
349+
"""Get logs for a specific rollout ID from Elasticsearch."""
350+
if not self.elasticsearch_client:
351+
raise HTTPException(status_code=503, detail="Elasticsearch is not configured for this logs server")
352+
353+
try:
354+
# Search for logs by rollout_id
355+
search_results = self.elasticsearch_client.search_by_match("rollout_id", rollout_id, size=limit)
356+
357+
if not search_results or "hits" not in search_results:
358+
# Return empty response using Pydantic model
359+
return LogsResponse(
360+
logs=[],
361+
total=0,
362+
rollout_id=rollout_id,
363+
filtered_by_level=level,
364+
)
365+
366+
log_entries = []
367+
for hit in search_results["hits"]["hits"]:
368+
log_data = hit["_source"]
369+
370+
# Filter by level if specified
371+
if level and log_data.get("level") != level:
372+
continue
373+
374+
# Create LogEntry using Pydantic model for validation
375+
try:
376+
log_entry = LogEntry(
377+
**log_data # Use ** to unpack the dict, Pydantic will handle field mapping
378+
)
379+
log_entries.append(log_entry)
380+
except Exception as e:
381+
# Log the error but continue processing other entries
382+
logger.warning(f"Failed to parse log entry: {e}, data: {log_data}")
383+
continue
384+
385+
# Sort by timestamp (most recent first)
386+
log_entries.sort(key=lambda x: x.timestamp, reverse=True)
387+
388+
# Get total count
389+
total_hits = search_results["hits"]["total"]
390+
if isinstance(total_hits, dict):
391+
# Elasticsearch 7+ format
392+
total_count = total_hits["value"]
393+
else:
394+
# Elasticsearch 6 format
395+
total_count = total_hits
396+
397+
# Return response using Pydantic model
398+
return LogsResponse(
399+
logs=log_entries,
400+
total=total_count,
401+
rollout_id=rollout_id,
402+
filtered_by_level=level,
403+
)
404+
405+
except Exception as e:
406+
logger.error(f"Error retrieving logs for rollout {rollout_id}: {e}")
407+
raise HTTPException(status_code=500, detail=f"Failed to retrieve logs: {str(e)}")
408+
300409
def _handle_event(self, event_type: str, data: Any) -> None:
301410
"""Handle events from the event bus."""
302411
if event_type in [LOG_EVENT_TYPE]:
@@ -353,7 +462,12 @@ def run(self):
353462
asyncio.run(self.run_async())
354463

355464

356-
def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[str] = None) -> FastAPI:
465+
def create_app(
466+
host: str = "localhost",
467+
port: int = 8000,
468+
build_dir: Optional[str] = None,
469+
elasticsearch_config: Optional[ElasticsearchConfig] = None,
470+
) -> FastAPI:
357471
"""
358472
Factory function to create a FastAPI app instance and start the server with async loops.
359473
@@ -364,6 +478,7 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st
364478
host: Host to bind to
365479
port: Port to bind to
366480
build_dir: Optional custom build directory path
481+
elasticsearch_config: Optional Elasticsearch configuration for log querying
367482
368483
Returns:
369484
FastAPI app instance with server running in background
@@ -373,17 +488,17 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st
373488
os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "vite-app", "dist")
374489
)
375490

376-
server = LogsServer(host=host, port=port, build_dir=build_dir)
491+
server = LogsServer(host=host, port=port, build_dir=build_dir, elasticsearch_config=elasticsearch_config)
377492
server.start_loops()
378493
return server.app
379494

380495

381496
# For backward compatibility and direct usage
382-
def serve_logs(port: Optional[int] = None):
497+
def serve_logs(port: Optional[int] = None, elasticsearch_config: Optional[ElasticsearchConfig] = None):
383498
"""
384499
Convenience function to create and run a LogsServer.
385500
"""
386-
server = LogsServer(port=port)
501+
server = LogsServer(port=port, elasticsearch_config=elasticsearch_config)
387502
server.run()
388503

389504

eval_protocol/utils/vite_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def __init__(
3232
host: str = "localhost",
3333
port: int = 8000,
3434
index_file: str = "index.html",
35-
lifespan: Optional[Callable[[FastAPI], Any]] = None,
35+
app: Optional[FastAPI] = None,
3636
):
3737
self.build_dir = Path(build_dir)
3838
self.host = host
3939
self.port = port
4040
self.index_file = index_file
41-
self.app = FastAPI(title="Vite SPA Server", lifespan=lifespan)
41+
self.app = app if app is not None else FastAPI(title="Vite SPA Server")
4242

4343
# Validate build directory exists
4444
if not self.build_dir.exists():

0 commit comments

Comments
 (0)