1+ import logging
2+ import socket
13from dataclasses import dataclass
24from pathlib import Path
3- from typing import Union , Text , Optional , AnyStr , Dict , Any , Callable
4- import logging
5+ from typing import Any , AnyStr , Callable , Dict , Optional , Text , Union
6+
57from websocket_server import WebsocketServer
6- import socket
78
89from . import blocks
910from . import sources as src
1213
1314# Configure logging
1415logging .basicConfig (
15- level = logging .INFO ,
16- format = '%(asctime)s - %(levelname)s - %(message)s'
16+ level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s"
1717)
1818logger = logging .getLogger (__name__ )
1919
@@ -29,6 +29,7 @@ class WebSocketAudioSourceConfig:
2929 sample_rate : int
3030 Audio sample rate in Hz
3131 """
32+
3233 uri : str
3334 sample_rate : int = 16000
3435
@@ -52,6 +53,7 @@ class StreamingInferenceConfig:
5253 progress_bar : Optional[ProgressBar]
5354 Custom progress bar implementation
5455 """
56+
5557 pipeline : blocks .Pipeline
5658 batch_size : int = 1
5759 do_profile : bool = True
@@ -63,6 +65,7 @@ class StreamingInferenceConfig:
6365@dataclass
6466class ClientState :
6567 """Represents the state of a connected client."""
68+
6669 audio_source : src .WebSocketAudioSource
6770 inference : StreamingInference
6871
@@ -102,7 +105,7 @@ def __init__(
102105 self .sample_rate = sample_rate
103106 self .host = host
104107 self .port = port
105-
108+
106109 # Server configuration
107110 self .uri = f"{ host } :{ port } "
108111 self ._clients : Dict [Text , ClientState ] = {}
@@ -132,16 +135,16 @@ def _create_client_state(self, client_id: Text) -> ClientState:
132135 """
133136 # Create a new pipeline instance with the same config
134137 # This ensures each client has its own state while sharing model weights
135- pipeline = self .inference_config .pipeline .__class__ (self .inference_config .pipeline .config )
136-
138+ pipeline = self .inference_config .pipeline .__class__ (
139+ self .inference_config .pipeline .config
140+ )
141+
137142 audio_config = WebSocketAudioSourceConfig (
138- uri = f"{ self .uri } :{ client_id } " ,
139- sample_rate = self .sample_rate
143+ uri = f"{ self .uri } :{ client_id } " , sample_rate = self .sample_rate
140144 )
141-
145+
142146 audio_source = src .WebSocketAudioSource (
143- uri = audio_config .uri ,
144- sample_rate = audio_config .sample_rate
147+ uri = audio_config .uri , sample_rate = audio_config .sample_rate
145148 )
146149
147150 inference = StreamingInference (
@@ -151,7 +154,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
151154 do_profile = self .inference_config .do_profile ,
152155 do_plot = self .inference_config .do_plot ,
153156 show_progress = self .inference_config .show_progress ,
154- progress_bar = self .inference_config .progress_bar
157+ progress_bar = self .inference_config .progress_bar ,
155158 )
156159
157160 return ClientState (audio_source = audio_source , inference = inference )
@@ -182,7 +185,7 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
182185 # Start inference
183186 client_state .inference ()
184187 logger .info (f"Started inference for client: { client_id } " )
185-
188+
186189 # Send ready notification to client
187190 self .send (client_id , "READY" )
188191 except Exception as e :
@@ -204,10 +207,7 @@ def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> No
204207 self .close (client_id )
205208
206209 def _on_message_received (
207- self ,
208- client : Dict [Text , Any ],
209- server : WebsocketServer ,
210- message : AnyStr
210+ self , client : Dict [Text , Any ], server : WebsocketServer , message : AnyStr
211211 ) -> None :
212212 """Process incoming client messages.
213213
@@ -245,16 +245,15 @@ def send(self, client_id: Text, message: AnyStr) -> None:
245245 if not message :
246246 return
247247
248- client = next (
249- (c for c in self .server .clients if c ["id" ] == client_id ),
250- None
251- )
252-
248+ client = next ((c for c in self .server .clients if c ["id" ] == client_id ), None )
249+
253250 if client is not None :
254251 try :
255252 self .server .send_message (client , message )
256253 except (socket .error , ConnectionError ) as e :
257- logger .warning (f"Client { client_id } disconnected while sending message: { e } " )
254+ logger .warning (
255+ f"Client { client_id } disconnected while sending message: { e } "
256+ )
258257 self .close (client_id )
259258 except Exception as e :
260259 logger .error (f"Failed to send message to client { client_id } : { e } " )
@@ -264,7 +263,7 @@ def run(self) -> None:
264263 logger .info (f"Starting WebSocket server on { self .uri } " )
265264 max_retries = 3
266265 retry_count = 0
267-
266+
268267 while retry_count < max_retries :
269268 try :
270269 self .server .run_forever ()
@@ -273,7 +272,9 @@ def run(self) -> None:
273272 logger .warning (f"WebSocket connection error: { e } " )
274273 retry_count += 1
275274 if retry_count < max_retries :
276- logger .info (f"Attempting to restart server (attempt { retry_count + 1 } /{ max_retries } )" )
275+ logger .info (
276+ f"Attempting to restart server (attempt { retry_count + 1 } /{ max_retries } )"
277+ )
277278 else :
278279 logger .error ("Max retry attempts reached. Server shutting down." )
279280 except Exception as e :
@@ -295,20 +296,24 @@ def close(self, client_id: Text) -> None:
295296 # Clean up pipeline state using built-in reset method
296297 client_state = self ._clients [client_id ]
297298 client_state .inference .pipeline .reset ()
298-
299+
299300 # Close audio source and remove client
300301 client_state .audio_source .close ()
301302 del self ._clients [client_id ]
302-
303+
303304 # Try to send a close frame to the client
304305 try :
305- client = next ((c for c in self .server .clients if c ["id" ] == client_id ), None )
306+ client = next (
307+ (c for c in self .server .clients if c ["id" ] == client_id ), None
308+ )
306309 if client :
307310 self .server .send_message (client , "CLOSE" )
308311 except Exception :
309312 pass # Ignore errors when trying to send close message
310-
311- logger .info (f"Closed connection and cleaned up state for client: { client_id } " )
313+
314+ logger .info (
315+ f"Closed connection and cleaned up state for client: { client_id } "
316+ )
312317 except Exception as e :
313318 logger .error (f"Error closing client { client_id } : { e } " )
314319 # Ensure client is removed from dictionary even if cleanup fails
0 commit comments