diff --git a/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.cpp b/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.cpp index 20daf0b..e147b12 100644 --- a/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.cpp +++ b/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.cpp @@ -1,20 +1,16 @@ +#include #include "org_mozilla_deepspeech_DeepSpeech.h" jint Java_org_mozilla_deepspeech_DeepSpeech_nCreateModel(JNIEnv *env, jclass, jstring modelPath, - jlong nCep, - jlong nContext, jstring alphabetConfigPath, - jlong beamWidth, jobject modelStatePtr) { - jboolean isModelPathCopy, isAlphaBetCopy; + jboolean isModelPathCopy; ModelState *ptr = nullptr; auto modelPathCStr = (char *) env->GetStringUTFChars(modelPath, &isModelPathCopy); - auto alphaBetCStr = (char *) env->GetStringUTFChars(alphabetConfigPath, &isAlphaBetCopy); - jint state = DS_CreateModel(modelPathCStr, static_cast(nCep), - static_cast(nContext), alphaBetCStr, - static_cast(beamWidth), - &ptr); +// https://github.com/mozilla/DeepSpeech/commit/8c820817794d445746aefb1b5347b35bf5e0c621#diff-0317a0e76ece10e0dba742af310a2362 + jint state = DS_CreateModel(modelPathCStr, &ptr); + auto *bufferPtr = (jlong *) (env->GetDirectBufferAddress(modelStatePtr)); bufferPtr[0] = reinterpret_cast(ptr); @@ -26,46 +22,62 @@ Java_org_mozilla_deepspeech_DeepSpeech_nCreateModel(JNIEnv *env, jclass, jstring } void Java_org_mozilla_deepspeech_DeepSpeech_destroyModel(JNIEnv *, jclass, jlong modelPtr) { - DS_DestroyModel(reinterpret_cast(modelPtr)); + DS_FreeModel(reinterpret_cast(modelPtr)); +} + +jint +Java_org_mozilla_deepspeech_DeepSpeech_getModelBeamWidth(JNIEnv *env, jclass, jlong modelStatePtr) { + return DS_GetModelBeamWidth((ModelState *) modelStatePtr); +} + +jint +Java_org_mozilla_deepspeech_DeepSpeech_setModelBeamWidth(JNIEnv *env, jclass, jlong modelStatePtr, jlong beamWidth) { + return DS_SetModelBeamWidth((ModelState *) modelStatePtr, static_cast(beamWidth)); +} + +jint +Java_org_mozilla_deepspeech_DeepSpeech_getModelSampleRate(JNIEnv *env, jclass, jlong modelStatePtr) { + return DS_GetModelSampleRate((ModelState *) modelStatePtr); } jint Java_org_mozilla_deepspeech_DeepSpeech_enableDecoderWithLM(JNIEnv *env, jclass, jlong modelStatePtr, - jstring alphaBetConfigPath, - jstring lmPath, - jstring triePath, jfloat alpha, - jfloat beta) { - jboolean isAlphabetStrCopy, isLmPathCopy, isTriePathCopy; - auto alphaBetConfigPathCStr = const_cast(env->GetStringUTFChars(alphaBetConfigPath, - &isAlphabetStrCopy)); - auto lmPathCStr = const_cast(env->GetStringUTFChars(lmPath, &isLmPathCopy)); - auto triePathCStr = const_cast(env->GetStringUTFChars(triePath, &isTriePathCopy)); - - jint status = DS_EnableDecoderWithLM((ModelState *) modelStatePtr, alphaBetConfigPathCStr, - lmPathCStr, triePathCStr, - alpha, beta); - - if (isAlphabetStrCopy == JNI_TRUE) { - env->ReleaseStringUTFChars(alphaBetConfigPath, alphaBetConfigPathCStr); - } + jstring scorerPath, + jfloat alpha, jfloat beta) { + jboolean isLmPathCopy; + auto scorerPathCStr = const_cast(env->GetStringUTFChars(scorerPath, &isLmPathCopy)); + + // https://github.com/mozilla/DeepSpeech/pull/2681/files + // lm: models/lm.binary, trie: models/trie + // became scorer: models/kenlm.scorer + DS_EnableExternalScorer((ModelState *) modelStatePtr, scorerPathCStr); + jint status = DS_SetScorerAlphaBeta((ModelState *) modelStatePtr, alpha, beta); + if (isLmPathCopy == JNI_TRUE) { - env->ReleaseStringUTFChars(lmPath, lmPathCStr); - } - if (isTriePathCopy == JNI_TRUE) { - env->ReleaseStringUTFChars(triePath, triePathCStr); + env->ReleaseStringUTFChars(scorerPath, scorerPathCStr); } return status; } +jint +Java_org_mozilla_deepspeech_DeepSpeech_setScorerAlphaBeta(JNIEnv *env, jclass, jlong modelStatePtr, + jfloat alpha, jfloat beta) { + return DS_SetScorerAlphaBeta((ModelState *) modelStatePtr, alpha, beta); +} + +jint +Java_org_mozilla_deepspeech_DeepSpeech_disableExternalScorer(JNIEnv *env, jclass, jlong modelStatePtr) { + jint status = DS_DisableExternalScorer((ModelState *) modelStatePtr); + return status; +} + jstring Java_org_mozilla_deepspeech_DeepSpeech_nSpeechToText(JNIEnv *env, jclass, jlong modelStatePtr, - jobject audioBuffer, jlong numSamples, - jlong sampleRate) { + jobject audioBuffer, jlong numSamples) { auto *array = (short *) (env->GetDirectBufferAddress(audioBuffer)); char *cStr = DS_SpeechToText((ModelState *) modelStatePtr, array, - static_cast(numSamples), - (unsigned int) sampleRate); + static_cast(numSamples)); if (cStr == nullptr) { return nullptr; } @@ -76,12 +88,10 @@ Java_org_mozilla_deepspeech_DeepSpeech_nSpeechToText(JNIEnv *env, jclass, jlong jstring Java_org_mozilla_deepspeech_DeepSpeech_speechToTextUnsafe(JNIEnv *env, jclass, jlong modelStatePtr, - jlong audioBuffer, jlong numSamples, - jlong sampleRate) { + jlong audioBuffer, jlong numSamples) { auto *array = (short *) (audioBuffer); char *cStr = DS_SpeechToText((ModelState *) modelStatePtr, array, - static_cast(numSamples), - (unsigned int) sampleRate); + static_cast(numSamples)); if (cStr == nullptr) { return nullptr; } @@ -95,12 +105,12 @@ Java_org_mozilla_deepspeech_DeepSpeech_nSpeechToTextWithMetadata(JNIEnv *env, jc jlong modelStatePtr, jobject audioBuffer, jlong bufferSize, - jlong sampleRate) { + jlong numResults) { auto *array = static_cast(env->GetDirectBufferAddress(audioBuffer)); auto metaPtr = reinterpret_cast(DS_SpeechToTextWithMetadata((ModelState *) modelStatePtr, array, static_cast(bufferSize), - static_cast(sampleRate))); + static_cast(numResults))); return metaPtr; } jlong @@ -108,23 +118,21 @@ Java_org_mozilla_deepspeech_DeepSpeech_speechToTextWithMetadataUnsafe(JNIEnv *, jlong modelStatePtr, jlong audioBuffer, jlong bufferSize, - jlong sampleRate) { + jlong numResults) { auto *array = (short *)audioBuffer; auto metaPtr = reinterpret_cast(DS_SpeechToTextWithMetadata((ModelState *) modelStatePtr, array, static_cast(bufferSize), - static_cast(sampleRate))); + static_cast(numResults))); return metaPtr; } jint Java_org_mozilla_deepspeech_DeepSpeech_nSetupStream(JNIEnv *env, jclass, jlong modelStatePtr, - jlong preAllocFrames, jlong sampleRate, jobject streamPtr) { StreamingState *pStreamingState; - jint status = DS_SetupStream((ModelState *) modelStatePtr, - static_cast(preAllocFrames), - static_cast(sampleRate), &pStreamingState); + jint status = DS_CreateStream((ModelState *) modelStatePtr, &pStreamingState); + auto p = (StreamingState **) env->GetDirectBufferAddress(streamPtr); *p = pStreamingState; return status; @@ -155,20 +163,34 @@ jstring Java_org_mozilla_deepspeech_DeepSpeech_finishStream(JNIEnv *env, jclass, } jlong -Java_org_mozilla_deepspeech_DeepSpeech_finishStreamWithMetadata(JNIEnv *, jclass, jlong streamPtr) { - return reinterpret_cast(DS_FinishStreamWithMetadata((StreamingState *) streamPtr)); +Java_org_mozilla_deepspeech_DeepSpeech_finishStreamWithMetadata(JNIEnv *, jclass, jlong streamPtr, jlong numResults) { + return reinterpret_cast(DS_FinishStreamWithMetadata((StreamingState *) streamPtr, numResults)); } void Java_org_mozilla_deepspeech_DeepSpeech_discardStream(JNIEnv *, jclass, jlong streamPtr) { - DS_DiscardStream((StreamingState *) streamPtr); + DS_FreeStream((StreamingState *) streamPtr); } void Java_org_mozilla_deepspeech_DeepSpeech_freeMetadata(JNIEnv *, jclass, jlong metaPtr) { DS_FreeMetadata((Metadata *) metaPtr); } -void Java_org_mozilla_deepspeech_DeepSpeech_printVersions(JNIEnv *, jclass) { - DS_PrintVersions(); +jstring Java_org_mozilla_deepspeech_DeepSpeech_getVersion(JNIEnv *env, jclass) { + char *cString = DS_Version(); + size_t cStrLen = strlen(cString); + jstring str = env->NewString(reinterpret_cast(cString), + static_cast(cStrLen)); + DS_FreeString(cString); + return str; +} + +jstring Java_org_mozilla_deepspeech_DeepSpeech_errorCodeToErrorMessage(JNIEnv *env, jclass, jlong errorCode) { + char *cString = DS_ErrorCodeToErrorMessage(errorCode); + size_t cStrLen = strlen(cString); + jstring str = env->NewString(reinterpret_cast(cString), + static_cast(cStrLen)); + DS_FreeString(cString); + return str; } jint Java_org_mozilla_deepspeech_DeepSpeech_nGetConfiguration(JNIEnv *, jclass) { @@ -181,4 +203,4 @@ jint Java_org_mozilla_deepspeech_DeepSpeech_nGetConfiguration(JNIEnv *, jclass) return BuildConfiguration::INVALID; // This should never be returned #endif #endif -} \ No newline at end of file +} diff --git a/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.h b/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.h index 7d676e5..3ce2360 100644 --- a/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.h +++ b/DeepSpeech-jni/org_mozilla_deepspeech_DeepSpeech.h @@ -16,19 +16,35 @@ enum BuildConfiguration { }; JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_nCreateModel - (JNIEnv *, jclass, jstring, jlong, jlong, jstring, jlong, jobject); + (JNIEnv *, jclass, jstring, jobject); JNIEXPORT void JNICALL Java_org_mozilla_deepspeech_DeepSpeech_destroyModel (JNIEnv *, jclass, jlong); + +JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_getModelBeamWidth + (JNIEnv *, jclass, jlong); + +JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_setModelBeamWidth + (JNIEnv *, jclass, jlong, jlong); + +JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_getModelSampleRate + (JNIEnv *, jclass, jlong); + JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_enableDecoderWithLM - (JNIEnv *, jclass, jlong, jstring, jstring, jstring, jfloat, jfloat); + (JNIEnv *, jclass, jlong, jstring, jfloat, jfloat); + +JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_setScorerAlphaBeta + (JNIEnv *env, jclass, jlong, jfloat, jfloat); + +JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_disableExternalScorer + (JNIEnv *env, jclass, jlong); JNIEXPORT jstring JNICALL Java_org_mozilla_deepspeech_DeepSpeech_nSpeechToText - (JNIEnv *, jclass, jlong, jobject, jlong, jlong); + (JNIEnv *, jclass, jlong, jobject, jlong); JNIEXPORT jstring JNICALL Java_org_mozilla_deepspeech_DeepSpeech_speechToTextUnsafe - (JNIEnv *, jclass, jlong, jlong, jlong, jlong); + (JNIEnv *, jclass, jlong, jlong, jlong); JNIEXPORT jlong JNICALL Java_org_mozilla_deepspeech_DeepSpeech_nSpeechToTextWithMetadata (JNIEnv *, jclass, jlong, jobject, jlong, jlong); @@ -37,7 +53,7 @@ JNIEXPORT jlong JNICALL Java_org_mozilla_deepspeech_DeepSpeech_speechToTextWithM (JNIEnv *, jclass, jlong, jlong, jlong, jlong); JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_nSetupStream - (JNIEnv *, jclass, jlong, jlong, jlong, jobject); + (JNIEnv *, jclass, jlong, jobject); JNIEXPORT void JNICALL Java_org_mozilla_deepspeech_DeepSpeech_nFeedAudioContent (JNIEnv *, jclass, jlong, jobject, jlong); @@ -57,9 +73,12 @@ JNIEXPORT void JNICALL Java_org_mozilla_deepspeech_DeepSpeech_discardStream JNIEXPORT void JNICALL Java_org_mozilla_deepspeech_DeepSpeech_freeMetadata (JNIEnv *, jclass, jlong); -JNIEXPORT void JNICALL Java_org_mozilla_deepspeech_DeepSpeech_printVersions +JNIEXPORT jstring JNICALL Java_org_mozilla_deepspeech_DeepSpeech_getVersion (JNIEnv *, jclass); +JNIEXPORT jstring JNICALL Java_org_mozilla_deepspeech_DeepSpeech_errorCodeToErrorMessage + (JNIEnv *, jclass, jlong); + JNIEXPORT jint JNICALL Java_org_mozilla_deepspeech_DeepSpeech_nGetConfiguration (JNIEnv *, jclass); diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c11c862 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM java:8 + +COPY . /idear +WORKDIR /idear +RUN ./gradlew buildPlugin + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5007d3e --- /dev/null +++ b/Makefile @@ -0,0 +1,13 @@ +default: + @docker run --rm -v $(PWD):/idear -w /idear openasr/idear ./gradlew buildPlugin + +test: + @docker run --rm -v $(PWD):/idear -w /idear openasr/idear ./gradlew test + +docker: + @rm -rf build out + @docker build -t openasr/idear . + +push: + @docker push openasr/idear:latest + diff --git a/libdeepspeech/build.gradle b/libdeepspeech/build.gradle index 396d3ee..cceee4e 100644 --- a/libdeepspeech/build.gradle +++ b/libdeepspeech/build.gradle @@ -1,9 +1,11 @@ plugins { id 'java' + id 'distribution' + id 'maven-publish' } group 'org.mozilla.deepspeech' -version '1.0' +version '1.0.4-SNAPSHOT' sourceCompatibility = 1.6 @@ -14,4 +16,50 @@ repositories { dependencies { testCompile group: 'junit', name: 'junit', version: '4.12' compile("org.jetbrains:annotations:16.0.1") // Jetbrains annotations library -} \ No newline at end of file +} + +//task linuxJniZip(type: Zip) { +// archiveName("deepspeech-jni.zip") +// destinationDir(file("$buildDir/dist")) +//// archiveFileName = "deepspeech-jni.zip" +//// destinationDirectory = file("$buildDir/target") +// from "../DeepSpeech-jni/libs/linux" +//} + +//artifacts { +// linuxJniZip +//} + +distributions { + linux { + contents { + from '../DeepSpeech-jni/libs/linux' + } + } + windows { + contents { + from '../DeepSpeech-jni/libs/windows' + } + } +} + +publishing { + publications { + libdeepspeech(MavenPublication) { + from components.java + artifact linuxDistZip + artifact windowsDistZip + } + } + repositories { + maven { + name "GitHubPackages" + url "https://maven.pkg.github.com/" + System.getenv('GITHUB_REPOSITORY') + credentials { + username System.getenv('GITHUB_ACTOR') + password System.getenv('GITHUB_TOKEN') + } + } + } +} + diff --git a/libdeepspeech/src/main/java/org/mozilla/deepspeech/DeepSpeech.java b/libdeepspeech/src/main/java/org/mozilla/deepspeech/DeepSpeech.java index ca26ad6..475bf63 100644 --- a/libdeepspeech/src/main/java/org/mozilla/deepspeech/DeepSpeech.java +++ b/libdeepspeech/src/main/java/org/mozilla/deepspeech/DeepSpeech.java @@ -23,7 +23,8 @@ * Unsafe. Use wrapper objects instead! * * @see org.mozilla.deepspeech.recognition.DeepSpeechModel - * @see SpeechRecognitionAudioStream + * @see org.mozilla.deepspeech.recognition.stream.SpeechRecognitionAudioStream + * @see org.mozilla.deepspeech.recognition.stream.StreamingState * @see org.mozilla.deepspeech.recognition.SpeechRecognitionResult */ public class DeepSpeech { @@ -119,10 +120,6 @@ public static class ErrorCodes { * An object providing an interface to a trained DeepSpeech recognition. * * @param modelPath The path to the frozen recognition graph. - * @param nCep The number of cepstrum the recognition was trained with. - * @param nContext The context window the recognition was trained with. - * @param alphabetConfigPath The path to the configuration file specifying - * the alphabet used by the network. See alphabet.h. * @param beamWidth The beam width used by the decoder. A larger beam * width generates better results at the cost of decoding * time. @@ -135,16 +132,12 @@ public static class ErrorCodes { */ @NativeType("jint") public static int createModel(@NotNull String modelPath, - @NativeType("jlong") long nCep, - @NativeType("jlong") long nContext, - @NotNull String alphabetConfigPath, - @NativeType("jlong") long beamWidth, @CallByReference @NativeType("struct ModelState *") @NotNull @DynamicPointer("destroyModel") ByteBuffer modelStatePointer) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { BufferUtils.checkByteBuffer(modelStatePointer, ByteOrder.nativeOrder(), 8); // 8 -> Long.BYTES - return nCreateModel(modelPath, nCep, nContext, alphabetConfigPath, beamWidth, modelStatePointer); + return nCreateModel(modelPath, modelStatePointer); } /** @@ -152,10 +145,6 @@ public static int createModel(@NotNull String modelPath, */ @NativeType("jint") private static native int nCreateModel(@NotNull String modelPath, - @NativeType("jlong") long nCep, - @NativeType("jlong") long nContext, - @NotNull String alphabetConfigPath, - @NativeType("jlong") long beamWidth, @CallByReference @NativeType("struct ModelState *") @NotNull @@ -168,32 +157,43 @@ private static native int nCreateModel(@NotNull String modelPath, */ public static native void destroyModel(@NativeType("ModelState *") long modelStatePointer); + public static native int getModelBeamWidth(@NativeType("ModelState *") long modelStatePointer); + + public static native int setModelBeamWidth(@NativeType("ModelState *") long modelStatePointer, long beamWidth); + + public static native int getModelSampleRate(@NativeType("ModelState *") long modelStatePointer); + /** * Enable decoding using beam scoring with a KenLM language recognition. * * @param modelStatePtr The ModelState pointer for the recognition being changed. - * @param alphaBetConfigPath The path to the configuration file specifying the alphabet used by the network. See alphabet.h. - * @param lmPath The path to the language recognition binary file. - * @param triePath The path to the trie file build from the same vocabulary as the language recognition binary. + * @param scorerPath The path to the scorer package generated with `data/lm/generate_package.py`. * @param alpha The alpha hyperparameter of the CTC decoder. Language Model weight. * @param beta The beta hyperparameter of the CTC decoder. Word insertion weight. * @return Zero on success, non-zero on failure (invalid arguments). */ @NativeType("jint") public static native int enableDecoderWithLM(@NativeType("struct ModelState *") long modelStatePtr, - @NativeType("jstring") @NotNull String alphaBetConfigPath, - @NativeType("jstring") @NotNull String lmPath, - @NativeType("jstring") @NotNull String triePath, + @NativeType("jstring") @NotNull String scorerPath, @NativeType("jfloat") float alpha, @NativeType("jfloat") float beta); + @NativeType("jint") + @Calls("DS_SetScorerAlphaBeta") + public static native int setScorerAlphaBeta(@NativeType("struct ModelState *") long modelStatePtr, + @NativeType("jfloat") float alpha, + @NativeType("jfloat") float beta); + + @NativeType("jint") + @Calls("DS_DisableExternalScorer") + public static native int disableExternalScorer(@NativeType("struct ModelState *") long modelStatePtr); + /** * Use the DeepSpeech recognition to perform Speech-To-Text. * * @param modelStatePointer The ModelState pointer for the recognition to use. * @param audioBuffer A 16-bit, mono raw audio signal at the appropriate sample rate. * @param numSamples The number of samples in the audio signal. - * @param sampleRate The sample-rate of the audio signal. * @return The STT result. Returns null on error. * @throws UnexpectedBufferCapacityException if #numSamples does not match the allocated buffer capacity. Condition: {@code numSamples * Short.BYTES > audioBuffer.capacity()} * @throws IncorrectBufferByteOrderException if the audioBuffer has a byte order different to {@link ByteOrder#LITTLE_ENDIAN}. @@ -205,39 +205,35 @@ public static native int enableDecoderWithLM(@NativeType("struct ModelState *") public static String speechToText(@NativeType("struct ModelState *") long modelStatePointer, @NativeType("const short *") @NotNull ByteBuffer audioBuffer, - @NativeType("jlong") long numSamples, - @NativeType("jlong") long sampleRate) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { + @NativeType("jlong") long numSamples) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { BufferUtils.checkByteBuffer(audioBuffer, ByteOrder.LITTLE_ENDIAN, numSamples * 2 /* sizeof(short) */); - return nSpeechToText(modelStatePointer, audioBuffer, numSamples, sampleRate); + return nSpeechToText(modelStatePointer, audioBuffer, numSamples); } /** - * Unexposed unsafe method that should not be used. Use instead: {@link #speechToText(long, ByteBuffer, long, long)} + * Unexposed unsafe method that should not be used. Use instead: {@link #speechToText(long, ByteBuffer, long)} */ @Nullable @Calls("DS_SpeechToText") private static native String nSpeechToText(@NativeType("struct ModelState *") long modelStatePointer, @NativeType("const short *") @NotNull ByteBuffer audioBuffer, - @NativeType("jlong") long numSamples, - @NativeType("jlong") long sampleRate); + @NativeType("jlong") long numSamples); /** - * WARNING: Unsafe function. Consider using {@link #speechToText(long, ByteBuffer, long, long)} + * WARNING: Unsafe function. Consider using {@link #speechToText(long, ByteBuffer, long)} * Use the DeepSpeech recognition to perform Speech-To-Text. * * @param modelStatePointer The ModelState pointer for the recognition to use. * @param audioBufferPtr A 16-bit, mono raw audio signal at the appropriate sample rate. * @param numSamples The number of samples in the audio signal. - * @param sampleRate The sample-rate of the audio signal. * @return The STT result. Returns null on error. */ @Nullable @Calls("DS_SpeechToText") public static native String speechToTextUnsafe(@NativeType("struct ModelState *") long modelStatePointer, @NativeType("const short *") long audioBufferPtr, - @NativeType("jlong") long numSamples, - @NativeType("jlong") long sampleRate); + @NativeType("jlong") long numSamples); /** * Use the DeepSpeech recognition to perform Speech-To-Text and output metadata @@ -245,8 +241,9 @@ public static native String speechToTextUnsafe(@NativeType("struct ModelState *" * * @param modelStatePointer The ModelState pointer for the recognition to use. * @param audioBufferPtr A 16-bit, mono raw audio signal at the appropriate sample rate. + * (matching what the model was trained on) * @param numSamples The number of samples in the audio signal. - * @param sampleRate The sample-rate of the audio signal. + * @param numResults The maximum number of CandidateTranscript structs to return. Returned value might be smaller than this. * @return Outputs a struct of individual letters along with their timing information. * The user is responsible for freeing Metadata by calling {@link #freeMetadata(long)}. Returns {@link #NULL} on error. */ @@ -256,7 +253,7 @@ public static native String speechToTextUnsafe(@NativeType("struct ModelState *" public static native long speechToTextWithMetadataUnsafe(@NativeType("struct ModelState *") long modelStatePointer, @NativeType("const short *") long audioBufferPtr, @NativeType("jlong") long numSamples, - @NativeType("jlong") long sampleRate); + @NativeType("jlong") long numResults); /** * Use the DeepSpeech recognition to perform Speech-To-Text and output metadata @@ -264,8 +261,9 @@ public static native long speechToTextWithMetadataUnsafe(@NativeType("struct Mod * * @param modelStatePointer The ModelState pointer for the recognition to use. * @param audioBuffer A 16-bit, mono raw audio signal at the appropriate sample rate. + * (matching what the model was trained on) * @param numSamples The number of samples in the audio signal. - * @param sampleRate The sample-rate of the audio signal. + * @param numResults The maximum number of CandidateTranscript structs to return. Returned value might be smaller than this. * @return Outputs a struct of individual letters along with their timing information. * The user is responsible for freeing Metadata by calling {@link #freeMetadata(long)}. Returns {@link #NULL} on error. * @throws UnexpectedBufferCapacityException if #numSamples does not match the allocated buffer capacity. Condition: {@code numSamples * Short.BYTES > audioBuffer.capacity()} @@ -280,9 +278,9 @@ public static long speechToTextWithMetadata(@NativeType("struct ModelState *") l @NativeType("const short *") @NotNull ByteBuffer audioBuffer, @NativeType("jlong") long numSamples, - @NativeType("jlong") long sampleRate) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { + @NativeType("jlong") long numResults) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { BufferUtils.checkByteBuffer(audioBuffer, ByteOrder.nativeOrder(), numSamples * 2 /* sizeof(short) */); - return nSpeechToTextWithMetadata(modelStatePointer, audioBuffer, numSamples, sampleRate); + return nSpeechToTextWithMetadata(modelStatePointer, audioBuffer, numSamples, numResults); } /** @@ -295,7 +293,7 @@ private static native long nSpeechToTextWithMetadata(@NativeType("struct ModelSt @NativeType("const short *") @NotNull ByteBuffer audioBuffer, @NativeType("jlong") long numSamples, - @NativeType("jlong") long sampleRate); + @NativeType("jlong") long numResults); /** * Create a new streaming inference state. The streaming state returned @@ -303,13 +301,10 @@ private static native long nSpeechToTextWithMetadata(@NativeType("struct ModelSt * and {@link #finishStream(long)}. * * @param modelStatePointer The ModelState pointer for the recognition to use. - * @param preAllocFrames Number of timestep frames to reserve. One timestep - * is equivalent to two window lengths (20ms). If set to - * 0 we reserve enough frames for 3 seconds of audio (150). - * @param sampleRate The sample-rate of the audio signal. * @param streamPointerOut an opaque pointer that represents the streaming state. Can * be {@link #NULL} if an error occurs. - * Note for JavaBindings: The long buffer must have a capacity of one long otherwise the function will return -1. No native memory will be allocated, so this does not result in a memory leak. + * Note for JavaBindings: The long buffer must have a capacity of one long otherwise the function will return -1. + * No native memory will be allocated, so this does not result in a memory leak. * The function will throw an {@link UnexpectedBufferCapacityException} stating the buffer does not have enough capacity. * @return Zero for success, non-zero on failure. * @throws UnexpectedBufferCapacityException if the buffer has a capacity smaller than {@link Long#BYTES} bytes. @@ -320,24 +315,20 @@ private static native long nSpeechToTextWithMetadata(@NativeType("struct ModelSt @Calls("DS_SetupStream") @NativeType("jint") public static int setupStream(@NativeType("struct ModelState *") long modelStatePointer, - @NativeType("jlong") long preAllocFrames, - @NativeType("jlong") long sampleRate, @DynamicPointer("finishStream") @NativeType("struct StreamingState **") @NotNull @CallByReference ByteBuffer streamPointerOut) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { BufferUtils.checkByteBuffer(streamPointerOut, ByteOrder.nativeOrder(), NATIVE_POINTER_SIZE); - return nSetupStream(modelStatePointer, preAllocFrames, sampleRate, streamPointerOut); + return nSetupStream(modelStatePointer, streamPointerOut); } /** - * Unexposed unsafe method that should not be used. Use instead: {@link #setupStream(long, long, long, ByteBuffer)} + * Unexposed unsafe method that should not be used. Use instead: {@link #setupStream(long, ByteBuffer)} */ - @Calls("DS_SetupStream") + @Calls("DS_CreateStream") @NativeType("jint") private static native int nSetupStream(@NativeType("struct ModelState *") long modelStatePointer, - @NativeType("jlong") long preAllocFrames, - @NativeType("jlong") long sampleRate, @DynamicPointer("finishStream") @NativeType("struct StreamingState **") @NotNull @@ -346,7 +337,7 @@ private static native int nSetupStream(@NativeType("struct ModelState *") long m /** * Feed audio samples to an ongoing streaming inference. * - * @param streamPointer A streaming state pointer created by {@link #setupStream(long, long, long, ByteBuffer)}. + * @param streamPointer A streaming state pointer created by {@link #setupStream(long, ByteBuffer)}. * @param audioBuffer An array of 16-bit, mono raw audio samples at the appropriate sample rate. * @param numSamples The number of samples in the audio content. * @throws UnexpectedBufferCapacityException if #numSamples does not match the allocated buffer capacity. Condition: {@code numSamples * Short.BYTES < audioBuffer.capacity()} @@ -379,7 +370,7 @@ private static native void nFeedAudioContent(@NativeType("struct StreamingState * currently capable of streaming, so it always starts from the beginning * of the audio. * - * @param streamPointer A streaming state pointer created by {@link #setupStream(long, long, long, ByteBuffer)}. + * @param streamPointer A streaming state pointer created by {@link #setupStream(long, ByteBuffer)}. * @return The STT intermediate result. */ @Calls("DS_IntermediateDecode") @@ -391,7 +382,7 @@ private static native void nFeedAudioContent(@NativeType("struct StreamingState * Signal the end of an audio signal to an ongoing streaming * inference, returns the STT result over the whole audio signal. * - * @param streamPointer A streaming state pointer created by {@link #setupStream(long, long, long, ByteBuffer)}. + * @param streamPointer A streaming state pointer created by {@link #setupStream(long, ByteBuffer)}. * @return The STT result. */ @Calls("DS_FinishStream") @@ -403,14 +394,15 @@ private static native void nFeedAudioContent(@NativeType("struct StreamingState * Signal the end of an audio signal to an ongoing streaming * inference, returns per-letter metadata. * - * @param streamPointer A streaming state pointer created by {@link #setupStream(long, long, long, ByteBuffer)}. + * @param streamPointer A streaming state pointer created by {@link #setupStream(long, ByteBuffer)}. + * @param numResults The number of candidate transcripts to return. * @return Outputs a struct of individual letters along with their timing information. * The user is responsible for freeing Metadata by calling {@link #freeMetadata(long)}. Returns {@link #NULL} on error. */ @Calls("DS_FinishStreamWithMetadata") @NativeType("struct Metadata *") @DynamicPointer("freeMetadata") - public static native long finishStreamWithMetadata(@NativeType("struct StreamingState *") long streamPointer); + public static native long finishStreamWithMetadata(@NativeType("struct StreamingState *") long streamPointer, long numResults); /** * This method will free the state pointer #streamPointer. @@ -432,10 +424,13 @@ private static native void nFeedAudioContent(@NativeType("struct StreamingState public static native void freeMetadata(@NativeType("struct Metadata *") long metaDataPointer); /** - * Prints version of this library and of the linked TensorFlow library. + * Gets the version of this library. The returned version is a semantic version (SemVer 2.0.0) */ - @Calls("DS_PrintVersions") - public static native void printVersions(); + @Calls("DS_Version") + public static native String getVersion(); + + @Calls("DS_ErrorCodeToErrorMessage") + public static native String errorCodeToErrorMessage(long errorCode); /** * @return the configuration the jni library has been built for diff --git a/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/DeepSpeechModel.java b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/DeepSpeechModel.java index 6ffacbd..2934c46 100644 --- a/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/DeepSpeechModel.java +++ b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/DeepSpeechModel.java @@ -1,6 +1,7 @@ package org.mozilla.deepspeech.recognition; import org.jetbrains.annotations.NotNull; +import org.mozilla.deepspeech.DeepSpeech; import org.mozilla.deepspeech.doc.NativeType; import org.mozilla.deepspeech.doc.WrappsStruct; import org.mozilla.deepspeech.exception.buffer.BufferReadonlyException; @@ -8,6 +9,7 @@ import org.mozilla.deepspeech.exception.buffer.IncorrectBufferTypeException; import org.mozilla.deepspeech.exception.buffer.UnexpectedBufferCapacityException; import org.mozilla.deepspeech.nativewrapper.DynamicStruct; +import org.mozilla.deepspeech.recognition.stream.StreamingState; import java.io.File; import java.io.FileNotFoundException; @@ -28,39 +30,76 @@ public class DeepSpeechModel extends DynamicStruct.LifecycleDisposed { /** * @param modelFile The file pointing to the frozen recognition graph. - * @param numCep The number of cepstrum the recognition was trained with. - * @param context The context window the recognition was trained with. - * @param alphabetConfigFile The path to the configuration file specifying - * the alphabet used by the network. See alphabet.h. - * @param beamWidth The beam width used by the decoder. A larger beam - * width generates better results at the cost of decoding - * time. - * @throws FileNotFoundException if either modelFile or alphabetConfigFile is not found + * @throws FileNotFoundException if modelFile is not found */ - public DeepSpeechModel(@NotNull File modelFile, long numCep, long context, @NotNull File alphabetConfigFile, long beamWidth) throws FileNotFoundException { - super(newModel(modelFile, numCep, context, alphabetConfigFile, beamWidth), UNDEFINED_STRUCT_SIZE); + public DeepSpeechModel(@NotNull File modelFile) throws FileNotFoundException { + super(newModel(modelFile), UNDEFINED_STRUCT_SIZE); + } + + /** + * Get beam width value used by the model. If {@link #setBeamWidth(long)} was not called before, + * will return the default value loaded from the model file. + * + * @return Beam width value used by the model. + */ + public int getBeamWidth() { + return getModelBeamWidth(this.pointer); + } + + /** + * Set beam width value used by the model. + * + * @param beamWidth The beam width used by the model. + * A larger beam width value generates better results at the cost of decoding time. + * + * @return Zero on success, non-zero on failure. + */ + public int setBeamWidth(long beamWidth) { + return setModelBeamWidth(this.pointer, beamWidth); + } + + /** + * Return the sample rate expected by a model. + * + * @return Sample rate expected by the model for its input. + */ + public int getSampleRate() { + return getModelSampleRate(this.pointer); } /** * Enables decoding using beam scoring with a KenLM language model. * - * @param alphabetFile The path to the configuration file specifying the alphabet used by the network. - * @param lmBinaryFile The path to the language model binary file. - * @param trieFile The path to the trie file build from the same vocabulary as the language model binary + * @param scorerPath The path to the scorer package generated with `data/lm/generate_package.py`. * @param lmAlpha The alpha hyper-parameter of the CTC decoder. Language Model weight. * @param lmBeta The beta hyper-parameter of the CTC decoder. Word insertion weight. - * @throws FileNotFoundException if one of the files is not found. + * @throws FileNotFoundException if the file is not found. */ - public void enableLMLanguageModel(@NotNull File alphabetFile, @NotNull File lmBinaryFile, @NotNull File trieFile, float lmAlpha, float lmBeta) throws FileNotFoundException { - enableDecoderWithLM(this.pointer, checkExists(alphabetFile).getPath(), checkExists(lmBinaryFile).getPath(), checkExists(trieFile).getPath(), lmAlpha, lmBeta); + public void enableLMLanguageModel(@NotNull File scorerPath, float lmAlpha, float lmBeta) throws FileNotFoundException { + DeepSpeech.enableDecoderWithLM(this.pointer, checkExists(scorerPath).getPath(), lmAlpha, lmBeta); + } + + /** + * @param lmAlpha The alpha hyper-parameter of the CTC decoder. Language Model weight. + * @param lmBeta The beta hyper-parameter of the CTC decoder. Word insertion weight. + */ + public void setScorerAlphaBeta(float lmAlpha, float lmBeta) throws FileNotFoundException { + DeepSpeech.setScorerAlphaBeta(this.pointer, lmAlpha, lmBeta); + } + + /** + * Disable decoding using an external scorer. + */ + public void disableExternalScorer(@NotNull File scorerPath, float lmAlpha, float lmBeta) throws FileNotFoundException { + DeepSpeech.disableExternalScorer(this.pointer); } /** * Performs a text to speech call on the recognition * - * @param audioBuffer the audio buffer storing the audio data in samples / frames to perform the recognition on + * @param audioBuffer the 16-bit, mono raw audio buffer storing the audio data at the appropriate sample rate + * (matching what the model was trained on). * @param numSamples the number of samples / frames in the buffer - * @param sampleRate the amount of samples representing a given duration of audio. sampleRate = Δ samples / Δ time * @return the transcription string * @throws UnexpectedBufferCapacityException if #numSamples does not match the allocated buffer capacity. Condition: {@code numSamples * Short.BYTES > audioBuffer.capacity()} * @throws IncorrectBufferByteOrderException if the audioBuffer has a byte order different to {@link ByteOrder#nativeOrder()}. @@ -69,7 +108,7 @@ public void enableLMLanguageModel(@NotNull File alphabetFile, @NotNull File lmBi */ @NotNull public String doSpeechToText(@NativeType("const short *") @NotNull ByteBuffer audioBuffer, long numSamples, long sampleRate) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { - String ret = speechToText(this.pointer, audioBuffer, numSamples, sampleRate); + String ret = speechToText(this.pointer, audioBuffer, numSamples); if (ret == null) throw new NullPointerException(); return ret; } @@ -79,13 +118,13 @@ public String doSpeechToText(@NativeType("const short *") @NotNull ByteBuffer au * * @param audioBuffer the audio buffer storing the audio data in samples / frames to perform the recognition on * @param numSamples the number of samples / frames in the buffer - * @param sampleRate the amount of samples representing a given duration of audio. sampleRate = Δ samples / Δ time + * @param numResults The maximum number of CandidateTranscript structs to return. Returned value might be smaller than this. * @return the meta data of transcription * @see SpeechRecognitionResult */ @NotNull - public SpeechRecognitionResult doSpeechRecognitionWithMeta(@NativeType("const short *") @NotNull ByteBuffer audioBuffer, long numSamples, long sampleRate) { - long metaPointer = speechToTextWithMetadata(this.pointer, audioBuffer, numSamples, sampleRate); + public SpeechRecognitionResult doSpeechRecognitionWithMeta(@NativeType("const short *") @NotNull ByteBuffer audioBuffer, long numSamples, long numResults) { + long metaPointer = speechToTextWithMetadata(this.pointer, audioBuffer, numSamples, numResults); if (metaPointer == NULL) throw new NullPointerException(); return new SpeechRecognitionResult(metaPointer); // Meta pointer is freed as Recognition Result instantly disposes it after copying the values. } @@ -93,11 +132,11 @@ public SpeechRecognitionResult doSpeechRecognitionWithMeta(@NativeType("const sh /** * Allocates a new native recognition structure and returns the pointer pointing to the dynamically allocated memory * - * @see DeepSpeechModel#DeepSpeechModel(File, long, long, File, long) + * @see DeepSpeechModel#DeepSpeechModel(File, long) */ - private static long newModel(@NotNull File modelFile, long numCep, long context, @NotNull File alphabetConfigFile, long beamWidth) throws FileNotFoundException { + private static long newModel(@NotNull File modelFile) throws FileNotFoundException { ByteBuffer ptr = ByteBuffer.allocateDirect(NATIVE_POINTER_SIZE).order(ByteOrder.LITTLE_ENDIAN); - if (createModel(checkExists(modelFile).getPath(), numCep, context, checkExists(alphabetConfigFile).getPath(), beamWidth, ptr) != 0) + if (createModel(checkExists(modelFile).getPath(), ptr) != 0) throw new RuntimeException("Failed to create recognition!"); return getNativePointer(getBufferAddress(ptr)); } @@ -123,12 +162,11 @@ public long getPointer() { * * @param audioBufferPointer the audio buffer storing the audio data in samples / frames to perform the recognition on * @param numSamples the number of samples / frames in the buffer - * @param sampleRate the amount of samples representing a given duration of audio. sampleRate = Δ samples / Δ time * @return the transcription string */ @NotNull - public String doSpeechToTextUnsafe(@NativeType("const short *") long audioBufferPointer, long numSamples, long sampleRate) { - String ret = speechToTextUnsafe(this.pointer, audioBufferPointer, numSamples, sampleRate); + public String doSpeechToTextUnsafe(@NativeType("const short *") long audioBufferPointer, long numSamples) { + String ret = speechToTextUnsafe(this.pointer, audioBufferPointer, numSamples); if (ret == null) throw new NullPointerException(); return ret; } @@ -138,13 +176,13 @@ public String doSpeechToTextUnsafe(@NativeType("const short *") long audioBuffer * * @param audioBufferPointer the audio buffer storing the audio data in samples / frames to perform the recognition on * @param numSamples the number of samples / frames in the buffer - * @param sampleRate the amount of samples representing a given duration of audio. sampleRate = Δ samples / Δ time + * @param numResults The maximum number of CandidateTranscript structs to return. Returned value might be smaller than this. * @return the meta data of transcription * @see SpeechRecognitionResult */ @NotNull - public SpeechRecognitionResult doSpeechRecognitionWithMetaUnsafe(@NativeType("const short *") long audioBufferPointer, long numSamples, long sampleRate) { - long metaPtr = speechToTextWithMetadataUnsafe(this.pointer, audioBufferPointer, numSamples, sampleRate); + public SpeechRecognitionResult doSpeechRecognitionWithMetaUnsafe(@NativeType("const short *") long audioBufferPointer, long numSamples, long numResults) { + long metaPtr = speechToTextWithMetadataUnsafe(this.pointer, audioBufferPointer, numSamples, numResults); if (metaPtr == NULL) throw new NullPointerException(); return new SpeechRecognitionResult(metaPtr); // MetaPtr is freed after this action } diff --git a/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/SpeechRecognitionResult.java b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/SpeechRecognitionResult.java index 56d52a5..b905779 100644 --- a/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/SpeechRecognitionResult.java +++ b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/SpeechRecognitionResult.java @@ -6,6 +6,7 @@ import org.mozilla.deepspeech.nativewrapper.DynamicStruct; import org.mozilla.deepspeech.utils.NativeAccess; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; @@ -14,7 +15,9 @@ /** * Represents the entire STT output as an array of character metadata objects. * Stores properties like a confidence value and time stamps for each spoken character. - * + * + * @see DeepSpeechModel#doSpeechRecognitionWithMeta(ByteBuffer, long, long) + * @see org.mozilla.deepspeech.recognition.stream.SpeechRecognitionAudioStream#doSpeechRecognitionWithMeta(DeepSpeechModel, long) * @see #spokenCharacters */ @WrappsStruct("Metadata") diff --git a/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/stream/SpeechRecognitionAudioStream.java b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/stream/SpeechRecognitionAudioStream.java index 5f3ebfa..23c76fa 100644 --- a/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/stream/SpeechRecognitionAudioStream.java +++ b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/stream/SpeechRecognitionAudioStream.java @@ -5,34 +5,35 @@ import org.mozilla.deepspeech.recognition.SpeechRecognitionResult; /** - * Represents a 16 bit audio output stream where eg. the input of a microphone input stream are written onto. Intermediate speech recognition calls can be performed + * Represents a 16 bit audio output stream where eg. the input of a microphone input stream are written onto. + * Intermediate speech recognition calls can be performed */ public class SpeechRecognitionAudioStream extends NativeByteArrayOutputStream { - /** - * The sample-rate of the audio signal - */ - private final long sampleRate; +// /** +// * The sample-rate of the audio signal +// */ +// private final long sampleRate; /** * @param sampleRate the sample-rate of the audio signal */ - public SpeechRecognitionAudioStream(long sampleRate) { - this.sampleRate = sampleRate; + public SpeechRecognitionAudioStream() { +// this.sampleRate = sampleRate; } @NotNull public String doSpeechToText(@NotNull DeepSpeechModel model) { - return model.doSpeechToTextUnsafe(this.address(), this.getStreamSize() / 2, this.sampleRate); + return model.doSpeechToTextUnsafe(this.address(), this.getStreamSize() / 2); } @NotNull - public SpeechRecognitionResult doSpeechRecognitionWithMeta(@NotNull DeepSpeechModel model) { - return model.doSpeechRecognitionWithMetaUnsafe(this.address(), this.getStreamSize() / 2, this.sampleRate); + public SpeechRecognitionResult doSpeechRecognitionWithMeta(@NotNull DeepSpeechModel model, long numResults) { + return model.doSpeechRecognitionWithMetaUnsafe(this.address(), this.getStreamSize() / 2, numResults); } - public long getSampleRate() { - return sampleRate; + public long getSampleRate(@NotNull DeepSpeechModel model) { + return model.getSampleRate(); } /** diff --git a/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/stream/StreamingState.java b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/stream/StreamingState.java new file mode 100644 index 0000000..438ce1c --- /dev/null +++ b/libdeepspeech/src/main/java/org/mozilla/deepspeech/recognition/stream/StreamingState.java @@ -0,0 +1,101 @@ +package org.mozilla.deepspeech.recognition.stream; + +import org.jetbrains.annotations.NotNull; +import org.mozilla.deepspeech.DeepSpeech; +import org.mozilla.deepspeech.doc.Calls; +import org.mozilla.deepspeech.doc.NativeType; +import org.mozilla.deepspeech.exception.buffer.BufferReadonlyException; +import org.mozilla.deepspeech.exception.buffer.IncorrectBufferByteOrderException; +import org.mozilla.deepspeech.exception.buffer.IncorrectBufferTypeException; +import org.mozilla.deepspeech.exception.buffer.UnexpectedBufferCapacityException; +import org.mozilla.deepspeech.nativewrapper.DynamicStruct; +import org.mozilla.deepspeech.recognition.DeepSpeechModel; +import org.mozilla.deepspeech.utils.BufferUtils; +import org.mozilla.deepspeech.utils.NativeAccess; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.mozilla.deepspeech.DeepSpeech.*; + +public class StreamingState extends DynamicStruct.LifecycleDisposed { + + /** + * Create a new streaming inference state + */ + public static StreamingState setupStream(@NotNull DeepSpeechModel model) { + ByteBuffer streamPointer = ByteBuffer.allocateDirect(NativeAccess.NATIVE_POINTER_SIZE).order(ByteOrder.nativeOrder()); + + long status = DeepSpeech.setupStream(model.getPointer(), streamPointer); + if (status != ErrorCodes.OK) { + throw new IllegalStateException(errorCodeToErrorMessage(status)); + } + + return new StreamingState(BufferUtils.getBufferAddress(streamPointer)); + } + + /** + * Feed audio samples to an ongoing streaming inference. + * + * @param audioBuffer An array of 16-bit, mono raw audio samples at the appropriate sample rate. + * @param numSamples The number of samples in the audio content. + * @throws UnexpectedBufferCapacityException if #numSamples does not match the allocated buffer capacity. Condition: {@code numSamples * Short.BYTES < audioBuffer.capacity()} + * @throws IncorrectBufferByteOrderException if the audioBuffer has a byte order different to {@link ByteOrder#nativeOrder()}. + * @throws IncorrectBufferTypeException if the audioBuffer is not directly allocated. + * @throws BufferReadonlyException if the buffer is read only + */ + public void feedAudioContent(@NotNull ByteBuffer audioBuffer, + long numSamples) throws UnexpectedBufferCapacityException, IncorrectBufferByteOrderException, IncorrectBufferTypeException, BufferReadonlyException { + DeepSpeech.feedAudioContent(this.pointer, audioBuffer, numSamples); + } + + /** + * Compute the intermediate decoding of an ongoing streaming inference. + * This is an expensive process as the decoder implementation isn't currently capable of streaming, + * so it always starts from the beginning of the audio. + * + * @return The STT intermediate result. + */ + public String intermediateDecode() { + return DeepSpeech.intermediateDecode(this.pointer); + } + + /** + * This method will free the state pointer (#pointer) + * Signal the end of an audio signal to an ongoing streaming + * inference, returns the STT result over the whole audio signal. + * + * @return The STT result. + */ + @Calls("DS_FinishStream") + @NativeType("jstring") + public String finishStream() { + return DeepSpeech.finishStream(this.pointer); + } + +// TODO: org_mozilla_deepspeech_DeepSpeech.cpp needs to return a DynamicStruct that can be parsed to a SpeechRecognitionResult.SpokenCharacterData +// /** +// * This method will free the state pointer #streamPointer. +// * Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata. +// * +// * @param numResults The number of candidate transcripts to return. +// * @return Outputs a struct of individual letters along with their timing information. +// * The user is responsible for freeing Metadata by calling {@link #freeMetadata(long)}. +// * Returns {@link #NULL} on error. +// */ +// public long finishStreamWithMetadata(long numResults) { +// return DeepSpeech.finishStreamWithMetadata(this.streamPointer, numResults); +// } +// public void freeMetadata(long metaDataPointer) { +// DeepSpeech.freeMetadata(metaDataPointer); +// } + + @Override + protected void deallocateStruct(long pointer) { + discardStream(pointer); + } + + private StreamingState(long streamPointer) { + super(streamPointer, NativeAccess.NATIVE_POINTER_SIZE); + } +}