From d371b7094f2a5d3d6150e0a9c7236d7de856664f Mon Sep 17 00:00:00 2001 From: Long Chen Date: Tue, 18 Feb 2025 19:13:16 +0800 Subject: [PATCH 1/3] add user transcription for realtime model --- examples/roomio_worker.py | 2 +- livekit-agents/livekit/agents/llm/__init__.py | 4 ++ livekit-agents/livekit/agents/llm/realtime.py | 18 +++++ .../livekit/agents/llm/remote_chat_context.py | 3 + .../livekit/agents/pipeline/task_activity.py | 36 ++++++++++ .../plugins/openai/realtime/realtime_model.py | 68 ++++++++++++++++++- 6 files changed, 129 insertions(+), 2 deletions(-) diff --git a/examples/roomio_worker.py b/examples/roomio_worker.py index ce29e848bc..fa17ea7116 100644 --- a/examples/roomio_worker.py +++ b/examples/roomio_worker.py @@ -45,7 +45,7 @@ async def entrypoint(ctx: JobContext): await ctx.connect() agent = PipelineAgent( - task=EchoTask(), + task=AlloyTask(), ) await agent.start(room=ctx.room) diff --git a/livekit-agents/livekit/agents/llm/__init__.py b/livekit-agents/livekit/agents/llm/__init__.py index 495fb5ec08..31af0c8e96 100644 --- a/livekit-agents/livekit/agents/llm/__init__.py +++ b/livekit-agents/livekit/agents/llm/__init__.py @@ -34,6 +34,8 @@ GenerationCreatedEvent, InputSpeechStartedEvent, InputSpeechStoppedEvent, + InputTranscriptionCompleted, + InputTranscriptionFailed, MessageGeneration, RealtimeCapabilities, RealtimeError, @@ -74,6 +76,8 @@ "RealtimeError", "RealtimeCapabilities", "RealtimeSession", + "InputTranscriptionCompleted", + "InputTranscriptionFailed", "InputSpeechStartedEvent", "InputSpeechStoppedEvent", "GenerationCreatedEvent", diff --git a/livekit-agents/livekit/agents/llm/realtime.py b/livekit-agents/livekit/agents/llm/realtime.py index caf3ca9ecf..f48bf63cd2 100644 --- a/livekit-agents/livekit/agents/llm/realtime.py +++ b/livekit-agents/livekit/agents/llm/realtime.py @@ -70,6 +70,8 @@ async def aclose(self) -> None: ... EventTypes = Literal[ "input_speech_started", # serverside VAD (also used for interruptions) "input_speech_stopped", # serverside VAD + "input_audio_transcription_completed", + "input_audio_transcription_failed", "generation_created", "error", ] @@ -77,6 +79,22 @@ async def aclose(self) -> None: ... TEvent = TypeVar("TEvent") +@dataclass +class InputTranscriptionCompleted: + item_id: str + """id of the item""" + transcript: str + """transcript of the input audio""" + + +@dataclass +class InputTranscriptionFailed: + item_id: str + """id of the item""" + message: str + """error message""" + + class RealtimeSession( ABC, rtc.EventEmitter[Union[EventTypes, TEvent]], diff --git a/livekit-agents/livekit/agents/llm/remote_chat_context.py b/livekit-agents/livekit/agents/llm/remote_chat_context.py index 7998f55884..b9af6a6aa8 100644 --- a/livekit-agents/livekit/agents/llm/remote_chat_context.py +++ b/livekit-agents/livekit/agents/llm/remote_chat_context.py @@ -28,6 +28,9 @@ def to_chat_ctx(self) -> ChatContext: current_node = current_node._next return ChatContext(items=items) + + def get(self, item_id: str) -> _RemoteChatItem | None: + return self._id_to_item.get(item_id) def insert(self, previous_item_id: str | None, message: ChatItem) -> None: """ diff --git a/livekit-agents/livekit/agents/pipeline/task_activity.py b/livekit-agents/livekit/agents/pipeline/task_activity.py index 7c48e1d28e..ac1d36aba1 100644 --- a/livekit-agents/livekit/agents/pipeline/task_activity.py +++ b/livekit-agents/livekit/agents/pipeline/task_activity.py @@ -122,6 +122,14 @@ async def start(self) -> None: self._rt_session.on( "input_speech_stopped", self._on_input_speech_stopped ) + self._rt_session.on( + "input_audio_transcription_completed", + self._on_input_audio_transcription_completed, + ) + self._rt_session.on( + "input_audio_transcription_failed", + self._on_input_audio_transcription_failed, + ) try: await self._rt_session.update_instructions( self._agent_task.instructions @@ -303,6 +311,34 @@ def _on_input_speech_started(self, _: llm.InputSpeechStartedEvent) -> None: def _on_input_speech_stopped(self, _: llm.InputSpeechStoppedEvent) -> None: log_event("input_speech_stopped") + self.on_interim_transcript( + stt.SpeechEvent( + stt.SpeechEventType.INTERIM_TRANSCRIPT, + alternatives=[stt.SpeechData(text="", language="")], + ) + ) + + def _on_input_audio_transcription_completed( + self, ev: llm.InputTranscriptionCompleted + ) -> None: + log_event("input_audio_transcription_completed") + self.on_final_transcript( + stt.SpeechEvent( + stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[stt.SpeechData(text=ev.transcript, language="")], + ) + ) + + def _on_input_audio_transcription_failed( + self, ev: llm.InputTranscriptionFailed + ) -> None: + log_event("input_audio_transcription_failed") + self.on_final_transcript( + stt.SpeechEvent( + stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[stt.SpeechData(text="", language="")], + ) + ) def _on_generation_created(self, ev: llm.GenerationCreatedEvent) -> None: if self.draining: diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 073a4ad0ad..1799db0927 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -3,6 +3,7 @@ import asyncio import base64 from dataclasses import dataclass +from typing import Literal, Optional from livekit import rtc from livekit.agents import llm, utils @@ -18,6 +19,8 @@ ConversationItemCreateEvent, ConversationItemDeletedEvent, ConversationItemDeleteEvent, + ConversationItemInputAudioTranscriptionCompletedEvent, + ConversationItemInputAudioTranscriptionFailedEvent, ConversationItemTruncateEvent, ErrorEvent, InputAudioBufferAppendEvent, @@ -59,10 +62,19 @@ NUM_CHANNELS = 1 +@dataclass +class _InputAudioTranscription: + model: Literal["whisper-1"] = "whisper-1" + + +DEFAULT_INPUT_AUDIO_TRANSCRIPTION = _InputAudioTranscription() + + @dataclass class _RealtimeOptions: model: str voice: str + input_audio_transcription: Optional[_InputAudioTranscription] @dataclass @@ -86,11 +98,18 @@ def __init__( *, model: str = "gpt-4o-realtime-preview", voice: str = "alloy", + input_audio_transcription: Optional[ + _InputAudioTranscription + ] = DEFAULT_INPUT_AUDIO_TRANSCRIPTION, client: openai.AsyncClient | None = None, ) -> None: super().__init__(capabilities=llm.RealtimeCapabilities(message_truncation=True)) - self._opts = _RealtimeOptions(model=model, voice=voice) + self._opts = _RealtimeOptions( + model=model, + voice=voice, + input_audio_transcription=input_audio_transcription, + ) self._client = client or openai.AsyncClient() def session(self) -> "RealtimeSession": @@ -145,6 +164,15 @@ async def _listen_for_events() -> None: self._handle_conversion_item_created(event) elif event.type == "conversation.item.deleted": self._handle_conversion_item_deleted(event) + elif ( + event.type + == "conversation.item.input_audio_transcription.completed" + ): + self._handle_conversion_item_input_audio_transcription_completed( + event + ) + elif event.type == "conversation.item.input_audio_transcription.failed": + self._handle_conversion_item_input_audio_transcription_failed(event) elif event.type == "response.audio_transcript.delta": self._handle_response_audio_transcript_delta(event) elif event.type == "response.audio.delta": @@ -168,12 +196,23 @@ async def _forward_input() -> None: except Exception: break + input_audio_transcription: Optional[ + session_update_event.SessionInputAudioTranscription + ] = None + if self._realtime_model._opts.input_audio_transcription: + input_audio_transcription = ( + session_update_event.SessionInputAudioTranscription( + model=self._realtime_model._opts.input_audio_transcription.model, + ) + ) + self._msg_ch.send_nowait( SessionUpdateEvent( type="session.update", session=session_update_event.Session( model=self._realtime_model._opts.model, # type: ignore voice=self._realtime_model._opts.voice, # type: ignore + input_audio_transcription=input_audio_transcription, ), event_id=utils.shortuuid("session_update_"), ) @@ -428,6 +467,33 @@ def _handle_conversion_item_deleted( if fut := self._item_delete_future.pop(event.item_id, None): fut.set_result(None) + def _handle_conversion_item_input_audio_transcription_completed( + self, event: ConversationItemInputAudioTranscriptionCompletedEvent + ) -> None: + remote_item = self._remote_chat_ctx.get(event.item_id) + if remote_item: + remote_item.item.content.append(event.transcript) + self.emit( + "input_audio_transcription_completed", + llm.InputTranscriptionCompleted( + item_id=event.item_id, transcript=event.transcript + ), + ) + + def _handle_conversion_item_input_audio_transcription_failed( + self, event: ConversationItemInputAudioTranscriptionFailedEvent + ) -> None: + logger.error( + "OpenAI Realtime API failed to transcribe input audio", + extra={"error": event.error}, + ) + self.emit( + "input_audio_transcription_failed", + llm.InputTranscriptionFailed( + item_id=event.item_id, message=event.error.message + ), + ) + def _handle_response_audio_transcript_delta( self, event: ResponseAudioTranscriptDeltaEvent ) -> None: From 07e79accadd10f86f8428037c139f24b5be82548 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Wed, 19 Feb 2025 00:06:30 +0800 Subject: [PATCH 2/3] wip for text input --- .../livekit/agents/pipeline/pipeline_agent.py | 6 ++++ .../livekit/agents/pipeline/room_io.py | 36 ++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 1689f33111..a6e15ade76 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -255,6 +255,12 @@ def generate_reply( allow_interruptions=allow_interruptions, ) + def interrupt(self) -> None: + if self._activity is None: + raise ValueError("PipelineAgent isn't running") + + self._activity.interrupt() + def update_task(self, task: AgentTask) -> None: self._agent_task = task diff --git a/livekit-agents/livekit/agents/pipeline/room_io.py b/livekit-agents/livekit/agents/pipeline/room_io.py index 8d9e8883f3..1712726b76 100644 --- a/livekit-agents/livekit/agents/pipeline/room_io.py +++ b/livekit-agents/livekit/agents/pipeline/room_io.py @@ -17,6 +17,8 @@ @dataclass(frozen=True) class RoomInputOptions: + text_enabled: bool = True + """Whether to subscribe to text input""" audio_enabled: bool = True """Whether to subscribe to audio""" video_enabled: bool = False @@ -48,6 +50,7 @@ class RoomOutputOptions: DEFAULT_ROOM_INPUT_OPTIONS = RoomInputOptions() DEFAULT_ROOM_OUTPUT_OPTIONS = RoomOutputOptions() LK_PUBLISH_FOR_ATTR = "lk.publish_for" +LK_TEXT_INPUT_TOPIC = "lk.room_text_input" class BaseStreamHandle: @@ -226,6 +229,7 @@ def __init__( """ self._options = options self._room = room + self._agent: Optional["PipelineAgent"] = None self._tasks: set[asyncio.Task] = set() # target participant @@ -263,6 +267,12 @@ def __init__( for participant in self._room.remote_participants.values(): self._on_participant_connected(participant) + # text input from datastream + if options.text_enabled: + self._room.register_text_stream_handler( + LK_TEXT_INPUT_TOPIC, self._on_text_input + ) + @property def audio(self) -> AsyncIterator[rtc.AudioFrame] | None: if not self._audio_handle: @@ -287,7 +297,8 @@ async def start(self, agent: Optional["PipelineAgent"] = None) -> None: # link to the first connected participant if not set self.set_participant(participant.identity) - if not agent: + self._agent = agent + if not self._agent: return agent.input.audio = self.audio @@ -399,6 +410,29 @@ async def _capture_text(): self._tasks.add(task) task.add_done_callback(self._tasks.discard) + def _on_text_input( + self, reader: rtc.TextStreamReader, participant_identity: str + ) -> None: + if participant_identity != self._participant_identity: + return + + async def _read_text(): + if not self._agent: + return + + text = await reader.read_all() + # TODO(long): text is always "0"? + logger.debug( + "received text input", + extra={"text": text, "participant": self._participant_identity}, + ) + self._agent.interrupt() + self._agent.generate_reply(user_input=text) + + task = asyncio.create_task(_read_text()) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + async def aclose(self) -> None: self._room.off("participant_connected", self._on_participant_connected) self._room.off("participant_disconnected", self._on_participant_disconnected) From fffa2dc07305662e004172f099f171a100698f4a Mon Sep 17 00:00:00 2001 From: Long Chen Date: Wed, 19 Feb 2025 18:57:25 +0800 Subject: [PATCH 3/3] fix when realtime model user transcription disabled --- livekit-agents/livekit/agents/llm/realtime.py | 2 +- livekit-agents/livekit/agents/pipeline/room_io.py | 2 +- .../livekit/agents/pipeline/task_activity.py | 13 +++++++------ .../plugins/openai/realtime/realtime_model.py | 10 +++++++++- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/livekit-agents/livekit/agents/llm/realtime.py b/livekit-agents/livekit/agents/llm/realtime.py index f48bf63cd2..6a31b015bc 100644 --- a/livekit-agents/livekit/agents/llm/realtime.py +++ b/livekit-agents/livekit/agents/llm/realtime.py @@ -18,7 +18,7 @@ class InputSpeechStartedEvent: @dataclass class InputSpeechStoppedEvent: - pass + user_transcription_enabled: bool @dataclass diff --git a/livekit-agents/livekit/agents/pipeline/room_io.py b/livekit-agents/livekit/agents/pipeline/room_io.py index 1712726b76..b2d233fa6d 100644 --- a/livekit-agents/livekit/agents/pipeline/room_io.py +++ b/livekit-agents/livekit/agents/pipeline/room_io.py @@ -297,6 +297,7 @@ async def start(self, agent: Optional["PipelineAgent"] = None) -> None: # link to the first connected participant if not set self.set_participant(participant.identity) + # TODO(long): should we force the agent to be set or provide a set_agent method? self._agent = agent if not self._agent: return @@ -421,7 +422,6 @@ async def _read_text(): return text = await reader.read_all() - # TODO(long): text is always "0"? logger.debug( "received text input", extra={"text": text, "participant": self._participant_identity}, diff --git a/livekit-agents/livekit/agents/pipeline/task_activity.py b/livekit-agents/livekit/agents/pipeline/task_activity.py index 02c38ea7eb..735caefc02 100644 --- a/livekit-agents/livekit/agents/pipeline/task_activity.py +++ b/livekit-agents/livekit/agents/pipeline/task_activity.py @@ -319,14 +319,15 @@ def _on_input_speech_started(self, _: llm.InputSpeechStartedEvent) -> None: log_event("input_speech_started") self.interrupt() # input_speech_started is also interrupting on the serverside realtime session - def _on_input_speech_stopped(self, _: llm.InputSpeechStoppedEvent) -> None: + def _on_input_speech_stopped(self, ev: llm.InputSpeechStoppedEvent) -> None: log_event("input_speech_stopped") - self.on_interim_transcript( - stt.SpeechEvent( - stt.SpeechEventType.INTERIM_TRANSCRIPT, - alternatives=[stt.SpeechData(text="", language="")], + if ev.user_transcription_enabled: + self.on_interim_transcript( + stt.SpeechEvent( + stt.SpeechEventType.INTERIM_TRANSCRIPT, + alternatives=[stt.SpeechData(text="", language="")], + ) ) - ) def _on_input_audio_transcription_completed( self, ev: llm.InputTranscriptionCompleted diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 1799db0927..95f3becad5 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -397,7 +397,15 @@ def _handle_input_audio_buffer_speech_started( def _handle_input_audio_buffer_speech_stopped( self, _: InputAudioBufferSpeechStoppedEvent ) -> None: - self.emit("input_speech_stopped", llm.InputSpeechStoppedEvent()) + user_transcription_enabled = ( + self._realtime_model._opts.input_audio_transcription is not None + ) + self.emit( + "input_speech_stopped", + llm.InputSpeechStoppedEvent( + user_transcription_enabled=user_transcription_enabled + ), + ) def _handle_response_created(self, event: ResponseCreatedEvent) -> None: assert event.response.id is not None, "response.id is None"