1919
2020
2121@dataclass
22- class WebSocketAudioSourceConfig :
23- """Configuration for WebSocket audio source.
24-
25- Parameters
26- ----------
27- uri : str
28- WebSocket URI for the audio source
29- sample_rate : int
30- Audio sample rate in Hz
31- """
32-
33- uri : str
34- sample_rate : int = 16000
35-
36-
37- @dataclass
38- class StreamingInferenceConfig :
22+ class StreamingHandlerConfig :
3923 """Configuration for streaming inference.
4024
4125 Parameters
4226 ----------
43- pipeline : blocks.Pipeline
44- Diarization pipeline configuration
27+ pipeline_class : type
28+ Pipeline class
29+ pipeline_config : blocks.PipelineConfig
30+ Pipeline configuration
4531 batch_size : int
4632 Number of inputs to process at once
4733 do_profile : bool
@@ -54,7 +40,8 @@ class StreamingInferenceConfig:
5440 Custom progress bar implementation
5541 """
5642
57- pipeline : blocks .Pipeline
43+ pipeline_class : type
44+ pipeline_config : blocks .PipelineConfig
5845 batch_size : int = 1
5946 do_profile : bool = True
6047 do_plot : bool = False
@@ -70,18 +57,16 @@ class ClientState:
7057 inference : StreamingInference
7158
7259
73- class StreamingInferenceHandler :
60+ class StreamingHandler :
7461 """Handles real-time speaker diarization inference for multiple audio sources over WebSocket.
7562
7663 This handler manages WebSocket connections from multiple clients, processing
7764 audio streams and performing speaker diarization in real-time.
7865
7966 Parameters
8067 ----------
81- inference_config : StreamingInferenceConfig
68+ config : StreamingHandlerConfig
8269 Streaming inference configuration
83- sample_rate : int, optional
84- Audio sample rate in Hz, by default 16000
8570 host : str, optional
8671 WebSocket server host, by default "127.0.0.1"
8772 port : int, optional
@@ -94,15 +79,13 @@ class StreamingInferenceHandler:
9479
9580 def __init__ (
9681 self ,
97- inference_config : StreamingInferenceConfig ,
98- sample_rate : int = 16000 ,
82+ config : StreamingHandlerConfig ,
9983 host : Text = "127.0.0.1" ,
10084 port : int = 7007 ,
10185 key : Optional [Union [Text , Path ]] = None ,
10286 certificate : Optional [Union [Text , Path ]] = None ,
10387 ):
104- self .inference_config = inference_config
105- self .sample_rate = sample_rate
88+ self .config = config
10689 self .host = host
10790 self .port = port
10891
@@ -135,26 +118,21 @@ def _create_client_state(self, client_id: Text) -> ClientState:
135118 """
136119 # Create a new pipeline instance with the same config
137120 # This ensures each client has its own state while sharing model weights
138- pipeline = self .inference_config .pipeline .__class__ (
139- self .inference_config .pipeline .config
140- )
141-
142- audio_config = WebSocketAudioSourceConfig (
143- uri = f"{ self .uri } :{ client_id } " , sample_rate = self .sample_rate
144- )
121+ pipeline = self .config .pipeline_class (self .config .pipeline_config )
145122
146123 audio_source = src .WebSocketAudioSource (
147- uri = audio_config .uri , sample_rate = audio_config .sample_rate
124+ uri = f"{ self .uri } :{ client_id } " ,
125+ sample_rate = self .config .pipeline_config .sample_rate ,
148126 )
149127
150128 inference = StreamingInference (
151129 pipeline = pipeline ,
152130 source = audio_source ,
153- batch_size = self .inference_config .batch_size ,
154- do_profile = self .inference_config .do_profile ,
155- do_plot = self .inference_config .do_plot ,
156- show_progress = self .inference_config .show_progress ,
157- progress_bar = self .inference_config .progress_bar ,
131+ batch_size = self .config .batch_size ,
132+ do_profile = self .config .do_profile ,
133+ do_plot = self .config .do_plot ,
134+ show_progress = self .config .show_progress ,
135+ progress_bar = self .config .progress_bar ,
158136 )
159137
160138 return ClientState (audio_source = audio_source , inference = inference )
@@ -174,16 +152,15 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
174152
175153 if client_id not in self ._clients :
176154 try :
177- client_state = self ._create_client_state (client_id )
178- self ._clients [client_id ] = client_state
155+ self ._clients [client_id ] = self ._create_client_state (client_id )
179156
180157 # Setup RTTM response hook
181- client_state .inference .attach_hooks (
158+ self . _clients [ client_id ] .inference .attach_hooks (
182159 lambda ann_wav : self .send (client_id , ann_wav [0 ].to_rttm ())
183160 )
184161
185162 # Start inference
186- client_state .inference ()
163+ self . _clients [ client_id ] .inference ()
187164 logger .info (f"Started inference for client: { client_id } " )
188165
189166 # Send ready notification to client
0 commit comments