Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions Examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
#include "miscUtils.h"
#include <array>
#include <atomic>
#include <fstream>
#include <vector>
#include "textWriter.h"
using namespace Whisper;

#define STREAM_AUDIO 1
#define USEBUFFER 0

static HRESULT loadWhisperModel( const wchar_t* path, const std::wstring& gpu, iModel** pp )
{
Expand Down Expand Up @@ -302,21 +305,46 @@ int wmain( int argc, wchar_t* argv[] )
wparams.encoder_begin_callback_user_data = &is_aborted;
}

if( STREAM_AUDIO && !wparams.flag( eFullParamsFlags::TokenTimestamps ) )
#ifdef USEBUFFER
printf("----- USING BUFFER-------");

std::ifstream file(fname, std::ios::binary);
file.seekg(0, std::ios::end);
size_t fileSize = file.tellg();
file.seekg(0, std::ios::beg);
std::vector<std::byte> data(fileSize);
file.read(reinterpret_cast<char*>(data.data()), fileSize);
file.close();
if (STREAM_AUDIO && !wparams.flag(eFullParamsFlags::TokenTimestamps)) {
ComLight::CComPtr<iAudioReader> reader;
CHECK(mf->loadAudioFileData(data.data(), data.size(), false, &reader));
sProgressSink progressSink{ nullptr, nullptr };
hr = context->runStreamed(wparams, progressSink, reader);
}
else
{
ComLight::CComPtr<iAudioBuffer> buffer;
CHECK(mf->loadAudioFileDataBuffer(data.data(), data.size(), false, &buffer));
hr = context->runFull(wparams, buffer);
}

#else
if (STREAM_AUDIO && !wparams.flag(eFullParamsFlags::TokenTimestamps))
{
ComLight::CComPtr<iAudioReader> reader;
CHECK( mf->openAudioFile( fname.c_str(), params.diarize, &reader ) );
CHECK(mf->openAudioFile(fname.c_str(), params.diarize, &reader));
sProgressSink progressSink{ nullptr, nullptr };
hr = context->runStreamed( wparams, progressSink, reader );
hr = context->runStreamed(wparams, progressSink, reader);
}
else
{
// Token-level timestamps feature is not currently implemented when streaming the audio
// When these timestamps are requested, fall back to buffered mode.
ComLight::CComPtr<iAudioBuffer> buffer;
CHECK( mf->loadAudioFile( fname.c_str(), params.diarize, &buffer ) );
hr = context->runFull( wparams, buffer );
CHECK(mf->loadAudioFile(fname.c_str(), params.diarize, &buffer));
hr = context->runFull(wparams, buffer);
}
#endif

if( FAILED( hr ) )
{
Expand Down
1 change: 1 addition & 0 deletions Whisper/API/iContext.cl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace Whisper
virtual HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) = 0;
virtual HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) = 0;
virtual HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) = 0;
virtual HRESULT COMLIGHTCALL getLastError(char* error, size_t errorSize) = 0;

virtual HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const = 0;
// Try to detect speaker by comparing channels of the stereo PCM data
Expand Down
1 change: 1 addition & 0 deletions Whisper/API/iContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace Whisper
HRESULT __stdcall runFull( const sFullParams& params, const iAudioBuffer* buffer );
HRESULT __stdcall runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader );
HRESULT __stdcall runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader );
HRESULT __stdcall getLastError(char* error, size_t errorSize);

HRESULT __stdcall getResults( eResultFlags flags, iTranscribeResult** pp ) const;
// Try to detect speaker by comparing channels of the stereo PCM data
Expand Down
1 change: 1 addition & 0 deletions Whisper/API/iMediaFoundation.cl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace Whisper
virtual HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const = 0;
virtual HRESULT COMLIGHTCALL openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp ) = 0;
virtual HRESULT COMLIGHTCALL loadAudioFileData( const void* data, uint64_t size, bool stereo, iAudioReader** pp ) = 0;
virtual HRESULT COMLIGHTCALL loadAudioFileDataBuffer(const void* data, uint64_t size, bool stereo, iAudioBuffer** pp) = 0;

virtual HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) = 0;
virtual HRESULT COMLIGHTCALL openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) = 0;
Expand Down
122 changes: 61 additions & 61 deletions Whisper/D3D/shaders.cpp
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
#include "stdafx.h"
#include "shaders.h"
#include "device.h"
#include "../Utils/LZ4/lz4.h"

