diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp index 0919131..7e318b0 100644 --- a/Examples/main/main.cpp +++ b/Examples/main/main.cpp @@ -5,10 +5,13 @@ #include "miscUtils.h" #include #include +#include +#include #include "textWriter.h" using namespace Whisper; #define STREAM_AUDIO 1 +#define USEBUFFER 0 static HRESULT loadWhisperModel( const wchar_t* path, const std::wstring& gpu, iModel** pp ) { @@ -302,21 +305,46 @@ int wmain( int argc, wchar_t* argv[] ) wparams.encoder_begin_callback_user_data = &is_aborted; } - if( STREAM_AUDIO && !wparams.flag( eFullParamsFlags::TokenTimestamps ) ) +#ifdef USEBUFFER + printf("----- USING BUFFER-------"); + + std::ifstream file(fname, std::ios::binary); + file.seekg(0, std::ios::end); + size_t fileSize = file.tellg(); + file.seekg(0, std::ios::beg); + std::vector data(fileSize); + file.read(reinterpret_cast(data.data()), fileSize); + file.close(); + if (STREAM_AUDIO && !wparams.flag(eFullParamsFlags::TokenTimestamps)) { + ComLight::CComPtr reader; + CHECK(mf->loadAudioFileData(data.data(), data.size(), false, &reader)); + sProgressSink progressSink{ nullptr, nullptr }; + hr = context->runStreamed(wparams, progressSink, reader); + } + else + { + ComLight::CComPtr buffer; + CHECK(mf->loadAudioFileDataBuffer(data.data(), data.size(), false, &buffer)); + hr = context->runFull(wparams, buffer); + } + +#else + if (STREAM_AUDIO && !wparams.flag(eFullParamsFlags::TokenTimestamps)) { ComLight::CComPtr reader; - CHECK( mf->openAudioFile( fname.c_str(), params.diarize, &reader ) ); + CHECK(mf->openAudioFile(fname.c_str(), params.diarize, &reader)); sProgressSink progressSink{ nullptr, nullptr }; - hr = context->runStreamed( wparams, progressSink, reader ); + hr = context->runStreamed(wparams, progressSink, reader); } else { // Token-level timestamps feature is not currently implemented when streaming the audio // When these timestamps are requested, fall back to buffered mode. ComLight::CComPtr buffer; - CHECK( mf->loadAudioFile( fname.c_str(), params.diarize, &buffer ) ); - hr = context->runFull( wparams, buffer ); + CHECK(mf->loadAudioFile(fname.c_str(), params.diarize, &buffer)); + hr = context->runFull(wparams, buffer); } +#endif if( FAILED( hr ) ) { diff --git a/Whisper/API/iContext.cl.h b/Whisper/API/iContext.cl.h index d66fa37..e81c753 100644 --- a/Whisper/API/iContext.cl.h +++ b/Whisper/API/iContext.cl.h @@ -29,6 +29,7 @@ namespace Whisper virtual HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) = 0; virtual HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) = 0; virtual HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) = 0; + virtual HRESULT COMLIGHTCALL getLastError(char* error, size_t errorSize) = 0; virtual HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const = 0; // Try to detect speaker by comparing channels of the stereo PCM data diff --git a/Whisper/API/iContext.h b/Whisper/API/iContext.h index 5d1f3b7..ef7abe2 100644 --- a/Whisper/API/iContext.h +++ b/Whisper/API/iContext.h @@ -26,6 +26,7 @@ namespace Whisper HRESULT __stdcall runFull( const sFullParams& params, const iAudioBuffer* buffer ); HRESULT __stdcall runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ); HRESULT __stdcall runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ); + HRESULT __stdcall getLastError(char* error, size_t errorSize); HRESULT __stdcall getResults( eResultFlags flags, iTranscribeResult** pp ) const; // Try to detect speaker by comparing channels of the stereo PCM data diff --git a/Whisper/API/iMediaFoundation.cl.h b/Whisper/API/iMediaFoundation.cl.h index f84d928..0b74ada 100644 --- a/Whisper/API/iMediaFoundation.cl.h +++ b/Whisper/API/iMediaFoundation.cl.h @@ -40,6 +40,7 @@ namespace Whisper virtual HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const = 0; virtual HRESULT COMLIGHTCALL openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp ) = 0; virtual HRESULT COMLIGHTCALL loadAudioFileData( const void* data, uint64_t size, bool stereo, iAudioReader** pp ) = 0; + virtual HRESULT COMLIGHTCALL loadAudioFileDataBuffer(const void* data, uint64_t size, bool stereo, iAudioBuffer** pp) = 0; virtual HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) = 0; virtual HRESULT COMLIGHTCALL openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) = 0; diff --git a/Whisper/D3D/shaders.cpp b/Whisper/D3D/shaders.cpp index 9d4c9ca..4e07b6e 100644 --- a/Whisper/D3D/shaders.cpp +++ b/Whisper/D3D/shaders.cpp @@ -1,62 +1,62 @@ -#include "stdafx.h" -#include "shaders.h" -#include "device.h" -#include "../Utils/LZ4/lz4.h" - -namespace -{ -#ifdef _DEBUG -#include "shaderData-Debug.inl" -#else -#include "shaderData-Release.inl" -#endif - - // static std::vector> s_shaders; -} - -HRESULT DirectCompute::createComputeShaders( std::vector>& shaders ) -{ - constexpr size_t countBinaries = s_shaderOffsets.size() - 1; - const size_t cbDecompressedLength = s_shaderOffsets[ countBinaries ]; - constexpr size_t countShaders = s_shaderBlobs32.size(); - - std::vector dxbc; - try - { - shaders.resize( countShaders ); - dxbc.resize( cbDecompressedLength ); - } - catch( const std::bad_alloc& ) - { - return E_OUTOFMEMORY; - } - - const int lz4Status = LZ4_decompress_safe( (const char*)s_compressedShaders.data(), (char*)dxbc.data(), (int)s_compressedShaders.size(), (int)cbDecompressedLength ); - if( lz4Status != (int)cbDecompressedLength ) - { - logError( u8"LZ4_decompress_safe failed with status %i", lz4Status ); - return PLA_E_CABAPI_FAILURE; - } - ID3D11Device* const dev = device(); - - const auto& blobs = gpuInfo().wave64() ? s_shaderBlobs64 : s_shaderBlobs32; - - for( size_t i = 0; i < countShaders; i++ ) - { - const size_t idxBinary = blobs[ i ]; - const uint32_t offThis = s_shaderOffsets[ idxBinary ]; - const uint8_t* rsi = &dxbc[ offThis ]; - const size_t len = s_shaderOffsets[ idxBinary + 1 ] - offThis; - const HRESULT hr = dev->CreateComputeShader( rsi, len, nullptr, &shaders[ i ] ); - if( SUCCEEDED( hr ) ) - continue; - - const uint64_t binaryBit = ( 1ull << idxBinary ); - if( 0 != ( binaryBit & fp64ShadersBitmap ) ) - continue; // This shader uses FP64 math, the support for that is optional. When not supported, CreateComputeShader method is expected to fail. - // TODO [low]: ideally, query for the support when creating the device, and don't even try creating these compute shaders - return hr; - } - - return S_OK; +#include "stdafx.h" +#include "shaders.h" +#include "device.h" +#include "../Utils/LZ4/lz4.h" + +namespace +{ +#ifdef _DEBUG +#include "shaderData-Release.inl" //"shaderData-Debug.inl" +#else +#include "shaderData-Release.inl" +#endif + + // static std::vector> s_shaders; +} + +HRESULT DirectCompute::createComputeShaders( std::vector>& shaders ) +{ + constexpr size_t countBinaries = s_shaderOffsets.size() - 1; + const size_t cbDecompressedLength = s_shaderOffsets[ countBinaries ]; + constexpr size_t countShaders = s_shaderBlobs32.size(); + + std::vector dxbc; + try + { + shaders.resize( countShaders ); + dxbc.resize( cbDecompressedLength ); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + + const int lz4Status = LZ4_decompress_safe( (const char*)s_compressedShaders.data(), (char*)dxbc.data(), (int)s_compressedShaders.size(), (int)cbDecompressedLength ); + if( lz4Status != (int)cbDecompressedLength ) + { + logError( u8"LZ4_decompress_safe failed with status %i", lz4Status ); + return PLA_E_CABAPI_FAILURE; + } + ID3D11Device* const dev = device(); + + const auto& blobs = gpuInfo().wave64() ? s_shaderBlobs64 : s_shaderBlobs32; + + for( size_t i = 0; i < countShaders; i++ ) + { + const size_t idxBinary = blobs[ i ]; + const uint32_t offThis = s_shaderOffsets[ idxBinary ]; + const uint8_t* rsi = &dxbc[ offThis ]; + const size_t len = s_shaderOffsets[ idxBinary + 1 ] - offThis; + const HRESULT hr = dev->CreateComputeShader( rsi, len, nullptr, &shaders[ i ] ); + if( SUCCEEDED( hr ) ) + continue; + + const uint64_t binaryBit = ( 1ull << idxBinary ); + if( 0 != ( binaryBit & fp64ShadersBitmap ) ) + continue; // This shader uses FP64 math, the support for that is optional. When not supported, CreateComputeShader method is expected to fail. + // TODO [low]: ideally, query for the support when creating the device, and don't even try creating these compute shaders + return hr; + } + + return S_OK; } \ No newline at end of file diff --git a/Whisper/MF/MediaFoundation.cpp b/Whisper/MF/MediaFoundation.cpp index 361aa1c..4567e5e 100644 --- a/Whisper/MF/MediaFoundation.cpp +++ b/Whisper/MF/MediaFoundation.cpp @@ -136,6 +136,10 @@ namespace Whisper res.detach( pp ); return S_OK; } + HRESULT COMLIGHTCALL loadAudioFileDataBuffer(const void* data, uint64_t size, bool stereo, iAudioBuffer** pp) noexcept override final + { + return Whisper::loadAudioMemoryFile(data, size, stereo, pp); + } HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) noexcept override final { return captureDeviceList( pfn, pv ); diff --git a/Whisper/MF/loadAudioFile.cpp b/Whisper/MF/loadAudioFile.cpp index 354e140..1cd3c9c 100644 --- a/Whisper/MF/loadAudioFile.cpp +++ b/Whisper/MF/loadAudioFile.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #pragma comment(lib, "Mfreadwrite.lib") #pragma comment(lib, "mfuuid.lib") @@ -37,8 +38,10 @@ namespace Whisper rdi = 0; return S_OK; } + HRESULT _loadBuffer(CComPtr reader, LPCTSTR path, bool stereo); public: HRESULT load( LPCTSTR path, bool stereo ); + HRESULT loadBuffer(const void* data, uint64_t size, bool stereo); }; HRESULT MediaFileBuffer::load( LPCTSTR path, bool stereo ) @@ -51,8 +54,16 @@ namespace Whisper return hr; } - CHECK( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) ); - CHECK( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) ); + return _loadBuffer(reader, path, stereo); + } + + + HRESULT MediaFileBuffer::_loadBuffer(CComPtr reader, LPCTSTR path, bool stereo) + { + + CHECK(reader->SetStreamSelection(MF_SOURCE_READER_ALL_STREAMS, FALSE)); + CHECK(reader->SetStreamSelection(MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE)); + CComPtr mtNative; CHECK( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) ); @@ -73,7 +84,7 @@ namespace Whisper CComPtr sample; // Read the next sample. - hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample ); + HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample ); if( FAILED( hr ) ) { logErrorHr( hr, u8"IMFSourceReader.ReadSample" ); @@ -136,8 +147,34 @@ namespace Whisper return S_OK; } + + + + HRESULT MediaFileBuffer::loadBuffer(const void* data, uint64_t size, bool stereo) + { + + CComPtr reader; + CComPtr comStream; + // Microsoft neglected to document their API, but Wine returns a new stream with reference counter = 1 + // See shstream_create() function there https://source.winehq.org/source/dlls/shcore/main.c#0832 + // That's why we need the Attach(), as opposed to an assignment + comStream.Attach(SHCreateMemStream((const BYTE*)data, (UINT)size)); + if (!comStream) + { + logError(u8"SHCreateMemStream failed"); + return E_FAIL; + } + + CComPtr mfStream; + CHECK(MFCreateMFByteStreamOnStream(comStream, &mfStream)); + HRESULT hr = MFCreateSourceReaderFromByteStream(mfStream, nullptr, &reader); + + return _loadBuffer(reader, L"memory", stereo); + } } + + HRESULT COMLIGHTCALL Whisper::loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) { if( nullptr == path || nullptr == pp ) @@ -148,4 +185,16 @@ HRESULT COMLIGHTCALL Whisper::loadAudioFile( LPCTSTR path, bool stereo, iAudioBu CHECK( obj->load( path, stereo ) ); obj.detach( pp ); return S_OK; +} + +HRESULT COMLIGHTCALL Whisper::loadAudioMemoryFile(const void *data, uint64_t size, bool stereo, iAudioBuffer** pp) +{ + if (nullptr == data || nullptr == pp || size == 0) + return E_POINTER; + + ComLight::CComPtr> obj; + CHECK(ComLight::Object::create(obj)); + CHECK(obj->loadBuffer(data, size, stereo)); + obj.detach(pp); + return S_OK; } \ No newline at end of file diff --git a/Whisper/MF/loadAudioFile.h b/Whisper/MF/loadAudioFile.h index 9736ccd..e800444 100644 --- a/Whisper/MF/loadAudioFile.h +++ b/Whisper/MF/loadAudioFile.h @@ -4,4 +4,5 @@ namespace Whisper { HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ); + HRESULT COMLIGHTCALL loadAudioMemoryFile(const void* Buffer, uint64_t len, bool stereo, iAudioBuffer** pp); } \ No newline at end of file diff --git a/Whisper/Whisper.vcxproj b/Whisper/Whisper.vcxproj index 9052351..d9fcab1 100644 --- a/Whisper/Whisper.vcxproj +++ b/Whisper/Whisper.vcxproj @@ -1,366 +1,368 @@ - - - - - Debug - x64 - - - Release - x64 - - - - 16.0 - Win32Proj - {701df8c8-e4a5-43ec-9c6b-747bbf4d8e71} - Whisper - 10.0 - - - - DynamicLibrary - true - v143 - Unicode - - - DynamicLibrary - false - v143 - true - Unicode - - - - - - - - - - - - - - - $(ProjectDir);$(IncludePath) - - - $(ProjectDir);$(IncludePath) - - - - Level3 - true - _DEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) - true - stdcpp20 - Use - true - AdvancedVectorExtensions - - - Windows - true - false - whisper.def - - - - - Level3 - true - true - true - NDEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) - true - stdcpp20 - Use - true - MultiThreaded - AdvancedVectorExtensions - - - Windows - true - true - true - false - whisper.def - UseLinkTimeCodeGeneration - - - - - {52f486e7-830c-45d8-be47-e76b5aab2772} - - - - - - - AdvancedVectorExtensions2 - AdvancedVectorExtensions2 - - - AdvancedVectorExtensions - AdvancedVectorExtensions - - - AdvancedVectorExtensions - AdvancedVectorExtensions - - - - - - - - - - - AdvancedVectorExtensions - AdvancedVectorExtensions - - - - AdvancedVectorExtensions - AdvancedVectorExtensions - - - AdvancedVectorExtensions - AdvancedVectorExtensions - - - AdvancedVectorExtensions - AdvancedVectorExtensions - - - - - AdvancedVectorExtensions - AdvancedVectorExtensions - - - - - - NotUsing - - - - - - - - - - - - - - - - - - - - - - NotUsing - AdvancedVectorExtensions - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - true - true - - - true - true - - - - - Create - Create - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + Debug + x64 + + + Release + x64 + + + + 16.0 + Win32Proj + {701df8c8-e4a5-43ec-9c6b-747bbf4d8e71} + Whisper + 10.0 + + + + DynamicLibrary + true + v143 + Unicode + + + DynamicLibrary + false + v143 + true + Unicode + + + + + + + + + + + + + + + $(ProjectDir);$(IncludePath) + $(ProjectName)_k_d + + + $(ProjectDir);$(IncludePath) + $(ProjectName)_k + + + + Level3 + true + _DEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) + true + stdcpp20 + Use + true + AdvancedVectorExtensions + + + Windows + true + false + whisper.def + + + + + Level3 + true + true + true + NDEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) + true + stdcpp20 + Use + true + MultiThreaded + AdvancedVectorExtensions + + + Windows + true + true + true + false + whisper.def + UseLinkTimeCodeGeneration + + + + + {52f486e7-830c-45d8-be47-e76b5aab2772} + + + + + + + AdvancedVectorExtensions2 + AdvancedVectorExtensions2 + + + AdvancedVectorExtensions + AdvancedVectorExtensions + + + AdvancedVectorExtensions + AdvancedVectorExtensions + + + + + + + + + + + AdvancedVectorExtensions + AdvancedVectorExtensions + + + + AdvancedVectorExtensions + AdvancedVectorExtensions + + + AdvancedVectorExtensions + AdvancedVectorExtensions + + + AdvancedVectorExtensions + AdvancedVectorExtensions + + + + + AdvancedVectorExtensions + AdvancedVectorExtensions + + + + + + NotUsing + + + + + + + + + + + + + + + + + + + + + + NotUsing + AdvancedVectorExtensions + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + true + true + + + true + true + + + + + Create + Create + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h index 033f447..871aa2f 100644 --- a/Whisper/Whisper/ContextImpl.h +++ b/Whisper/Whisper/ContextImpl.h @@ -29,6 +29,7 @@ namespace Whisper HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) override final; HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) override final; HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) override final; + HRESULT COMLIGHTCALL getLastError(char* error, size_t errorSize) override final; struct Segment { diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp index 21c4a49..f1cb96a 100644 --- a/Whisper/Whisper/ContextImpl.misc.cpp +++ b/Whisper/Whisper/ContextImpl.misc.cpp @@ -356,6 +356,21 @@ int ContextImpl::wrapSegment( int max_len ) return res; } +static std::string lastError; + +HRESULT COMLIGHTCALL ContextImpl::getLastError( char *error, size_t errorSize ) +{ + if( errorSize == 0 ) + return E_INVALIDARG; + + if( errorSize > lastError.size() ) + errorSize = lastError.size(); + + memcpy( error, lastError.c_str(), errorSize ); + error[ errorSize] = 0; + return S_OK; +} + HRESULT COMLIGHTCALL ContextImpl::runFull( const sFullParams& params, const iAudioBuffer* buffer ) { #if SAVE_DEBUG_TRACE @@ -379,9 +394,15 @@ HRESULT COMLIGHTCALL ContextImpl::runFull( const sFullParams& params, const iAud try { + lastError = ""; sProgressSink progressSink{ nullptr, nullptr }; return runFullImpl( params, progressSink, spectrogram ); } + catch (const std::exception& e) + { + lastError = e.what(); + return E_FAIL; + } catch( HRESULT hr ) { return hr; @@ -401,6 +422,7 @@ HRESULT COMLIGHTCALL ContextImpl::runStreamed( const sFullParams& params, const try { + lastError = ""; if( params.cpuThreads > 1 ) { MelStreamerThread mel{ model.shared->filters, profiler, reader, params.cpuThreads }; @@ -412,6 +434,11 @@ HRESULT COMLIGHTCALL ContextImpl::runStreamed( const sFullParams& params, const return runFullImpl( params, progress, mel ); } } + catch (const std::exception& e) + { + lastError = e.what(); + return E_FAIL; + } catch( HRESULT hr ) { return hr;