@@ -13,62 +13,63 @@ class TenVad:
1313 def __init__ (self , hop_size : int = 256 , threshold : float = 0.5 ):
1414 self .hop_size = hop_size
1515 self .threshold = threshold
16+
17+ # Get the directory where this module is installed
18+ module_dir = os .path .dirname (os .path .abspath (__file__ ))
19+
1620 if platform .system () == "Linux" and platform .machine () == "x86_64" :
17- git_path = os .path .join (
18- os .path .dirname (os .path .relpath (__file__ )),
19- "../lib/Linux/x64/libten_vad.so"
20- )
21+ # Try git repo structure first
22+ git_path = os .path .join (module_dir , "../lib/Linux/x64/libten_vad.so" )
23+ # Try installed package structure (lib directory is inside the package)
24+ pip_path = os .path .join (module_dir , "lib/Linux/x64/libten_vad.so" )
25+
2126 if os .path .exists (git_path ):
2227 self .vad_library = CDLL (git_path )
23- else :
24- pip_path = os .path .join (
25- os .path .dirname (os .path .relpath (__file__ )),
26- "./ten_vad_library/libten_vad.so"
27- )
28+ elif os .path .exists (pip_path ):
2829 self .vad_library = CDLL (pip_path )
30+ else :
31+ raise FileNotFoundError ("Cannot find libten_vad.so at {} or {}" .format (git_path , pip_path ))
2932
3033 elif platform .system () == "Darwin" :
31- git_path = os .path .join (
32- os .path .dirname (os .path .relpath (__file__ )),
33- "../lib/macOS/ten_vad.framework/Versions/A/ten_vad"
34- )
34+ # Try git repo structure first
35+ git_path = os .path .join (module_dir , "../lib/macOS/ten_vad.framework/Versions/A/ten_vad" )
36+ # Try installed package structure
37+ pip_path = os .path .join (module_dir , "lib/macOS/ten_vad.framework/Versions/A/ten_vad" )
38+
3539 if os .path .exists (git_path ):
3640 self .vad_library = CDLL (git_path )
37- else :
38- pip_path = os .path .join (
39- os .path .dirname (os .path .relpath (__file__ )),
40- "./ten_vad_library/libten_vad"
41- )
41+ elif os .path .exists (pip_path ):
4242 self .vad_library = CDLL (pip_path )
43+ else :
44+ raise FileNotFoundError ("Cannot find libten_vad at {} or {}" .format (git_path , pip_path ))
45+
4346 elif platform .system ().upper () == 'WINDOWS' :
4447 if platform .machine ().upper () in ['X64' , 'X86_64' , 'AMD64' ]:
45- git_path = os .path .join (
46- os .path .dirname (os .path .realpath (__file__ )),
47- "../lib/Windows/x64/ten_vad.dll"
48- )
48+ # Try git repo structure first
49+ git_path = os .path .join (module_dir , "../lib/Windows/x64/ten_vad.dll" )
50+ # Try installed package structure
51+ pip_path = os .path .join (module_dir , "lib/Windows/x64/ten_vad.dll" )
52+
4953 if os .path .exists (git_path ):
5054 self .vad_library = CDLL (git_path )
51- else :
52- pip_path = os .path .join (
53- os .path .dirname (os .path .realpath (__file__ )),
54- "./ten_vad_library/ten_vad.dll"
55- )
55+ elif os .path .exists (pip_path ):
5656 self .vad_library = CDLL (pip_path )
57+ else :
58+ raise FileNotFoundError ("Cannot find ten_vad.dll at {} or {}" .format (git_path , pip_path ))
5759 else :
58- git_path = os .path .join (
59- os .path .dirname (os .path .realpath (__file__ )),
60- "../lib/Windows/x86/ten_vad.dll"
61- )
60+ # Try git repo structure first
61+ git_path = os .path .join (module_dir , "../lib/Windows/x86/ten_vad.dll" )
62+ # Try installed package structure
63+ pip_path = os .path .join (module_dir , "lib/Windows/x86/ten_vad.dll" )
64+
6265 if os .path .exists (git_path ):
6366 self .vad_library = CDLL (git_path )
64- else :
65- pip_path = os .path .join (
66- os .path .dirname (os .path .realpath (__file__ )),
67- "./ten_vad_library/ten_vad.dll"
68- )
67+ elif os .path .exists (pip_path ):
6968 self .vad_library = CDLL (pip_path )
69+ else :
70+ raise FileNotFoundError ("Cannot find ten_vad.dll at {} or {}" .format (git_path , pip_path ))
7071 else :
71- raise NotImplementedError (f "Unsupported platform: { platform .system ()} { platform .machine ()} " )
72+ raise NotImplementedError ("Unsupported platform: {} {}" . format ( platform .system (), platform .machine ()) )
7273 self .vad_handler = c_void_p (0 )
7374 self .out_probability = c_float ()
7475 self .out_flags = c_int32 ()
@@ -116,9 +117,7 @@ def get_input_data(self, audio_data: np.ndarray):
116117 assert (
117118 len (audio_data .shape ) == 1
118119 and audio_data .shape [0 ] == self .hop_size
119- ), "[TEN VAD]: audio data shape should be [%d]" % (
120- self .hop_size
121- )
120+ ), "[TEN VAD]: audio data shape should be [{}]" .format (self .hop_size )
122121 assert (
123122 type (audio_data [0 ]) == np .int16
124123 ), "[TEN VAD]: audio data type error, must be int16"
0 commit comments