namespace
{
#ifdef _DEBUG
#include "shaderData-Debug.inl"
#else
#include "shaderData-Release.inl"
#endif

// static std::vector<CComPtr<ID3D11ComputeShader>> s_shaders;
}

HRESULT DirectCompute::createComputeShaders( std::vector<CComPtr<ID3D11ComputeShader>>& shaders )
{
constexpr size_t countBinaries = s_shaderOffsets.size() - 1;
const size_t cbDecompressedLength = s_shaderOffsets[ countBinaries ];
constexpr size_t countShaders = s_shaderBlobs32.size();

std::vector<uint8_t> dxbc;
try
{
shaders.resize( countShaders );
dxbc.resize( cbDecompressedLength );
}
catch( const std::bad_alloc& )
{
return E_OUTOFMEMORY;
}

const int lz4Status = LZ4_decompress_safe( (const char*)s_compressedShaders.data(), (char*)dxbc.data(), (int)s_compressedShaders.size(), (int)cbDecompressedLength );
if( lz4Status != (int)cbDecompressedLength )
{
logError( u8"LZ4_decompress_safe failed with status %i", lz4Status );
return PLA_E_CABAPI_FAILURE;
}
ID3D11Device* const dev = device();

const auto& blobs = gpuInfo().wave64() ? s_shaderBlobs64 : s_shaderBlobs32;

for( size_t i = 0; i < countShaders; i++ )
{
const size_t idxBinary = blobs[ i ];
const uint32_t offThis = s_shaderOffsets[ idxBinary ];
const uint8_t* rsi = &dxbc[ offThis ];
const size_t len = s_shaderOffsets[ idxBinary + 1 ] - offThis;
const HRESULT hr = dev->CreateComputeShader( rsi, len, nullptr, &shaders[ i ] );
if( SUCCEEDED( hr ) )
continue;

const uint64_t binaryBit = ( 1ull << idxBinary );
if( 0 != ( binaryBit & fp64ShadersBitmap ) )
continue; // This shader uses FP64 math, the support for that is optional. When not supported, CreateComputeShader method is expected to fail.
// TODO [low]: ideally, query for the support when creating the device, and don't even try creating these compute shaders
return hr;
}

return S_OK;
#include "stdafx.h"
#include "shaders.h"
#include "device.h"
#include "../Utils/LZ4/lz4.h"
namespace
{
#ifdef _DEBUG
#include "shaderData-Release.inl" //"shaderData-Debug.inl"
#else
#include "shaderData-Release.inl"
#endif
// static std::vector<CComPtr<ID3D11ComputeShader>> s_shaders;
}
HRESULT DirectCompute::createComputeShaders( std::vector<CComPtr<ID3D11ComputeShader>>& shaders )
{
constexpr size_t countBinaries = s_shaderOffsets.size() - 1;
const size_t cbDecompressedLength = s_shaderOffsets[ countBinaries ];
constexpr size_t countShaders = s_shaderBlobs32.size();
std::vector<uint8_t> dxbc;
try
{
shaders.resize( countShaders );
dxbc.resize( cbDecompressedLength );
}
catch( const std::bad_alloc& )
{
return E_OUTOFMEMORY;
}
const int lz4Status = LZ4_decompress_safe( (const char*)s_compressedShaders.data(), (char*)dxbc.data(), (int)s_compressedShaders.size(), (int)cbDecompressedLength );
if( lz4Status != (int)cbDecompressedLength )
{
logError( u8"LZ4_decompress_safe failed with status %i", lz4Status );
return PLA_E_CABAPI_FAILURE;
}
ID3D11Device* const dev = device();
const auto& blobs = gpuInfo().wave64() ? s_shaderBlobs64 : s_shaderBlobs32;
for( size_t i = 0; i < countShaders; i++ )
{
const size_t idxBinary = blobs[ i ];
const uint32_t offThis = s_shaderOffsets[ idxBinary ];
const uint8_t* rsi = &dxbc[ offThis ];
const size_t len = s_shaderOffsets[ idxBinary + 1 ] - offThis;
const HRESULT hr = dev->CreateComputeShader( rsi, len, nullptr, &shaders[ i ] );
if( SUCCEEDED( hr ) )
continue;
const uint64_t binaryBit = ( 1ull << idxBinary );
if( 0 != ( binaryBit & fp64ShadersBitmap ) )
continue; // This shader uses FP64 math, the support for that is optional. When not supported, CreateComputeShader method is expected to fail.
// TODO [low]: ideally, query for the support when creating the device, and don't even try creating these compute shaders
return hr;
}
return S_OK;
}
4 changes: 4 additions & 0 deletions Whisper/MF/MediaFoundation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ namespace Whisper
res.detach( pp );
return S_OK;
}
HRESULT COMLIGHTCALL loadAudioFileDataBuffer(const void* data, uint64_t size, bool stereo, iAudioBuffer** pp) noexcept override final
{
return Whisper::loadAudioMemoryFile(data, size, stereo, pp);
}
HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) noexcept override final
{
return captureDeviceList( pfn, pv );
Expand Down
55 changes: 52 additions & 3 deletions Whisper/MF/loadAudioFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <mfidl.h>
#include <mfreadwrite.h>
#include <mfapi.h>
#include <shlwapi.h>
#pragma comment(lib, "Mfreadwrite.lib")
#pragma comment(lib, "mfuuid.lib")

Expand Down Expand Up @@ -37,8 +38,10 @@ namespace Whisper
rdi = 0;
return S_OK;
}
HRESULT _loadBuffer(CComPtr<IMFSourceReader> reader, LPCTSTR path, bool stereo);
public:
HRESULT load( LPCTSTR path, bool stereo );
HRESULT loadBuffer(const void* data, uint64_t size, bool stereo);
};

