66import time
77from contextlib import asynccontextmanager
88from queue import Queue
9- from typing import TYPE_CHECKING , Any , List , Optional
9+ from typing import TYPE_CHECKING , Any , Dict , List , Optional
1010
1111import psutil
1212import uvicorn
13- from fastapi import FastAPI , WebSocket , WebSocketDisconnect
13+ from fastapi import FastAPI , WebSocket , WebSocketDisconnect , HTTPException , Query
14+ from fastapi .middleware .cors import CORSMiddleware
1415
1516from eval_protocol .dataset_logger import default_logger
1617from eval_protocol .dataset_logger .dataset_logger import LOG_EVENT_TYPE
1718from eval_protocol .event_bus import event_bus
1819from eval_protocol .models import Status
1920from eval_protocol .utils .vite_server import ViteServer
21+ from eval_protocol .logging .elasticsearch_client import ElasticsearchClient
22+ from eval_protocol .types .remote_rollout_processor import ElasticsearchConfig
23+ from eval_protocol .utils .logs_models import LogEntry , LogsResponse
2024
2125if TYPE_CHECKING :
2226 from eval_protocol .models import EvaluationRow
@@ -71,8 +75,11 @@ async def _start_broadcast_loop(self):
7175 while True :
7276 try :
7377 # Wait for a message to be queued
74- message = await asyncio .get_event_loop ().run_in_executor (None , self ._broadcast_queue .get )
75- await self ._send_text_to_all_connections (message )
78+ message_data = await asyncio .get_event_loop ().run_in_executor (None , self ._broadcast_queue .get )
79+
80+ # Regular string message for all connections
81+ await self ._send_text_to_all_connections (str (message_data ))
82+
7683 except Exception as e :
7784 logger .error (f"Error in broadcast loop: { e } " )
7885 await asyncio .sleep (0.1 )
@@ -238,8 +245,8 @@ class LogsServer(ViteServer):
238245 Enhanced server for serving Vite-built SPA with file watching and WebSocket support.
239246
240247 This server extends ViteServer to add:
241- - WebSocket connections for real-time updates
242- - Live log streaming
248+ - WebSocket connections for real-time evaluation row updates
249+ - REST API for log querying
243250 """
244251
245252 def __init__ (
@@ -250,17 +257,49 @@ def __init__(
250257 host : str = "localhost" ,
251258 port : Optional [int ] = 8000 ,
252259 index_file : str = "index.html" ,
260+ elasticsearch_config : Optional [ElasticsearchConfig ] = None ,
253261 ):
254262 # Initialize WebSocket manager
255263 self .websocket_manager = WebSocketManager ()
256264
257- super ().__init__ (build_dir , host , port if port is not None else 8000 , index_file )
265+ # Initialize Elasticsearch client if config is provided
266+ self .elasticsearch_client : Optional [ElasticsearchClient ] = None
267+ if elasticsearch_config :
268+ self .elasticsearch_client = ElasticsearchClient (elasticsearch_config )
269+
270+ self .app = FastAPI (title = "Logs Server" )
271+
272+ # Add WebSocket endpoint and API routes
273+ self ._setup_websocket_routes ()
274+ self ._setup_api_routes ()
275+
276+ super ().__init__ (build_dir , host , port if port is not None else 8000 , index_file , self .app )
277+
278+ # Add CORS middleware to allow frontend access
279+ allowed_origins = [
280+ "http://localhost:5173" , # Vite dev server
281+ "http://127.0.0.1:5173" , # Vite dev server (alternative)
282+ f"http://{ host } :{ port } " , # Server's own origin
283+ f"http://localhost:{ port } " , # Server on localhost
284+ ]
285+
286+ self .app .add_middleware (
287+ CORSMiddleware ,
288+ allow_origins = allowed_origins ,
289+ allow_credentials = True ,
290+ allow_methods = ["*" ],
291+ allow_headers = ["*" ],
292+ )
258293
259294 # Initialize evaluation watcher
260295 self .evaluation_watcher = EvaluationWatcher (self .websocket_manager )
261296
262- # Add WebSocket endpoint
263- self ._setup_websocket_routes ()
297+ # Log all registered routes for debugging
298+ logger .info ("Registered routes:" )
299+ for route in self .app .routes :
300+ path = getattr (route , "path" , "UNKNOWN" )
301+ methods = getattr (route , "methods" , {"UNKNOWN" })
302+ logger .info (f" { methods } { path } " )
264303
265304 # Subscribe to events and start listening for cross-process events
266305 event_bus .subscribe (self ._handle_event )
@@ -275,14 +314,17 @@ async def websocket_endpoint(websocket: WebSocket):
275314 await self .websocket_manager .connect (websocket )
276315 try :
277316 while True :
278- # Keep connection alive
317+ # Keep connection alive (for evaluation row updates)
279318 await websocket .receive_text ()
280319 except WebSocketDisconnect :
281320 self .websocket_manager .disconnect (websocket )
282321 except Exception as e :
283322 logger .error (f"WebSocket error: { e } " )
284323 self .websocket_manager .disconnect (websocket )
285324
325+ def _setup_api_routes (self ):
326+ """Set up API routes."""
327+
286328 @self .app .get ("/api/status" )
287329 async def status ():
288330 """Get server status including active connections."""
@@ -295,8 +337,75 @@ async def status():
295337 # LogsServer inherits from ViteServer which doesn't expose watch_paths
296338 # Expose an empty list to satisfy consumers and type checker
297339 "watch_paths" : [],
340+ "elasticsearch_enabled" : self .elasticsearch_client is not None ,
298341 }
299342
343+ @self .app .get ("/api/logs/{rollout_id}" , response_model = LogsResponse , response_model_exclude_none = True )
344+ async def get_logs (
345+ rollout_id : str ,
346+ level : Optional [str ] = Query (None , description = "Filter by log level (DEBUG, INFO, WARNING, ERROR)" ),
347+ limit : int = Query (100 , description = "Maximum number of log entries to return" ),
348+ ) -> LogsResponse :
349+ """Get logs for a specific rollout ID from Elasticsearch."""
350+ if not self .elasticsearch_client :
351+ raise HTTPException (status_code = 503 , detail = "Elasticsearch is not configured for this logs server" )
352+
353+ try :
354+ # Search for logs by rollout_id
355+ search_results = self .elasticsearch_client .search_by_match ("rollout_id" , rollout_id , size = limit )
356+
357+ if not search_results or "hits" not in search_results :
358+ # Return empty response using Pydantic model
359+ return LogsResponse (
360+ logs = [],
361+ total = 0 ,
362+ rollout_id = rollout_id ,
363+ filtered_by_level = level ,
364+ )
365+
366+ log_entries = []
367+ for hit in search_results ["hits" ]["hits" ]:
368+ log_data = hit ["_source" ]
369+
370+ # Filter by level if specified
371+ if level and log_data .get ("level" ) != level :
372+ continue
373+
374+ # Create LogEntry using Pydantic model for validation
375+ try :
376+ log_entry = LogEntry (
377+ ** log_data # Use ** to unpack the dict, Pydantic will handle field mapping
378+ )
379+ log_entries .append (log_entry )
380+ except Exception as e :
381+ # Log the error but continue processing other entries
382+ logger .warning (f"Failed to parse log entry: { e } , data: { log_data } " )
383+ continue
384+
385+ # Sort by timestamp (most recent first)
386+ log_entries .sort (key = lambda x : x .timestamp , reverse = True )
387+
388+ # Get total count
389+ total_hits = search_results ["hits" ]["total" ]
390+ if isinstance (total_hits , dict ):
391+ # Elasticsearch 7+ format
392+ total_count = total_hits ["value" ]
393+ else :
394+ # Elasticsearch 6 format
395+ total_count = total_hits
396+
397+ # Return response using Pydantic model
398+ return LogsResponse (
399+ logs = log_entries ,
400+ total = total_count ,
401+ rollout_id = rollout_id ,
402+ filtered_by_level = level ,
403+ )
404+
405+ except Exception as e :
406+ logger .error (f"Error retrieving logs for rollout { rollout_id } : { e } " )
407+ raise HTTPException (status_code = 500 , detail = f"Failed to retrieve logs: { str (e )} " )
408+
300409 def _handle_event (self , event_type : str , data : Any ) -> None :
301410 """Handle events from the event bus."""
302411 if event_type in [LOG_EVENT_TYPE ]:
@@ -353,7 +462,12 @@ def run(self):
353462 asyncio .run (self .run_async ())
354463
355464
356- def create_app (host : str = "localhost" , port : int = 8000 , build_dir : Optional [str ] = None ) -> FastAPI :
465+ def create_app (
466+ host : str = "localhost" ,
467+ port : int = 8000 ,
468+ build_dir : Optional [str ] = None ,
469+ elasticsearch_config : Optional [ElasticsearchConfig ] = None ,
470+ ) -> FastAPI :
357471 """
358472 Factory function to create a FastAPI app instance and start the server with async loops.
359473
@@ -364,6 +478,7 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st
364478 host: Host to bind to
365479 port: Port to bind to
366480 build_dir: Optional custom build directory path
481+ elasticsearch_config: Optional Elasticsearch configuration for log querying
367482
368483 Returns:
369484 FastAPI app instance with server running in background
@@ -373,17 +488,17 @@ def create_app(host: str = "localhost", port: int = 8000, build_dir: Optional[st
373488 os .path .join (os .path .dirname (os .path .dirname (os .path .dirname (__file__ ))), "vite-app" , "dist" )
374489 )
375490
376- server = LogsServer (host = host , port = port , build_dir = build_dir )
491+ server = LogsServer (host = host , port = port , build_dir = build_dir , elasticsearch_config = elasticsearch_config )
377492 server .start_loops ()
378493 return server .app
379494
380495
381496# For backward compatibility and direct usage
382- def serve_logs (port : Optional [int ] = None ):
497+ def serve_logs (port : Optional [int ] = None , elasticsearch_config : Optional [ ElasticsearchConfig ] = None ):
383498 """
384499 Convenience function to create and run a LogsServer.
385500 """
386- server = LogsServer (port = port )
501+ server = LogsServer (port = port , elasticsearch_config = elasticsearch_config )
387502 server .run ()
388503
389504
0 commit comments