diff --git a/api.py b/api.py index 88d2e61..571fb9b 100644 --- a/api.py +++ b/api.py @@ -99,10 +99,11 @@ class ImageCaptioningInput(BaseModel): class WhisperInputs(BaseModel): audio: str - task: typing.Literal["translate", "transcribe"] = "transcribe" - language: str = None + task: typing.Literal["translate", "transcribe"] | None = None + language: str | None = None return_timestamps: bool = False - decoder_kwargs: dict = None + max_length: int | None = None + decoder_kwargs: dict | None = None chunk_length_s: float = 30 stride_length_s: typing.Tuple[float, float] = (6, 0) diff --git a/chart/model-values.yaml b/chart/model-values.yaml index d430eb1..6b177d1 100644 --- a/chart/model-values.yaml +++ b/chart/model-values.yaml @@ -199,6 +199,36 @@ deployments: Sunbird/asr-whisper-large-v3-salt Jacaranda-Health/ASR-STT + - name: "common-whisper-akera-kikuyu-short" + image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-common:9" + autoscaling: + minReplicaCount: 0 + limits_gpu: "5Gi" + limits: + memory: "13Gi" + env: + QUEUE_PREFIX: "gooey-gpu/short" + WHISPER_TOKENIZER_FROM: "akera/whisper-large-v3-kik-full_v2" + IMPORTS: |- + common.whisper + WHISPER_MODEL_IDS: |- + akera/whisper-large-v3-kik-full_v2 + + - name: "common-whisper-akera-kikuyu-long" + image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-common:9" + autoscaling: + minReplicaCount: 0 + limits_gpu: "10Gi" + limits: + memory: "27Gi" + env: + QUEUE_PREFIX: "gooey-gpu/long" + WHISPER_TOKENIZER_FROM: "akera/whisper-large-v3-kik-full_v2" + IMPORTS: |- + common.whisper + WHISPER_MODEL_IDS: |- + akera/whisper-large-v3-kik-full_v2 + - name: "common-whisper-en-short" image: *commonImgOld limits: diff --git a/common/whisper.py b/common/whisper.py index 81ced96..5580923 100644 --- a/common/whisper.py +++ b/common/whisper.py @@ -27,6 +27,8 @@ def whisper(pipeline: PipelineInfo, inputs: WhisperInputs) -> AsrOutput: generate_kwargs["language"] = inputs.language if inputs.task: generate_kwargs["task"] = inputs.task + if inputs.max_length: + generate_kwargs["max_length"] = inputs.max_length if generate_kwargs: kwargs["generate_kwargs"] = generate_kwargs