HRESULT MediaFileBuffer::load( LPCTSTR path, bool stereo )
Expand All @@ -51,8 +54,16 @@ namespace Whisper
return hr;
}

CHECK( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) );
CHECK( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) );
return _loadBuffer(reader, path, stereo);
}


HRESULT MediaFileBuffer::_loadBuffer(CComPtr<IMFSourceReader> reader, LPCTSTR path, bool stereo)
{

CHECK(reader->SetStreamSelection(MF_SOURCE_READER_ALL_STREAMS, FALSE));
CHECK(reader->SetStreamSelection(MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE));


CComPtr<IMFMediaType> mtNative;
CHECK( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) );
Expand All @@ -73,7 +84,7 @@ namespace Whisper
CComPtr<IMFSample> sample;

// Read the next sample.
hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample );
HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample );
if( FAILED( hr ) )
{
logErrorHr( hr, u8"IMFSourceReader.ReadSample" );
Expand Down Expand Up @@ -136,8 +147,34 @@ namespace Whisper
return S_OK;

}



HRESULT MediaFileBuffer::loadBuffer(const void* data, uint64_t size, bool stereo)
{

CComPtr<IMFSourceReader> reader;
CComPtr<IStream> comStream;
// Microsoft neglected to document their API, but Wine returns a new stream with reference counter = 1
// See shstream_create() function there https://source.winehq.org/source/dlls/shcore/main.c#0832
// That's why we need the Attach(), as opposed to an assignment
comStream.Attach(SHCreateMemStream((const BYTE*)data, (UINT)size));
if (!comStream)
{
logError(u8"SHCreateMemStream failed");
return E_FAIL;
}

CComPtr<IMFByteStream> mfStream;
CHECK(MFCreateMFByteStreamOnStream(comStream, &mfStream));
HRESULT hr = MFCreateSourceReaderFromByteStream(mfStream, nullptr, &reader);

return _loadBuffer(reader, L"memory", stereo);
}
}



HRESULT COMLIGHTCALL Whisper::loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp )
{
if( nullptr == path || nullptr == pp )
Expand All @@ -148,4 +185,16 @@ HRESULT COMLIGHTCALL Whisper::loadAudioFile( LPCTSTR path, bool stereo, iAudioBu
CHECK( obj->load( path, stereo ) );
obj.detach( pp );
return S_OK;
}

HRESULT COMLIGHTCALL Whisper::loadAudioMemoryFile(const void *data, uint64_t size, bool stereo, iAudioBuffer** pp)
{
if (nullptr == data || nullptr == pp || size == 0)
return E_POINTER;

ComLight::CComPtr<ComLight::Object<MediaFileBuffer>> obj;
CHECK(ComLight::Object<MediaFileBuffer>::create(obj));
CHECK(obj->loadBuffer(data, size, stereo));
obj.detach(pp);
return S_OK;
}
1 change: 1 addition & 0 deletions Whisper/MF/loadAudioFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
namespace Whisper
{
HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp );
HRESULT COMLIGHTCALL loadAudioMemoryFile(const void* Buffer, uint64_t len, bool stereo, iAudioBuffer** pp);
}
Loading