diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..56734ea0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea/ +*.pyc +__pycache__/ +.ipynb_checkpoints/ \ No newline at end of file diff --git a/README.md b/README.md index e35aa1a5..c1bf181a 100644 --- a/README.md +++ b/README.md @@ -89,10 +89,10 @@ epoch 360, loss_tr=0.033095 err_tr=0.009600 loss_te=4.254683 err_te=0.419954 err The converge is initially very fast (see the first 30 epochs). After that the performance improvement decreases and oscillations into the sentence error rate performance appear. Despite these oscillations an average improvement trend can be observed for the subsequent epochs. In this experiment, we stopped our training at epoch 360. The fields of the res.res file have the following meaning: - loss_tr: is the average training loss (i.e., cross-entropy function) computed at every frame. -- err_tr: is the classification error (measured at frame level) of the training data. Note that we split the speech signals into chunks of 200ms with 10ms overlap. The error is averaged for all the chunks of the training dataset. +- err_tr: is the classification error (measured at frame level) of the training data. Note that we split the speech signals into chunks of 200ms with 190ms overlap. The error is averaged for all the chunks of the training dataset. - loss_te is the average test loss (i.e., cross-entropy function) computed at every frame. - err_te: is the classification error (measured at frame level) of the test data. -- err_te_snt: is the classification error (measured at sentence level) of the test data. Note that we split the speech signal into chunks of 200ms with 10ms overlap. For each chunk, our SincNet performs a prediction over the set of speakers. To compute this classification error rate we averaged the predictions and, for each sentence, we voted for the speaker with the highest average probability. +- err_te_snt: is the classification error (measured at sentence level) of the test data. Note that we split the speech signal into chunks of 200ms with 190ms overlap. For each chunk, our SincNet performs a prediction over the set of speakers. To compute this classification error rate we averaged the predictions and, for each sentence, we voted for the speaker with the highest average probability. [You can find our trained model for TIMIT here.](https://bitbucket.org/mravanelli/sincnet_models/) diff --git a/TIMIT_preparation.py b/TIMIT_preparation.py index 16c9c4a2..0058d47d 100644 --- a/TIMIT_preparation.py +++ b/TIMIT_preparation.py @@ -1,17 +1,17 @@ #!/usr/bin/env python3 -# TIMIT_preparation -# Mirco Ravanelli -# Mila - University of Montreal +# TIMIT_preparation +# Mirco Ravanelli +# Mila - University of Montreal # July 2018 -# Description: -# This code prepares TIMIT for the following speaker identification experiments. +# Description: +# This code prepares TIMIT for the following speaker identification experiments. # It removes start and end silences according to the information reported in the *.wrd files and normalizes the amplitude of each sentence. - + # How to run it: -# python TIMIT_preparation.py $TIMIT_FOLDER $OUTPUT_FOLDER data_lists/TIMIT_all.scp +# python TIMIT_preparation.py $TIMIT_FOLDER $OUTPUT_FOLDER data_lists/TIMIT_all.scp # NOTE: This script expects filenames in lowercase (e.g, train/dr1/fcjf0/si1027.wav" rather than "TRAIN/DR1/FCJF0/SI1027.WAV) @@ -22,59 +22,58 @@ import numpy as np import sys + def ReadList(list_file): - f=open(list_file,"r") - lines=f.readlines() - list_sig=[] - for x in lines: + f = open(list_file, "r") + lines = f.readlines() + list_sig = [] + for x in lines: list_sig.append(x.rstrip()) - f.close() - return list_sig + f.close() + return list_sig -def copy_folder(in_folder,out_folder): - if not(os.path.isdir(out_folder)): - shutil.copytree(in_folder, out_folder, ignore=ig_f) -def ig_f(dir, files): - return [f for f in files if os.path.isfile(os.path.join(dir, f))] +def copy_folder(in_folder, out_folder): + if not (os.path.isdir(out_folder)): + shutil.copytree(in_folder, out_folder, ignore=ig_f) +def ig_f(dir, files): + return [f for f in files if os.path.isfile(os.path.join(dir, f))] + -in_folder=sys.argv[1] -out_folder=sys.argv[2] -list_file=sys.argv[3] +in_folder = sys.argv[1] +out_folder = sys.argv[2] +list_file = sys.argv[3] # Read List file -list_sig=ReadList(list_file) +list_sig = ReadList(list_file) # Replicate input folder structure to output folder -copy_folder(in_folder,out_folder) - +copy_folder(in_folder, out_folder) # Speech Data Reverberation Loop -for i in range(len(list_sig)): - - # Open the wav file - wav_file=in_folder+'/'+list_sig[i] - [signal, fs] = sf.read(wav_file) - signal=signal.astype(np.float64) - - # Signal normalization - signal=signal/np.max(np.abs(signal)) - - # Read wrd file - wrd_file=wav_file.replace(".wav",".wrd") - wrd_sig=ReadList(wrd_file) - beg_sig=int(wrd_sig[0].split(' ')[0]) - end_sig=int(wrd_sig[-1].split(' ')[1]) - - # Remove silences - signal=signal[beg_sig:end_sig] - - - # Save normalized speech - file_out=out_folder+'/'+list_sig[i] - - sf.write(file_out, signal, fs) - - print("Done %s" % (file_out)) +for i in range(len(list_sig)): + # Open the wav file + wav_file = in_folder + '/' + list_sig[i] + [signal, fs] = sf.read(wav_file) + signal = signal.astype(np.float64) + + # Signal normalization + signal = signal / np.max(np.abs(signal)) + + # Read wrd file + wrd_file = wav_file.replace(".wav", ".wrd") + wrd_sig = ReadList(wrd_file) + beg_sig = int(wrd_sig[0].split(' ')[0]) + end_sig = int(wrd_sig[-1].split(' ')[1]) + + # Remove silences + signal = signal[beg_sig:end_sig] + + # Save normalized speech + file_out = out_folder + '/' + list_sig[i] + + sf.write(file_out, signal, fs) + + print("Done %s" % (file_out)) diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/compute_d_vector.py b/compute_d_vector.py index df424ca8..eebe79c1 100644 --- a/compute_d_vector.py +++ b/compute_d_vector.py @@ -1,260 +1,183 @@ # compute_d_vector.py -# Mirco Ravanelli -# Mila - University of Montreal +# Mirco Ravanelli +# Mila - University of Montreal # Feb 2019 -# Description: +# Description: # This code computes d-vectors using a pre-trained model - -import os -import soundfile as sf -import torch -import torch.nn as nn -from torch.autograd import Variable +import collections + import numpy as np +import torch + +from data_io import read_conf_inp +from data_io import str_to_bool from dnn_models import MLP -from dnn_models import SincNet as CNN -from data_io import ReadList,read_conf_inp,str_to_bool -import sys +from dnn_models import SincNet as CNN # Model to use for computing the d-vectors -model_file='/home/mirco/sincnet_models/SincNet_TIMIT/model_raw.pkl' # This is the model to use for computing the d-vectors (it should be pre-trained using the speaker-id DNN) -cfg_file='/home/mirco/SincNet/cfg/SincNet_TIMIT.cfg' # Config file of the speaker-id experiment used to generate the model -te_lst='data_lists/TIMIT_test.scp' # List of the wav files to process -out_dict_file='d_vect_timit.npy' # output dictionary containing the a sentence id as key as the d-vector as value -data_folder='/home/mirco/Dataset/TIMIT_norm_nosil' +model_file = '/home/mirco/sincnet_models/SincNet_TIMIT/model_raw.pkl' # This is the model to use for computing the d-vectors (it should be pre-trained using the speaker-id DNN) +cfg_file = '/home/mirco/SincNet/cfg/SincNet_TIMIT.cfg' # Config file of the speaker-id experiment used to generate the model +te_lst = 'data_lists/TIMIT_test.scp' # List of the wav files to process +out_dict_file = 'd_vect_timit.npy' # output dictionary containing the a sentence id as key as the d-vector as value +data_folder = '/home/mirco/Dataset/TIMIT_norm_nosil' -avoid_small_en_fr=True +avoid_small_en_fr = True energy_th = 0.1 # Avoid frames with an energy that is 1/10 over the average energy -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -#device = None - +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Reading cfg file -options=read_conf_inp(cfg_file) - - -#[data] -pt_file=options.pt_file -output_folder=options.output_folder - -#[windowing] -fs=int(options.fs) -cw_len=int(options.cw_len) -cw_shift=int(options.cw_shift) - -#[cnn] -cnn_N_filt=list(map(int, options.cnn_N_filt.split(','))) -cnn_len_filt=list(map(int, options.cnn_len_filt.split(','))) -cnn_max_pool_len=list(map(int, options.cnn_max_pool_len.split(','))) -cnn_use_laynorm_inp=str_to_bool(options.cnn_use_laynorm_inp) -cnn_use_batchnorm_inp=str_to_bool(options.cnn_use_batchnorm_inp) -cnn_use_laynorm=list(map(str_to_bool, options.cnn_use_laynorm.split(','))) -cnn_use_batchnorm=list(map(str_to_bool, options.cnn_use_batchnorm.split(','))) -cnn_act=list(map(str, options.cnn_act.split(','))) -cnn_drop=list(map(float, options.cnn_drop.split(','))) - - -#[dnn] -fc_lay=list(map(int, options.fc_lay.split(','))) -fc_drop=list(map(float, options.fc_drop.split(','))) -fc_use_laynorm_inp=str_to_bool(options.fc_use_laynorm_inp) -fc_use_batchnorm_inp=str_to_bool(options.fc_use_batchnorm_inp) -fc_use_batchnorm=list(map(str_to_bool, options.fc_use_batchnorm.split(','))) -fc_use_laynorm=list(map(str_to_bool, options.fc_use_laynorm.split(','))) -fc_act=list(map(str, options.fc_act.split(','))) - -#[class] -class_lay=list(map(int, options.class_lay.split(','))) -class_drop=list(map(float, options.class_drop.split(','))) -class_use_laynorm_inp=str_to_bool(options.class_use_laynorm_inp) -class_use_batchnorm_inp=str_to_bool(options.class_use_batchnorm_inp) -class_use_batchnorm=list(map(str_to_bool, options.class_use_batchnorm.split(','))) -class_use_laynorm=list(map(str_to_bool, options.class_use_laynorm.split(','))) -class_act=list(map(str, options.class_act.split(','))) - - -wav_lst_te=ReadList(te_lst) -snt_te=len(wav_lst_te) - - -# Folder creation -try: - os.stat(output_folder) -except: - os.mkdir(output_folder) - - -# loss function -cost = nn.NLLLoss() - - -# Converting context and shift in samples -wlen=int(fs*cw_len/1000.00) -wshift=int(fs*cw_shift/1000.00) +options = read_conf_inp(cfg_file) + +# [windowing] +fs = int(options.fs) +cw_len = int(options.cw_len) +cw_shift = int(options.cw_shift) + +# [cnn] +cnn_N_filt = list(map(int, options.cnn_N_filt.split(','))) +cnn_len_filt = list(map(int, options.cnn_len_filt.split(','))) +cnn_max_pool_len = list(map(int, options.cnn_max_pool_len.split(','))) +cnn_use_laynorm_inp = str_to_bool(options.cnn_use_laynorm_inp) +cnn_use_batchnorm_inp = str_to_bool(options.cnn_use_batchnorm_inp) +cnn_use_laynorm = list(map(str_to_bool, options.cnn_use_laynorm.split(','))) +cnn_use_batchnorm = list(map(str_to_bool, options.cnn_use_batchnorm.split(','))) +cnn_act = list(map(str, options.cnn_act.split(','))) +cnn_drop = list(map(float, options.cnn_drop.split(','))) + +# [dnn] +fc_lay = list(map(int, options.fc_lay.split(','))) +fc_drop = list(map(float, options.fc_drop.split(','))) +fc_use_laynorm_inp = str_to_bool(options.fc_use_laynorm_inp) +fc_use_batchnorm_inp = str_to_bool(options.fc_use_batchnorm_inp) +fc_use_batchnorm = list(map(str_to_bool, options.fc_use_batchnorm.split(','))) +fc_use_laynorm = list(map(str_to_bool, options.fc_use_laynorm.split(','))) +fc_act = list(map(str, options.fc_act.split(','))) + -# Batch_dev -Batch_dev=128 +# Converting context and shift in samples +wlen = int(fs * cw_len / 1000.00) +wshift = int(fs * cw_shift / 1000.00) +BATCH_SIZE = 128 # Feature extractor CNN CNN_arch = {'input_dim': wlen, - 'fs': fs, - 'cnn_N_filt': cnn_N_filt, - 'cnn_len_filt': cnn_len_filt, - 'cnn_max_pool_len':cnn_max_pool_len, - 'cnn_use_laynorm_inp': cnn_use_laynorm_inp, - 'cnn_use_batchnorm_inp': cnn_use_batchnorm_inp, - 'cnn_use_laynorm':cnn_use_laynorm, - 'cnn_use_batchnorm':cnn_use_batchnorm, - 'cnn_act': cnn_act, - 'cnn_drop':cnn_drop, - } - -CNN_net=CNN(CNN_arch) + 'fs': fs, + 'cnn_N_filt': cnn_N_filt, + 'cnn_len_filt': cnn_len_filt, + 'cnn_max_pool_len': cnn_max_pool_len, + 'cnn_use_laynorm_inp': cnn_use_laynorm_inp, + 'cnn_use_batchnorm_inp': cnn_use_batchnorm_inp, + 'cnn_use_laynorm': cnn_use_laynorm, + 'cnn_use_batchnorm': cnn_use_batchnorm, + 'cnn_act': cnn_act, + 'cnn_drop': cnn_drop} + +CNN_net = CNN(CNN_arch) CNN_net.to(device) - - DNN1_arch = {'input_dim': CNN_net.out_dim, - 'fc_lay': fc_lay, - 'fc_drop': fc_drop, - 'fc_use_batchnorm': fc_use_batchnorm, - 'fc_use_laynorm': fc_use_laynorm, - 'fc_use_laynorm_inp': fc_use_laynorm_inp, - 'fc_use_batchnorm_inp':fc_use_batchnorm_inp, - 'fc_act': fc_act, - } - -DNN1_net=MLP(DNN1_arch) + 'fc_lay': fc_lay, + 'fc_drop': fc_drop, + 'fc_use_batchnorm': fc_use_batchnorm, + 'fc_use_laynorm': fc_use_laynorm, + 'fc_use_laynorm_inp': fc_use_laynorm_inp, + 'fc_use_batchnorm_inp': fc_use_batchnorm_inp, + 'fc_act': fc_act} + +DNN1_net = MLP(DNN1_arch) DNN1_net.to(device) - -DNN2_arch = {'input_dim':fc_lay[-1] , - 'fc_lay': class_lay, - 'fc_drop': class_drop, - 'fc_use_batchnorm': class_use_batchnorm, - 'fc_use_laynorm': class_use_laynorm, - 'fc_use_laynorm_inp': class_use_laynorm_inp, - 'fc_use_batchnorm_inp':class_use_batchnorm_inp, - 'fc_act': class_act, - } - - -DNN2_net=MLP(DNN2_arch) -DNN2_net.to(device) - - -checkpoint_load = torch.load(model_file) +checkpoint_load = torch.load(model_file, map_location=device) +model_trained_using_data_parallel = False +if model_trained_using_data_parallel: + new_ckpt = {} + for k, v in checkpoint_load.items(): + new_v = collections.OrderedDict() + for kk, vv in v.items(): + if kk.startswith('module.'): + kk = '.'.join(kk.split('.')[1:]) + else: + assert False + new_v[kk] = vv + new_ckpt[k] = new_v + checkpoint_load = new_ckpt CNN_net.load_state_dict(checkpoint_load['CNN_model_par']) DNN1_net.load_state_dict(checkpoint_load['DNN1_model_par']) -DNN2_net.load_state_dict(checkpoint_load['DNN2_model_par']) - - CNN_net.eval() DNN1_net.eval() -DNN2_net.eval() -test_flag=1 - - -d_vector_dim=fc_lay[-1] -d_vect_dict={} - - -with torch.no_grad(): - - for i in range(snt_te): - - [signal, fs] = sf.read(data_folder+'/'+wav_lst_te[i]) - - # Amplitude normalization - signal=signal/np.max(np.abs(signal)) - - signal=torch.from_numpy(signal).float().to(device).contiguous() - - if avoid_small_en_fr: - # computing energy on each frame: - beg_samp=0 - end_samp=wlen - - N_fr=int((signal.shape[0]-wlen)/(wshift)) - Batch_dev=N_fr - en_arr=torch.zeros(N_fr).float().contiguous().to(device) - count_fr=0 - count_fr_tot=0 - while end_samptorch.mean(en_arr)*0.1 - en_arr_bin.to(device) - n_vect_elem=torch.sum(en_arr_bin) - - if n_vect_elem<10: - print('only few elements used to compute d-vectors') - sys.exit(0) - - - - # split signals into chunks - beg_samp=0 - end_samp=wlen - - N_fr=int((signal.shape[0]-wlen)/(wshift)) - - - sig_arr=torch.zeros([Batch_dev,wlen]).float().to(device).contiguous() - dvects=Variable(torch.zeros(N_fr,d_vector_dim).float().to(device).contiguous()) - count_fr=0 - count_fr_tot=0 - while end_samp0: - inp=Variable(sig_arr[0:count_fr]) - dvects[count_fr_tot-count_fr:count_fr_tot,:]=DNN1_net(CNN_net(inp)) - - if avoid_small_en_fr: - dvects=dvects.index_select(0, (en_arr_bin==1).nonzero().view(-1)) - - # averaging and normalizing all the d-vectors - d_vect_out=torch.mean(dvects/dvects.norm(p=2, dim=1).view(-1,1),dim=0) - - # checks for nan - nan_sum=torch.sum(torch.isnan(d_vect_out)) - - if nan_sum>0: - print(wav_lst_te[i]) - sys.exit(0) - - - # saving the d-vector in a numpy dictionary - dict_key=wav_lst_te[i].split('/')[-2]+'/'+wav_lst_te[i].split('/')[-1] - d_vect_dict[dict_key]=d_vect_out.cpu().numpy() - print(dict_key) - -# Save the dictionary -np.save(out_dict_file, d_vect_dict) - - - - - +d_vector_dim = fc_lay[-1] + +def audio_samples_to_d_vectors(signal: np.ndarray): + with torch.no_grad(): + # Amplitude normalization + signal = signal / np.max(np.abs(signal)) + + signal = torch.from_numpy(signal).float().to(device).contiguous() + + if avoid_small_en_fr: + # computing energy on each frame + en_N_fr_actual = 0 + en_N_fr = (signal.shape[0] - wlen) // wshift + 1 + en_arr = torch.zeros([en_N_fr]).float().cuda(device).contiguous() + for i_sig, beg_samp in enumerate(range(0, signal.shape[0], wshift)): + end_samp = beg_samp + wlen + if end_samp > signal.shape[0]: + break + else: + en_arr[i_sig] = torch.sum(signal[beg_samp:end_samp].pow(2)).item() + en_N_fr_actual += 1 + assert en_N_fr == en_N_fr_actual + + en_arr_bin = en_arr > torch.mean(en_arr) * 0.1 + en_arr_bin.to(device) + n_vect_elem = torch.sum(en_arr_bin) + + if n_vect_elem < 10: + raise Exception('Low energy') + + sig_arr = [] + d_vectors = [] + for beg_samp in range(0, signal.shape[0], wshift): + end_samp = beg_samp + wlen + if end_samp > signal.shape[0]: + break + else: + sig_arr.append(torch.unsqueeze(signal[beg_samp:end_samp], dim=0)) + if len(sig_arr) == BATCH_SIZE: + out = DNN1_net(CNN_net(torch.cat(sig_arr, dim=0))) + d_vectors.append(out) + sig_arr = [] + if sig_arr: + out = DNN1_net(CNN_net(torch.cat(sig_arr, dim=0))) + d_vectors.append(out) + if len(d_vectors) == 0: + raise Exception('Empty d-vectors') + d_vectors = torch.cat(d_vectors, dim=0) + + if avoid_small_en_fr: + d_vectors = d_vectors.index_select(0, (en_arr_bin == 1).nonzero().view(-1)) + + if d_vectors.shape[0] == 0: + raise Exception('Empty d-vectors') + return d_vectors.cpu().numpy() + + +def normalize_d_vectors(d_vectors): + with torch.no_grad(): + # averaging and normalizing all the d-vectors + d_vectors = torch.from_numpy(d_vectors).to(device) + d_vector_out = torch.mean(d_vectors / d_vectors.norm(p=2, dim=1).view(-1, 1), dim=0) + + # checks for nan + nan_sum = torch.sum(torch.isnan(d_vector_out)) + + if nan_sum > 0: + return Exception('NaN encountered when normalizing d-vectors') + else: + return d_vector_out.cpu().numpy() diff --git a/data_io.py b/data_io.py index 863c88e0..ed75559c 100644 --- a/data_io.py +++ b/data_io.py @@ -1,181 +1,141 @@ import configparser as ConfigParser from optparse import OptionParser -import numpy as np -#import scipy.io.wavfile -import torch + def ReadList(list_file): - f=open(list_file,"r") - lines=f.readlines() - list_sig=[] - for x in lines: + f = open(list_file, "r") + lines = f.readlines() + list_sig = [] + for x in lines: list_sig.append(x.rstrip()) - f.close() - return list_sig + f.close() + return list_sig def read_conf(): - - parser=OptionParser() - parser.add_option("--cfg") # Mandatory - (options,args)=parser.parse_args() - cfg_file=options.cfg - Config = ConfigParser.ConfigParser() - Config.read(cfg_file) - - #[data] - options.tr_lst=Config.get('data', 'tr_lst') - options.te_lst=Config.get('data', 'te_lst') - options.lab_dict=Config.get('data', 'lab_dict') - options.data_folder=Config.get('data', 'data_folder') - options.output_folder=Config.get('data', 'output_folder') - options.pt_file=Config.get('data', 'pt_file') - - #[windowing] - options.fs=Config.get('windowing', 'fs') - options.cw_len=Config.get('windowing', 'cw_len') - options.cw_shift=Config.get('windowing', 'cw_shift') - - #[cnn] - options.cnn_N_filt=Config.get('cnn', 'cnn_N_filt') - options.cnn_len_filt=Config.get('cnn', 'cnn_len_filt') - options.cnn_max_pool_len=Config.get('cnn', 'cnn_max_pool_len') - options.cnn_use_laynorm_inp=Config.get('cnn', 'cnn_use_laynorm_inp') - options.cnn_use_batchnorm_inp=Config.get('cnn', 'cnn_use_batchnorm_inp') - options.cnn_use_laynorm=Config.get('cnn', 'cnn_use_laynorm') - options.cnn_use_batchnorm=Config.get('cnn', 'cnn_use_batchnorm') - options.cnn_act=Config.get('cnn', 'cnn_act') - options.cnn_drop=Config.get('cnn', 'cnn_drop') - - - #[dnn] - options.fc_lay=Config.get('dnn', 'fc_lay') - options.fc_drop=Config.get('dnn', 'fc_drop') - options.fc_use_laynorm_inp=Config.get('dnn', 'fc_use_laynorm_inp') - options.fc_use_batchnorm_inp=Config.get('dnn', 'fc_use_batchnorm_inp') - options.fc_use_batchnorm=Config.get('dnn', 'fc_use_batchnorm') - options.fc_use_laynorm=Config.get('dnn', 'fc_use_laynorm') - options.fc_act=Config.get('dnn', 'fc_act') - - #[class] - options.class_lay=Config.get('class', 'class_lay') - options.class_drop=Config.get('class', 'class_drop') - options.class_use_laynorm_inp=Config.get('class', 'class_use_laynorm_inp') - options.class_use_batchnorm_inp=Config.get('class', 'class_use_batchnorm_inp') - options.class_use_batchnorm=Config.get('class', 'class_use_batchnorm') - options.class_use_laynorm=Config.get('class', 'class_use_laynorm') - options.class_act=Config.get('class', 'class_act') - - - #[optimization] - options.lr=Config.get('optimization', 'lr') - options.batch_size=Config.get('optimization', 'batch_size') - options.N_epochs=Config.get('optimization', 'N_epochs') - options.N_batches=Config.get('optimization', 'N_batches') - options.N_eval_epoch=Config.get('optimization', 'N_eval_epoch') - options.seed=Config.get('optimization', 'seed') - - return options + parser = OptionParser() + parser.add_option("--cfg") # Mandatory + (options, args) = parser.parse_args() + cfg_file = options.cfg + Config = ConfigParser.ConfigParser() + Config.read(cfg_file) + + # [data] + options.tr_lst = Config.get('data', 'tr_lst') + options.te_lst = Config.get('data', 'te_lst') + options.lab_dict = Config.get('data', 'lab_dict') + options.data_folder = Config.get('data', 'data_folder') + options.output_folder = Config.get('data', 'output_folder') + options.pt_file = Config.get('data', 'pt_file') + + # [windowing] + options.fs = Config.get('windowing', 'fs') + options.cw_len = Config.get('windowing', 'cw_len') + options.cw_shift = Config.get('windowing', 'cw_shift') + + # [cnn] + options.cnn_N_filt = Config.get('cnn', 'cnn_N_filt') + options.cnn_len_filt = Config.get('cnn', 'cnn_len_filt') + options.cnn_max_pool_len = Config.get('cnn', 'cnn_max_pool_len') + options.cnn_use_laynorm_inp = Config.get('cnn', 'cnn_use_laynorm_inp') + options.cnn_use_batchnorm_inp = Config.get('cnn', 'cnn_use_batchnorm_inp') + options.cnn_use_laynorm = Config.get('cnn', 'cnn_use_laynorm') + options.cnn_use_batchnorm = Config.get('cnn', 'cnn_use_batchnorm') + options.cnn_act = Config.get('cnn', 'cnn_act') + options.cnn_drop = Config.get('cnn', 'cnn_drop') + + # [dnn] + options.fc_lay = Config.get('dnn', 'fc_lay') + options.fc_drop = Config.get('dnn', 'fc_drop') + options.fc_use_laynorm_inp = Config.get('dnn', 'fc_use_laynorm_inp') + options.fc_use_batchnorm_inp = Config.get('dnn', 'fc_use_batchnorm_inp') + options.fc_use_batchnorm = Config.get('dnn', 'fc_use_batchnorm') + options.fc_use_laynorm = Config.get('dnn', 'fc_use_laynorm') + options.fc_act = Config.get('dnn', 'fc_act') + + # [class] + options.class_lay = Config.get('class', 'class_lay') + options.class_drop = Config.get('class', 'class_drop') + options.class_use_laynorm_inp = Config.get('class', 'class_use_laynorm_inp') + options.class_use_batchnorm_inp = Config.get('class', 'class_use_batchnorm_inp') + options.class_use_batchnorm = Config.get('class', 'class_use_batchnorm') + options.class_use_laynorm = Config.get('class', 'class_use_laynorm') + options.class_act = Config.get('class', 'class_act') + + # [optimization] + options.lr = Config.get('optimization', 'lr') + options.batch_size = Config.get('optimization', 'batch_size') + options.N_epochs = Config.get('optimization', 'N_epochs') + options.N_batches = Config.get('optimization', 'N_batches') + options.N_eval_epoch = Config.get('optimization', 'N_eval_epoch') + options.seed = Config.get('optimization', 'seed') + + return options def str_to_bool(s): - if s == 'True': - return True - elif s == 'False': - return False - else: - raise ValueError - - -def create_batches_rnd(batch_size,data_folder,wav_lst,N_snt,wlen,lab_dict,fact_amp): - - # Initialization of the minibatch (batch_size,[0=>x_t,1=>x_t+N,1=>random_samp]) - sig_batch=np.zeros([batch_size,wlen]) - lab_batch=np.zeros(batch_size) - - snt_id_arr=np.random.randint(N_snt, size=batch_size) - - rand_amp_arr = np.random.uniform(1.0-fact_amp,1+fact_amp,batch_size) - - for i in range(batch_size): - - # select a random sentence from the list (joint distribution) - [fs,signal]=scipy.io.wavfile.read(data_folder+wav_lst[snt_id_arr[i]]) - signal=signal.astype(float)/32768 - - # accesing to a random chunk - snt_len=signal.shape[0] - snt_beg=np.random.randint(snt_len-wlen-1) #randint(0, snt_len-2*wlen-1) - snt_end=snt_beg+wlen - - sig_batch[i,:]=signal[snt_beg:snt_end]*rand_amp_arr[i] - lab_batch[i]=lab_dict[wav_lst[snt_id_arr[i]]] - - inp=torch.from_numpy(sig_batch).float().cuda().contiguous() # Current Frame - lab=torch.from_numpy(lab_batch).float().cuda().contiguous() - - return inp,lab - - + if s == 'True': + return True + elif s == 'False': + return False + else: + raise ValueError def read_conf_inp(cfg_file): - - parser=OptionParser() - (options,args)=parser.parse_args() - - Config = ConfigParser.ConfigParser() - Config.read(cfg_file) - - #[data] - options.tr_lst=Config.get('data', 'tr_lst') - options.te_lst=Config.get('data', 'te_lst') - options.lab_dict=Config.get('data', 'lab_dict') - options.data_folder=Config.get('data', 'data_folder') - options.output_folder=Config.get('data', 'output_folder') - options.pt_file=Config.get('data', 'pt_file') - - #[windowing] - options.fs=Config.get('windowing', 'fs') - options.cw_len=Config.get('windowing', 'cw_len') - options.cw_shift=Config.get('windowing', 'cw_shift') - - #[cnn] - options.cnn_N_filt=Config.get('cnn', 'cnn_N_filt') - options.cnn_len_filt=Config.get('cnn', 'cnn_len_filt') - options.cnn_max_pool_len=Config.get('cnn', 'cnn_max_pool_len') - options.cnn_use_laynorm_inp=Config.get('cnn', 'cnn_use_laynorm_inp') - options.cnn_use_batchnorm_inp=Config.get('cnn', 'cnn_use_batchnorm_inp') - options.cnn_use_laynorm=Config.get('cnn', 'cnn_use_laynorm') - options.cnn_use_batchnorm=Config.get('cnn', 'cnn_use_batchnorm') - options.cnn_act=Config.get('cnn', 'cnn_act') - options.cnn_drop=Config.get('cnn', 'cnn_drop') - - - #[dnn] - options.fc_lay=Config.get('dnn', 'fc_lay') - options.fc_drop=Config.get('dnn', 'fc_drop') - options.fc_use_laynorm_inp=Config.get('dnn', 'fc_use_laynorm_inp') - options.fc_use_batchnorm_inp=Config.get('dnn', 'fc_use_batchnorm_inp') - options.fc_use_batchnorm=Config.get('dnn', 'fc_use_batchnorm') - options.fc_use_laynorm=Config.get('dnn', 'fc_use_laynorm') - options.fc_act=Config.get('dnn', 'fc_act') - - #[class] - options.class_lay=Config.get('class', 'class_lay') - options.class_drop=Config.get('class', 'class_drop') - options.class_use_laynorm_inp=Config.get('class', 'class_use_laynorm_inp') - options.class_use_batchnorm_inp=Config.get('class', 'class_use_batchnorm_inp') - options.class_use_batchnorm=Config.get('class', 'class_use_batchnorm') - options.class_use_laynorm=Config.get('class', 'class_use_laynorm') - options.class_act=Config.get('class', 'class_act') - - - #[optimization] - options.lr=Config.get('optimization', 'lr') - options.batch_size=Config.get('optimization', 'batch_size') - options.N_epochs=Config.get('optimization', 'N_epochs') - options.N_batches=Config.get('optimization', 'N_batches') - options.N_eval_epoch=Config.get('optimization', 'N_eval_epoch') - options.seed=Config.get('optimization', 'seed') - - return options \ No newline at end of file + parser = OptionParser() + (options, args) = parser.parse_args([]) + + Config = ConfigParser.ConfigParser() + Config.read(cfg_file) + + # [data] + options.tr_lst = Config.get('data', 'tr_lst') + options.te_lst = Config.get('data', 'te_lst') + options.lab_dict = Config.get('data', 'lab_dict') + options.data_folder = Config.get('data', 'data_folder') + options.output_folder = Config.get('data', 'output_folder') + options.pt_file = Config.get('data', 'pt_file') + + # [windowing] + options.fs = Config.get('windowing', 'fs') + options.cw_len = Config.get('windowing', 'cw_len') + options.cw_shift = Config.get('windowing', 'cw_shift') + + # [cnn] + options.cnn_N_filt = Config.get('cnn', 'cnn_N_filt') + options.cnn_len_filt = Config.get('cnn', 'cnn_len_filt') + options.cnn_max_pool_len = Config.get('cnn', 'cnn_max_pool_len') + options.cnn_use_laynorm_inp = Config.get('cnn', 'cnn_use_laynorm_inp') + options.cnn_use_batchnorm_inp = Config.get('cnn', 'cnn_use_batchnorm_inp') + options.cnn_use_laynorm = Config.get('cnn', 'cnn_use_laynorm') + options.cnn_use_batchnorm = Config.get('cnn', 'cnn_use_batchnorm') + options.cnn_act = Config.get('cnn', 'cnn_act') + options.cnn_drop = Config.get('cnn', 'cnn_drop') + + # [dnn] + options.fc_lay = Config.get('dnn', 'fc_lay') + options.fc_drop = Config.get('dnn', 'fc_drop') + options.fc_use_laynorm_inp = Config.get('dnn', 'fc_use_laynorm_inp') + options.fc_use_batchnorm_inp = Config.get('dnn', 'fc_use_batchnorm_inp') + options.fc_use_batchnorm = Config.get('dnn', 'fc_use_batchnorm') + options.fc_use_laynorm = Config.get('dnn', 'fc_use_laynorm') + options.fc_act = Config.get('dnn', 'fc_act') + + # [class] + options.class_lay = Config.get('class', 'class_lay') + options.class_drop = Config.get('class', 'class_drop') + options.class_use_laynorm_inp = Config.get('class', 'class_use_laynorm_inp') + options.class_use_batchnorm_inp = Config.get('class', 'class_use_batchnorm_inp') + options.class_use_batchnorm = Config.get('class', 'class_use_batchnorm') + options.class_use_laynorm = Config.get('class', 'class_use_laynorm') + options.class_act = Config.get('class', 'class_act') + + # [optimization] + options.lr = Config.get('optimization', 'lr') + options.batch_size = Config.get('optimization', 'batch_size') + options.N_epochs = Config.get('optimization', 'N_epochs') + options.N_batches = Config.get('optimization', 'N_batches') + options.N_eval_epoch = Config.get('optimization', 'N_eval_epoch') + options.seed = Config.get('optimization', 'seed') + + return options diff --git a/dnn_models.py b/dnn_models.py index 62421af0..2ba4bc7d 100644 --- a/dnn_models.py +++ b/dnn_models.py @@ -6,460 +6,430 @@ from torch.autograd import Variable import math + def flip(x, dim): - xsize = x.size() - dim = x.dim() + dim if dim < 0 else dim - x = x.contiguous() - x = x.view(-1, *xsize[dim:]) - x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, - -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] - return x.view(xsize) + xsize = x.size() + dim = x.dim() + dim if dim < 0 else dim + x = x.contiguous() + x = x.view(-1, *xsize[dim:]) + x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, + -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] + return x.view(xsize) + +def sinc(band, t_right): + y_right = torch.sin(2 * math.pi * band * t_right) / (2 * math.pi * band * t_right) + y_left = flip(y_right, 0) -def sinc(band,t_right): - y_right= torch.sin(2*math.pi*band*t_right)/(2*math.pi*band*t_right) - y_left= flip(y_right,0) + y = torch.cat([y_left, Variable(torch.ones(1)).cuda(), y_right]) - y=torch.cat([y_left,Variable(torch.ones(1)).cuda(),y_right]) + return y - return y - class SincConv_fast(nn.Module): - """Sinc-based convolution + """Sinc-based convolution + Parameters + ---------- + in_channels : `int` + Number of input channels. Must be 1. + out_channels : `int` + Number of filters. + kernel_size : `int` + Filter length. + sample_rate : `int`, optional + Sample rate. Defaults to 16000. + Usage + ----- + See `torch.nn.Conv1d` + Reference + --------- + Mirco Ravanelli, Yoshua Bengio, + "Speaker Recognition from raw waveform with SincNet". + https://arxiv.org/abs/1808.00158 + """ + + @staticmethod + def to_mel(hz): + return 2595 * np.log10(1 + hz / 700) + + @staticmethod + def to_hz(mel): + return 700 * (10 ** (mel / 2595) - 1) + + def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, + stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50): + + super(SincConv_fast, self).__init__() + + if in_channels != 1: + # msg = (f'SincConv only support one input channel ' + # f'(here, in_channels = {in_channels:d}).') + msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) + raise ValueError(msg) + + self.out_channels = out_channels + self.kernel_size = kernel_size + + # Forcing the filters to be odd (i.e, perfectly symmetrics) + if kernel_size % 2 == 0: + self.kernel_size = self.kernel_size + 1 + + self.stride = stride + self.padding = padding + self.dilation = dilation + + if bias: + raise ValueError('SincConv does not support bias.') + if groups > 1: + raise ValueError('SincConv does not support groups.') + + self.sample_rate = sample_rate + self.min_low_hz = min_low_hz + self.min_band_hz = min_band_hz + + # initialize filterbanks such that they are equally spaced in Mel scale + low_hz = 30 + high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) + + mel = np.linspace(self.to_mel(low_hz), + self.to_mel(high_hz), + self.out_channels + 1) + hz = self.to_hz(mel) + + # filter lower frequency (out_channels, 1) + self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) + + # filter frequency band (out_channels, 1) + self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) + + # Hamming window + # self.window_ = torch.hamming_window(self.kernel_size) + n_lin = torch.linspace(0, (self.kernel_size / 2) - 1, + steps=int((self.kernel_size / 2))) # computing only half of the window + self.window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / self.kernel_size); + + # (kernel_size, 1) + n = (self.kernel_size - 1) / 2.0 + self.n_ = 2 * math.pi * torch.arange(-n, 0).view(1, + -1) / self.sample_rate # Due to symmetry, I only need half of the time axes + + def forward(self, waveforms): + """ Parameters ---------- - in_channels : `int` - Number of input channels. Must be 1. - out_channels : `int` - Number of filters. - kernel_size : `int` - Filter length. - sample_rate : `int`, optional - Sample rate. Defaults to 16000. - Usage - ----- - See `torch.nn.Conv1d` - Reference - --------- - Mirco Ravanelli, Yoshua Bengio, - "Speaker Recognition from raw waveform with SincNet". - https://arxiv.org/abs/1808.00158 + waveforms : `torch.Tensor` (batch_size, 1, n_samples) + Batch of waveforms. + Returns + ------- + features : `torch.Tensor` (batch_size, out_channels, n_samples_out) + Batch of sinc filters activations. """ - @staticmethod - def to_mel(hz): - return 2595 * np.log10(1 + hz / 700) - - @staticmethod - def to_hz(mel): - return 700 * (10 ** (mel / 2595) - 1) - - def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, - stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50): - - super(SincConv_fast,self).__init__() - - if in_channels != 1: - #msg = (f'SincConv only support one input channel ' - # f'(here, in_channels = {in_channels:d}).') - msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) - raise ValueError(msg) - - self.out_channels = out_channels - self.kernel_size = kernel_size - - # Forcing the filters to be odd (i.e, perfectly symmetrics) - if kernel_size%2==0: - self.kernel_size=self.kernel_size+1 - - self.stride = stride - self.padding = padding - self.dilation = dilation - - if bias: - raise ValueError('SincConv does not support bias.') - if groups > 1: - raise ValueError('SincConv does not support groups.') - - self.sample_rate = sample_rate - self.min_low_hz = min_low_hz - self.min_band_hz = min_band_hz - - # initialize filterbanks such that they are equally spaced in Mel scale - low_hz = 30 - high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) - - mel = np.linspace(self.to_mel(low_hz), - self.to_mel(high_hz), - self.out_channels + 1) - hz = self.to_hz(mel) - - - # filter lower frequency (out_channels, 1) - self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) - - # filter frequency band (out_channels, 1) - self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) - - # Hamming window - #self.window_ = torch.hamming_window(self.kernel_size) - n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window - self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size); - - - # (kernel_size, 1) - n = (self.kernel_size - 1) / 2.0 - self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes - - - - - def forward(self, waveforms): - """ - Parameters - ---------- - waveforms : `torch.Tensor` (batch_size, 1, n_samples) - Batch of waveforms. - Returns - ------- - features : `torch.Tensor` (batch_size, out_channels, n_samples_out) - Batch of sinc filters activations. - """ - - self.n_ = self.n_.to(waveforms.device) - - self.window_ = self.window_.to(waveforms.device) - - low = self.min_low_hz + torch.abs(self.low_hz_) - - high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2) - band=(high-low)[:,0] - - f_times_t_low = torch.matmul(low, self.n_) - f_times_t_high = torch.matmul(high, self.n_) - - band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. - band_pass_center = 2*band.view(-1,1) - band_pass_right= torch.flip(band_pass_left,dims=[1]) - - - band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) - - - band_pass = band_pass / (2*band[:,None]) - - - self.filters = (band_pass).view( - self.out_channels, 1, self.kernel_size) - - return F.conv1d(waveforms, self.filters, stride=self.stride, - padding=self.padding, dilation=self.dilation, - bias=None, groups=1) - - - - + self.n_ = self.n_.to(waveforms.device) + + self.window_ = self.window_.to(waveforms.device) + + low = self.min_low_hz + torch.abs(self.low_hz_) + + high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2) + band = (high - low)[:, 0] + + f_times_t_low = torch.matmul(low, self.n_) + f_times_t_high = torch.matmul(high, self.n_) + + band_pass_left = ((torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) / ( + self.n_ / 2)) * self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. + band_pass_center = 2 * band.view(-1, 1) + band_pass_right = torch.flip(band_pass_left, dims=[1]) + + band_pass = torch.cat([band_pass_left, band_pass_center, band_pass_right], dim=1) + + band_pass = band_pass / (2 * band[:, None]) + + self.filters = (band_pass).view( + self.out_channels, 1, self.kernel_size) + + return F.conv1d(waveforms, self.filters, stride=self.stride, + padding=self.padding, dilation=self.dilation, + bias=None, groups=1) + + class sinc_conv(nn.Module): - def __init__(self, N_filt,Filt_dim,fs): - super(sinc_conv,self).__init__() - - # Mel Initialization of the filterbanks - low_freq_mel = 80 - high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700)) # Convert Hz to Mel - mel_points = np.linspace(low_freq_mel, high_freq_mel, N_filt) # Equally spaced in Mel scale - f_cos = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz - b1=np.roll(f_cos,1) - b2=np.roll(f_cos,-1) - b1[0]=30 - b2[-1]=(fs/2)-100 - - self.freq_scale=fs*1.0 - self.filt_b1 = nn.Parameter(torch.from_numpy(b1/self.freq_scale)) - self.filt_band = nn.Parameter(torch.from_numpy((b2-b1)/self.freq_scale)) - - - self.N_filt=N_filt - self.Filt_dim=Filt_dim - self.fs=fs - - - def forward(self, x): - - filters=Variable(torch.zeros((self.N_filt,self.Filt_dim))).cuda() - N=self.Filt_dim - t_right=Variable(torch.linspace(1, (N-1)/2, steps=int((N-1)/2))/self.fs).cuda() - - - min_freq=50.0; - min_band=50.0; - - filt_beg_freq=torch.abs(self.filt_b1)+min_freq/self.freq_scale - filt_end_freq=filt_beg_freq+(torch.abs(self.filt_band)+min_band/self.freq_scale) - - n=torch.linspace(0, N, steps=N) - - # Filter window (hamming) - window=0.54-0.46*torch.cos(2*math.pi*n/N); - window=Variable(window.float().cuda()) - - - for i in range(self.N_filt): - - low_pass1 = 2*filt_beg_freq[i].float()*sinc(filt_beg_freq[i].float()*self.freq_scale,t_right) - low_pass2 = 2*filt_end_freq[i].float()*sinc(filt_end_freq[i].float()*self.freq_scale,t_right) - band_pass=(low_pass2-low_pass1) - - band_pass=band_pass/torch.max(band_pass) - - filters[i,:]=band_pass.cuda()*window - - out=F.conv1d(x, filters.view(self.N_filt,1,self.Filt_dim)) - - return out - + def __init__(self, N_filt, Filt_dim, fs): + super(sinc_conv, self).__init__() -def act_fun(act_type): + # Mel Initialization of the filterbanks + low_freq_mel = 80 + high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700)) # Convert Hz to Mel + mel_points = np.linspace(low_freq_mel, high_freq_mel, N_filt) # Equally spaced in Mel scale + f_cos = (700 * (10 ** (mel_points / 2595) - 1)) # Convert Mel to Hz + b1 = np.roll(f_cos, 1) + b2 = np.roll(f_cos, -1) + b1[0] = 30 + b2[-1] = (fs / 2) - 100 + + self.freq_scale = fs * 1.0 + self.filt_b1 = nn.Parameter(torch.from_numpy(b1 / self.freq_scale)) + self.filt_band = nn.Parameter(torch.from_numpy((b2 - b1) / self.freq_scale)) + + self.N_filt = N_filt + self.Filt_dim = Filt_dim + self.fs = fs + + def forward(self, x): + filters = Variable(torch.zeros((self.N_filt, self.Filt_dim))).cuda() + N = self.Filt_dim + t_right = Variable(torch.linspace(1, (N - 1) / 2, steps=int((N - 1) / 2)) / self.fs).cuda() + + min_freq = 50.0; + min_band = 50.0; + + filt_beg_freq = torch.abs(self.filt_b1) + min_freq / self.freq_scale + filt_end_freq = filt_beg_freq + (torch.abs(self.filt_band) + min_band / self.freq_scale) + + n = torch.linspace(0, N, steps=N) + + # Filter window (hamming) + window = 0.54 - 0.46 * torch.cos(2 * math.pi * n / N); + window = Variable(window.float().cuda()) + + for i in range(self.N_filt): + low_pass1 = 2 * filt_beg_freq[i].float() * sinc(filt_beg_freq[i].float() * self.freq_scale, t_right) + low_pass2 = 2 * filt_end_freq[i].float() * sinc(filt_end_freq[i].float() * self.freq_scale, t_right) + band_pass = (low_pass2 - low_pass1) + + band_pass = band_pass / torch.max(band_pass) - if act_type=="relu": + filters[i, :] = band_pass.cuda() * window + + out = F.conv1d(x, filters.view(self.N_filt, 1, self.Filt_dim)) + + return out + + +def act_fun(act_type): + if act_type == "relu": return nn.ReLU() - - if act_type=="tanh": + + if act_type == "tanh": return nn.Tanh() - - if act_type=="sigmoid": + + if act_type == "sigmoid": return nn.Sigmoid() - - if act_type=="leaky_relu": + + if act_type == "leaky_relu": return nn.LeakyReLU(0.2) - - if act_type=="elu": + + if act_type == "elu": return nn.ELU() - - if act_type=="softmax": + + if act_type == "softmax": return nn.LogSoftmax(dim=1) - - if act_type=="linear": - return nn.LeakyReLU(1) # initializzed like this, but not used in forward! - - + + if act_type == "linear": + return nn.LeakyReLU(1) # initializzed like this, but not used in forward! + + class LayerNorm(nn.Module): - def __init__(self, features, eps=1e-6): - super(LayerNorm,self).__init__() - self.gamma = nn.Parameter(torch.ones(features)) - self.beta = nn.Parameter(torch.zeros(features)) - self.eps = eps + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(features)) + self.beta = nn.Parameter(torch.zeros(features)) + self.eps = eps - def forward(self, x): - mean = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) - return self.gamma * (x - mean) / (std + self.eps) + self.beta + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.gamma * (x - mean) / (std + self.eps) + self.beta class MLP(nn.Module): - def __init__(self, options): - super(MLP, self).__init__() - - self.input_dim=int(options['input_dim']) - self.fc_lay=options['fc_lay'] - self.fc_drop=options['fc_drop'] - self.fc_use_batchnorm=options['fc_use_batchnorm'] - self.fc_use_laynorm=options['fc_use_laynorm'] - self.fc_use_laynorm_inp=options['fc_use_laynorm_inp'] - self.fc_use_batchnorm_inp=options['fc_use_batchnorm_inp'] - self.fc_act=options['fc_act'] - - - self.wx = nn.ModuleList([]) - self.bn = nn.ModuleList([]) - self.ln = nn.ModuleList([]) - self.act = nn.ModuleList([]) - self.drop = nn.ModuleList([]) - - - - # input layer normalization - if self.fc_use_laynorm_inp: - self.ln0=LayerNorm(self.input_dim) - - # input batch normalization - if self.fc_use_batchnorm_inp: - self.bn0=nn.BatchNorm1d([self.input_dim],momentum=0.05) - - - self.N_fc_lay=len(self.fc_lay) - - current_input=self.input_dim - - # Initialization of hidden layers - - for i in range(self.N_fc_lay): - - # dropout - self.drop.append(nn.Dropout(p=self.fc_drop[i])) - - # activation - self.act.append(act_fun(self.fc_act[i])) - - - add_bias=True - - # layer norm initialization - self.ln.append(LayerNorm(self.fc_lay[i])) - self.bn.append(nn.BatchNorm1d(self.fc_lay[i],momentum=0.05)) - - if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]: - add_bias=False - - - # Linear operations - self.wx.append(nn.Linear(current_input, self.fc_lay[i],bias=add_bias)) - - # weight initialization - self.wx[i].weight = torch.nn.Parameter(torch.Tensor(self.fc_lay[i],current_input).uniform_(-np.sqrt(0.01/(current_input+self.fc_lay[i])),np.sqrt(0.01/(current_input+self.fc_lay[i])))) - self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i])) - - current_input=self.fc_lay[i] - - - def forward(self, x): - - # Applying Layer/Batch Norm - if bool(self.fc_use_laynorm_inp): - x=self.ln0((x)) - - if bool(self.fc_use_batchnorm_inp): - x=self.bn0((x)) - - for i in range(self.N_fc_lay): - - if self.fc_act[i]!='linear': - - if self.fc_use_laynorm[i]: - x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x)))) - - if self.fc_use_batchnorm[i]: - x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x)))) - - if self.fc_use_batchnorm[i]==False and self.fc_use_laynorm[i]==False: - x = self.drop[i](self.act[i](self.wx[i](x))) - - else: - if self.fc_use_laynorm[i]: - x = self.drop[i](self.ln[i](self.wx[i](x))) - - if self.fc_use_batchnorm[i]: - x = self.drop[i](self.bn[i](self.wx[i](x))) - - if self.fc_use_batchnorm[i]==False and self.fc_use_laynorm[i]==False: - x = self.drop[i](self.wx[i](x)) - - return x + def __init__(self, options): + super(MLP, self).__init__() + + self.input_dim = int(options['input_dim']) + self.fc_lay = options['fc_lay'] + self.fc_drop = options['fc_drop'] + self.fc_use_batchnorm = options['fc_use_batchnorm'] + self.fc_use_laynorm = options['fc_use_laynorm'] + self.fc_use_laynorm_inp = options['fc_use_laynorm_inp'] + self.fc_use_batchnorm_inp = options['fc_use_batchnorm_inp'] + self.fc_act = options['fc_act'] + + self.wx = nn.ModuleList([]) + self.bn = nn.ModuleList([]) + self.ln = nn.ModuleList([]) + self.act = nn.ModuleList([]) + self.drop = nn.ModuleList([]) + + # input layer normalization + if self.fc_use_laynorm_inp: + self.ln0 = LayerNorm(self.input_dim) + + # input batch normalization + if self.fc_use_batchnorm_inp: + self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) + + self.N_fc_lay = len(self.fc_lay) + + current_input = self.input_dim + + # Initialization of hidden layers + + for i in range(self.N_fc_lay): + + # dropout + self.drop.append(nn.Dropout(p=self.fc_drop[i])) + + # activation + self.act.append(act_fun(self.fc_act[i])) + add_bias = True + + # layer norm initialization + self.ln.append(LayerNorm(self.fc_lay[i])) + self.bn.append(nn.BatchNorm1d(self.fc_lay[i], momentum=0.05)) + + if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]: + add_bias = False + + # Linear operations + self.wx.append(nn.Linear(current_input, self.fc_lay[i], bias=add_bias)) + + # weight initialization + self.wx[i].weight = torch.nn.Parameter( + torch.Tensor(self.fc_lay[i], current_input).uniform_(-np.sqrt(0.01 / (current_input + self.fc_lay[i])), + np.sqrt(0.01 / (current_input + self.fc_lay[i])))) + self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i])) + + current_input = self.fc_lay[i] + + def forward(self, x): + + # Applying Layer/Batch Norm + if bool(self.fc_use_laynorm_inp): + x = self.ln0((x)) + + if bool(self.fc_use_batchnorm_inp): + x = self.bn0((x)) + + for i in range(self.N_fc_lay): + + if self.fc_act[i] != 'linear': + + if self.fc_use_laynorm[i]: + x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x)))) + + if self.fc_use_batchnorm[i]: + x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x)))) + + if self.fc_use_batchnorm[i] == False and self.fc_use_laynorm[i] == False: + x = self.drop[i](self.act[i](self.wx[i](x))) + + else: + if self.fc_use_laynorm[i]: + x = self.drop[i](self.ln[i](self.wx[i](x))) + + if self.fc_use_batchnorm[i]: + x = self.drop[i](self.bn[i](self.wx[i](x))) + + if self.fc_use_batchnorm[i] == False and self.fc_use_laynorm[i] == False: + x = self.drop[i](self.wx[i](x)) + + return x class SincNet(nn.Module): - - def __init__(self,options): - super(SincNet,self).__init__() - - self.cnn_N_filt=options['cnn_N_filt'] - self.cnn_len_filt=options['cnn_len_filt'] - self.cnn_max_pool_len=options['cnn_max_pool_len'] - - - self.cnn_act=options['cnn_act'] - self.cnn_drop=options['cnn_drop'] - - self.cnn_use_laynorm=options['cnn_use_laynorm'] - self.cnn_use_batchnorm=options['cnn_use_batchnorm'] - self.cnn_use_laynorm_inp=options['cnn_use_laynorm_inp'] - self.cnn_use_batchnorm_inp=options['cnn_use_batchnorm_inp'] - - self.input_dim=int(options['input_dim']) - - self.fs=options['fs'] - - self.N_cnn_lay=len(options['cnn_N_filt']) - self.conv = nn.ModuleList([]) - self.bn = nn.ModuleList([]) - self.ln = nn.ModuleList([]) - self.act = nn.ModuleList([]) - self.drop = nn.ModuleList([]) - - - if self.cnn_use_laynorm_inp: - self.ln0=LayerNorm(self.input_dim) - - if self.cnn_use_batchnorm_inp: - self.bn0=nn.BatchNorm1d([self.input_dim],momentum=0.05) - - current_input=self.input_dim - - for i in range(self.N_cnn_lay): - - N_filt=int(self.cnn_N_filt[i]) - len_filt=int(self.cnn_len_filt[i]) - - # dropout - self.drop.append(nn.Dropout(p=self.cnn_drop[i])) - - # activation - self.act.append(act_fun(self.cnn_act[i])) - - # layer norm initialization - self.ln.append(LayerNorm([N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i])])) - - self.bn.append(nn.BatchNorm1d(N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i]),momentum=0.05)) - - - if i==0: - self.conv.append(SincConv_fast(self.cnn_N_filt[0],self.cnn_len_filt[0],self.fs)) - - else: - self.conv.append(nn.Conv1d(self.cnn_N_filt[i-1], self.cnn_N_filt[i], self.cnn_len_filt[i])) - - current_input=int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i]) - - - self.out_dim=current_input*N_filt - - - - def forward(self, x): - batch=x.shape[0] - seq_len=x.shape[1] - - if bool(self.cnn_use_laynorm_inp): - x=self.ln0((x)) - - if bool(self.cnn_use_batchnorm_inp): - x=self.bn0((x)) - - x=x.view(batch,1,seq_len) - - - for i in range(self.N_cnn_lay): - - if self.cnn_use_laynorm[i]: - if i==0: - x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i])))) - else: - x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) - - if self.cnn_use_batchnorm[i]: - x = self.drop[i](self.act[i](self.bn[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) - - if self.cnn_use_batchnorm[i]==False and self.cnn_use_laynorm[i]==False: - x = self.drop[i](self.act[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]))) - - - x = x.view(batch,-1) - - return x - - - - + + def __init__(self, options): + super(SincNet, self).__init__() + + self.cnn_N_filt = options['cnn_N_filt'] + self.cnn_len_filt = options['cnn_len_filt'] + self.cnn_max_pool_len = options['cnn_max_pool_len'] + + self.cnn_act = options['cnn_act'] + self.cnn_drop = options['cnn_drop'] + + self.cnn_use_laynorm = options['cnn_use_laynorm'] + self.cnn_use_batchnorm = options['cnn_use_batchnorm'] + self.cnn_use_laynorm_inp = options['cnn_use_laynorm_inp'] + self.cnn_use_batchnorm_inp = options['cnn_use_batchnorm_inp'] + + self.input_dim = int(options['input_dim']) + + self.fs = options['fs'] + + self.N_cnn_lay = len(options['cnn_N_filt']) + self.conv = nn.ModuleList([]) + self.bn = nn.ModuleList([]) + self.ln = nn.ModuleList([]) + self.act = nn.ModuleList([]) + self.drop = nn.ModuleList([]) + + if self.cnn_use_laynorm_inp: + self.ln0 = LayerNorm(self.input_dim) + + if self.cnn_use_batchnorm_inp: + self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) + + current_input = self.input_dim + + for i in range(self.N_cnn_lay): + + N_filt = int(self.cnn_N_filt[i]) + len_filt = int(self.cnn_len_filt[i]) + + # dropout + self.drop.append(nn.Dropout(p=self.cnn_drop[i])) + + # activation + self.act.append(act_fun(self.cnn_act[i])) + + # layer norm initialization + self.ln.append(LayerNorm([N_filt, int((current_input - self.cnn_len_filt[i] + 1) / self.cnn_max_pool_len[i])])) + + self.bn.append(nn.BatchNorm1d(N_filt, int((current_input - self.cnn_len_filt[i] + 1) / self.cnn_max_pool_len[i]), + momentum=0.05)) + + if i == 0: + self.conv.append(SincConv_fast(self.cnn_N_filt[0], self.cnn_len_filt[0], self.fs)) + + else: + self.conv.append(nn.Conv1d(self.cnn_N_filt[i - 1], self.cnn_N_filt[i], self.cnn_len_filt[i])) + + current_input = int((current_input - self.cnn_len_filt[i] + 1) / self.cnn_max_pool_len[i]) + + self.out_dim = current_input * N_filt + + def forward(self, x): + batch = x.shape[0] + seq_len = x.shape[1] + + if bool(self.cnn_use_laynorm_inp): + x = self.ln0((x)) + + if bool(self.cnn_use_batchnorm_inp): + x = self.bn0((x)) + + x = x.view(batch, 1, seq_len) + + for i in range(self.N_cnn_lay): + + if self.cnn_use_laynorm[i]: + if i == 0: + x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i])))) + else: + x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) + + if self.cnn_use_batchnorm[i]: + x = self.drop[i](self.act[i](self.bn[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) + + if self.cnn_use_batchnorm[i] == False and self.cnn_use_laynorm[i] == False: + x = self.drop[i](self.act[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]))) + + x = x.view(batch, -1) + + return x diff --git a/speaker_id.py b/speaker_id.py index 36adc9d2..f9695c6b 100644 --- a/speaker_id.py +++ b/speaker_id.py @@ -1,141 +1,133 @@ # speaker_id.py -# Mirco Ravanelli -# Mila - University of Montreal +# Mirco Ravanelli +# Mila - University of Montreal # July 2018 -# Description: +# Description: # This code performs a speaker_id experiments with SincNet. - + # How to run it: # python speaker_id.py --cfg=cfg/SincNet_TIMIT.cfg import os -#import scipy.io.wavfile -import soundfile as sf +import time + +import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim -from torch.autograd import Variable +import soundfile as sf -import sys -import numpy as np -from dnn_models import MLP,flip -from dnn_models import SincNet as CNN -from data_io import ReadList,read_conf,str_to_bool - - -def create_batches_rnd(batch_size,data_folder,wav_lst,N_snt,wlen,lab_dict,fact_amp): - - # Initialization of the minibatch (batch_size,[0=>x_t,1=>x_t+N,1=>random_samp]) - sig_batch=np.zeros([batch_size,wlen]) - lab_batch=np.zeros(batch_size) - - snt_id_arr=np.random.randint(N_snt, size=batch_size) - - rand_amp_arr = np.random.uniform(1.0-fact_amp,1+fact_amp,batch_size) - - for i in range(batch_size): - - # select a random sentence from the list - #[fs,signal]=scipy.io.wavfile.read(data_folder+wav_lst[snt_id_arr[i]]) - #signal=signal.astype(float)/32768 - - [signal, fs] = sf.read(data_folder+wav_lst[snt_id_arr[i]]) - - # accesing to a random chunk - snt_len=signal.shape[0] - snt_beg=np.random.randint(snt_len-wlen-1) #randint(0, snt_len-2*wlen-1) - snt_end=snt_beg+wlen - - channels = len(signal.shape) - if channels == 2: - print('WARNING: stereo to mono: '+data_folder+wav_lst[snt_id_arr[i]]) - signal = signal[:,0] - - sig_batch[i,:]=signal[snt_beg:snt_end]*rand_amp_arr[i] - lab_batch[i]=lab_dict[wav_lst[snt_id_arr[i]]] - - inp=Variable(torch.from_numpy(sig_batch).float().cuda().contiguous()) - lab=Variable(torch.from_numpy(lab_batch).float().cuda().contiguous()) - - return inp,lab +from data_io import ReadList +from data_io import read_conf +from data_io import str_to_bool +from dnn_models import MLP +from dnn_models import SincNet as CNN +IS_DATA_PARALLEL = False +DEVICE_IDS = [0] +# IS_DATA_PARALLEL = True +# DEVICE_IDS = list(range(8)) +device = torch.device(f"cuda:{DEVICE_IDS[0]}") +def create_batches_rnd(batch_size, data_folder, wav_lst, N_snt, wlen, lab_dict, fact_amp): + # Initialization of the minibatch (batch_size,[0=>x_t,1=>x_t+N,1=>random_samp]) + sig_batch = np.zeros([batch_size, wlen]) + lab_batch = np.zeros(batch_size) -# Reading cfg file -options=read_conf() - -#[data] -tr_lst=options.tr_lst -te_lst=options.te_lst -pt_file=options.pt_file -class_dict_file=options.lab_dict -data_folder=options.data_folder+'/' -output_folder=options.output_folder - -#[windowing] -fs=int(options.fs) -cw_len=int(options.cw_len) -cw_shift=int(options.cw_shift) - -#[cnn] -cnn_N_filt=list(map(int, options.cnn_N_filt.split(','))) -cnn_len_filt=list(map(int, options.cnn_len_filt.split(','))) -cnn_max_pool_len=list(map(int, options.cnn_max_pool_len.split(','))) -cnn_use_laynorm_inp=str_to_bool(options.cnn_use_laynorm_inp) -cnn_use_batchnorm_inp=str_to_bool(options.cnn_use_batchnorm_inp) -cnn_use_laynorm=list(map(str_to_bool, options.cnn_use_laynorm.split(','))) -cnn_use_batchnorm=list(map(str_to_bool, options.cnn_use_batchnorm.split(','))) -cnn_act=list(map(str, options.cnn_act.split(','))) -cnn_drop=list(map(float, options.cnn_drop.split(','))) - - -#[dnn] -fc_lay=list(map(int, options.fc_lay.split(','))) -fc_drop=list(map(float, options.fc_drop.split(','))) -fc_use_laynorm_inp=str_to_bool(options.fc_use_laynorm_inp) -fc_use_batchnorm_inp=str_to_bool(options.fc_use_batchnorm_inp) -fc_use_batchnorm=list(map(str_to_bool, options.fc_use_batchnorm.split(','))) -fc_use_laynorm=list(map(str_to_bool, options.fc_use_laynorm.split(','))) -fc_act=list(map(str, options.fc_act.split(','))) - -#[class] -class_lay=list(map(int, options.class_lay.split(','))) -class_drop=list(map(float, options.class_drop.split(','))) -class_use_laynorm_inp=str_to_bool(options.class_use_laynorm_inp) -class_use_batchnorm_inp=str_to_bool(options.class_use_batchnorm_inp) -class_use_batchnorm=list(map(str_to_bool, options.class_use_batchnorm.split(','))) -class_use_laynorm=list(map(str_to_bool, options.class_use_laynorm.split(','))) -class_act=list(map(str, options.class_act.split(','))) - - -#[optimization] -lr=float(options.lr) -batch_size=int(options.batch_size) -N_epochs=int(options.N_epochs) -N_batches=int(options.N_batches) -N_eval_epoch=int(options.N_eval_epoch) -seed=int(options.seed) + snt_id_arr = np.random.randint(N_snt, size=batch_size) + + rand_amp_arr = np.random.uniform(1.0 - fact_amp, 1 + fact_amp, batch_size) + + for i in range(batch_size): + signal, _ = sf.read(data_folder + wav_lst[snt_id_arr[i]]) + + # accessing to a random chunk + snt_len = signal.shape[0] + snt_beg = np.random.randint(snt_len - wlen - 1) + snt_end = snt_beg + wlen + + channels = len(signal.shape) + if channels >= 2: + assert False + sig_batch[i, :] = signal[snt_beg:snt_end] * rand_amp_arr[i] + lab_batch[i] = lab_dict[wav_lst[snt_id_arr[i]]] + + inp = torch.from_numpy(sig_batch).float().cuda(device).contiguous() + lab = torch.from_numpy(lab_batch).float().cuda(device).contiguous() + + return inp, lab + + +# Reading cfg file +options = read_conf() +print(options) +# [data] +tr_lst = options.tr_lst +te_lst = options.te_lst +pt_file = options.pt_file +class_dict_file = options.lab_dict +data_folder = options.data_folder + '/' +output_folder = options.output_folder + +# [windowing] +fs = int(options.fs) +cw_len = int(options.cw_len) +cw_shift = int(options.cw_shift) + +# [cnn] +cnn_N_filt = list(map(int, options.cnn_N_filt.split(','))) +cnn_len_filt = list(map(int, options.cnn_len_filt.split(','))) +cnn_max_pool_len = list(map(int, options.cnn_max_pool_len.split(','))) +cnn_use_laynorm_inp = str_to_bool(options.cnn_use_laynorm_inp) +cnn_use_batchnorm_inp = str_to_bool(options.cnn_use_batchnorm_inp) +cnn_use_laynorm = list(map(str_to_bool, options.cnn_use_laynorm.split(','))) +cnn_use_batchnorm = list(map(str_to_bool, options.cnn_use_batchnorm.split(','))) +cnn_act = list(map(str, options.cnn_act.split(','))) +cnn_drop = list(map(float, options.cnn_drop.split(','))) + +# [dnn] +fc_lay = list(map(int, options.fc_lay.split(','))) +fc_drop = list(map(float, options.fc_drop.split(','))) +fc_use_laynorm_inp = str_to_bool(options.fc_use_laynorm_inp) +fc_use_batchnorm_inp = str_to_bool(options.fc_use_batchnorm_inp) +fc_use_batchnorm = list(map(str_to_bool, options.fc_use_batchnorm.split(','))) +fc_use_laynorm = list(map(str_to_bool, options.fc_use_laynorm.split(','))) +fc_act = list(map(str, options.fc_act.split(','))) + +# [class] +class_lay = list(map(int, options.class_lay.split(','))) +class_drop = list(map(float, options.class_drop.split(','))) +class_use_laynorm_inp = str_to_bool(options.class_use_laynorm_inp) +class_use_batchnorm_inp = str_to_bool(options.class_use_batchnorm_inp) +class_use_batchnorm = list(map(str_to_bool, options.class_use_batchnorm.split(','))) +class_use_laynorm = list(map(str_to_bool, options.class_use_laynorm.split(','))) +class_act = list(map(str, options.class_act.split(','))) + +# [optimization] +lr = float(options.lr) +batch_size = int(options.batch_size) +N_epochs = int(options.N_epochs) +N_batches = int(options.N_batches) +N_eval_epoch = int(options.N_eval_epoch) +seed = int(options.seed) # training list -wav_lst_tr=ReadList(tr_lst) -snt_tr=len(wav_lst_tr) +wav_lst_tr = ReadList(tr_lst) +snt_tr = len(wav_lst_tr) # test list -wav_lst_te=ReadList(te_lst) -snt_te=len(wav_lst_te) - +wav_lst_te = ReadList(te_lst) +snt_te = len(wav_lst_te) # Folder creation try: - os.stat(output_folder) + os.stat(output_folder) except: - os.mkdir(output_folder) - - + os.mkdir(output_folder) + # setting seed torch.manual_seed(seed) np.random.seed(seed) @@ -143,199 +135,175 @@ def create_batches_rnd(batch_size,data_folder,wav_lst,N_snt,wlen,lab_dict,fact_a # loss function cost = nn.NLLLoss() - # Converting context and shift in samples -wlen=int(fs*cw_len/1000.00) -wshift=int(fs*cw_shift/1000.00) - -# Batch_dev -Batch_dev=128 - +wlen = int(fs * cw_len / 1000.00) +wshift = int(fs * cw_shift / 1000.00) # Feature extractor CNN CNN_arch = {'input_dim': wlen, - 'fs': fs, - 'cnn_N_filt': cnn_N_filt, - 'cnn_len_filt': cnn_len_filt, - 'cnn_max_pool_len':cnn_max_pool_len, - 'cnn_use_laynorm_inp': cnn_use_laynorm_inp, - 'cnn_use_batchnorm_inp': cnn_use_batchnorm_inp, - 'cnn_use_laynorm':cnn_use_laynorm, - 'cnn_use_batchnorm':cnn_use_batchnorm, - 'cnn_act': cnn_act, - 'cnn_drop':cnn_drop, - } - -CNN_net=CNN(CNN_arch) -CNN_net.cuda() + 'fs': fs, + 'cnn_N_filt': cnn_N_filt, + 'cnn_len_filt': cnn_len_filt, + 'cnn_max_pool_len': cnn_max_pool_len, + 'cnn_use_laynorm_inp': cnn_use_laynorm_inp, + 'cnn_use_batchnorm_inp': cnn_use_batchnorm_inp, + 'cnn_use_laynorm': cnn_use_laynorm, + 'cnn_use_batchnorm': cnn_use_batchnorm, + 'cnn_act': cnn_act, + 'cnn_drop': cnn_drop, + } + +CNN_net = CNN(CNN_arch) +CNN_net_out_dim = CNN_net.out_dim +if IS_DATA_PARALLEL: + CNN_net = nn.DataParallel(CNN_net, device_ids=DEVICE_IDS) +CNN_net.cuda(device) # Loading label dictionary -lab_dict=np.load(class_dict_file).item() - - - -DNN1_arch = {'input_dim': CNN_net.out_dim, - 'fc_lay': fc_lay, - 'fc_drop': fc_drop, - 'fc_use_batchnorm': fc_use_batchnorm, - 'fc_use_laynorm': fc_use_laynorm, - 'fc_use_laynorm_inp': fc_use_laynorm_inp, - 'fc_use_batchnorm_inp':fc_use_batchnorm_inp, - 'fc_act': fc_act, - } - -DNN1_net=MLP(DNN1_arch) -DNN1_net.cuda() - - -DNN2_arch = {'input_dim':fc_lay[-1] , - 'fc_lay': class_lay, - 'fc_drop': class_drop, - 'fc_use_batchnorm': class_use_batchnorm, - 'fc_use_laynorm': class_use_laynorm, - 'fc_use_laynorm_inp': class_use_laynorm_inp, - 'fc_use_batchnorm_inp':class_use_batchnorm_inp, - 'fc_act': class_act, - } - - -DNN2_net=MLP(DNN2_arch) -DNN2_net.cuda() - - -if pt_file!='none': - checkpoint_load = torch.load(pt_file) - CNN_net.load_state_dict(checkpoint_load['CNN_model_par']) - DNN1_net.load_state_dict(checkpoint_load['DNN1_model_par']) - DNN2_net.load_state_dict(checkpoint_load['DNN2_model_par']) - - - -optimizer_CNN = optim.RMSprop(CNN_net.parameters(), lr=lr,alpha=0.95, eps=1e-8) -optimizer_DNN1 = optim.RMSprop(DNN1_net.parameters(), lr=lr,alpha=0.95, eps=1e-8) -optimizer_DNN2 = optim.RMSprop(DNN2_net.parameters(), lr=lr,alpha=0.95, eps=1e-8) - - - +lab_dict = np.load(class_dict_file, allow_pickle=True).item() + +DNN1_arch = {'input_dim': CNN_net_out_dim, + 'fc_lay': fc_lay, + 'fc_drop': fc_drop, + 'fc_use_batchnorm': fc_use_batchnorm, + 'fc_use_laynorm': fc_use_laynorm, + 'fc_use_laynorm_inp': fc_use_laynorm_inp, + 'fc_use_batchnorm_inp': fc_use_batchnorm_inp, + 'fc_act': fc_act, + } + +DNN1_net = MLP(DNN1_arch) +if IS_DATA_PARALLEL: + DNN1_net = nn.DataParallel(DNN1_net, device_ids=DEVICE_IDS) +DNN1_net.cuda(device) + +DNN2_arch = {'input_dim': fc_lay[-1], + 'fc_lay': class_lay, + 'fc_drop': class_drop, + 'fc_use_batchnorm': class_use_batchnorm, + 'fc_use_laynorm': class_use_laynorm, + 'fc_use_laynorm_inp': class_use_laynorm_inp, + 'fc_use_batchnorm_inp': class_use_batchnorm_inp, + 'fc_act': class_act, + } + +DNN2_net = MLP(DNN2_arch) +if IS_DATA_PARALLEL: + DNN2_net = nn.DataParallel(DNN2_net, device_ids=DEVICE_IDS) +DNN2_net.cuda(device) + +if pt_file != 'none': + checkpoint_load = torch.load(pt_file) + CNN_net.load_state_dict(checkpoint_load['CNN_model_par']) + DNN1_net.load_state_dict(checkpoint_load['DNN1_model_par']) + DNN2_net.load_state_dict(checkpoint_load['DNN2_model_par']) + +optimizer_CNN = optim.RMSprop(CNN_net.parameters(), lr=lr, alpha=0.95, eps=1e-8) +optimizer_DNN1 = optim.RMSprop(DNN1_net.parameters(), lr=lr, alpha=0.95, eps=1e-8) +optimizer_DNN2 = optim.RMSprop(DNN2_net.parameters(), lr=lr, alpha=0.95, eps=1e-8) + +print('localtime when starting:', time.localtime()) for epoch in range(N_epochs): - - test_flag=0 + epoch_start = time.monotonic() + test_flag = 0 CNN_net.train() DNN1_net.train() DNN2_net.train() - - loss_sum=0 - err_sum=0 + + loss_sum = 0 + err_sum = 0 for i in range(N_batches): + [inp, lab] = create_batches_rnd(batch_size, data_folder, wav_lst_tr, snt_tr, wlen, lab_dict, 0.2) + pout = DNN2_net(DNN1_net(CNN_net(inp))) - [inp,lab]=create_batches_rnd(batch_size,data_folder,wav_lst_tr,snt_tr,wlen,lab_dict,0.2) - pout=DNN2_net(DNN1_net(CNN_net(inp))) - - pred=torch.max(pout,dim=1)[1] + pred = torch.max(pout, dim=1)[1] loss = cost(pout, lab.long()) - err = torch.mean((pred!=lab.long()).float()) - - - + err = torch.mean((pred != lab.long()).float()) + optimizer_CNN.zero_grad() - optimizer_DNN1.zero_grad() - optimizer_DNN2.zero_grad() - + optimizer_DNN1.zero_grad() + optimizer_DNN2.zero_grad() + loss.backward() optimizer_CNN.step() optimizer_DNN1.step() optimizer_DNN2.step() - - loss_sum=loss_sum+loss.detach() - err_sum=err_sum+err.detach() - - - loss_tot=loss_sum/N_batches - err_tot=err_sum/N_batches - - - - -# Full Validation new - if epoch%N_eval_epoch==0: - - CNN_net.eval() - DNN1_net.eval() - DNN2_net.eval() - test_flag=1 - loss_sum=0 - err_sum=0 - err_sum_snt=0 - - with torch.no_grad(): - for i in range(snt_te): - - #[fs,signal]=scipy.io.wavfile.read(data_folder+wav_lst_te[i]) - #signal=signal.astype(float)/32768 - - [signal, fs] = sf.read(data_folder+wav_lst_te[i]) - - signal=torch.from_numpy(signal).float().cuda().contiguous() - lab_batch=lab_dict[wav_lst_te[i]] - - # split signals into chunks - beg_samp=0 - end_samp=wlen - - N_fr=int((signal.shape[0]-wlen)/(wshift)) - - - sig_arr=torch.zeros([Batch_dev,wlen]).float().cuda().contiguous() - lab= Variable((torch.zeros(N_fr+1)+lab_batch).cuda().contiguous().long()) - pout=Variable(torch.zeros(N_fr+1,class_lay[-1]).float().cuda().contiguous()) - count_fr=0 - count_fr_tot=0 - while end_samp0: - inp=Variable(sig_arr[0:count_fr]) - pout[count_fr_tot-count_fr:count_fr_tot,:]=DNN2_net(DNN1_net(CNN_net(inp))) - - - pred=torch.max(pout,dim=1)[1] - loss = cost(pout, lab.long()) - err = torch.mean((pred!=lab.long()).float()) - - [val,best_class]=torch.max(torch.sum(pout,dim=0),0) - err_sum_snt=err_sum_snt+(best_class!=lab[0]).float() - - - loss_sum=loss_sum+loss.detach() - err_sum=err_sum+err.detach() - - err_tot_dev_snt=err_sum_snt/snt_te - loss_tot_dev=loss_sum/snt_te - err_tot_dev=err_sum/snt_te - - - print("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt)) - - with open(output_folder+"/res.res", "a") as res_file: - res_file.write("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f\n" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt)) - - checkpoint={'CNN_model_par': CNN_net.state_dict(), - 'DNN1_model_par': DNN1_net.state_dict(), - 'DNN2_model_par': DNN2_net.state_dict(), - } - torch.save(checkpoint,output_folder+'/model_raw.pkl') - - else: - print("epoch %i, loss_tr=%f err_tr=%f" % (epoch, loss_tot,err_tot)) + loss_sum = loss_sum + loss.detach() + err_sum = err_sum + err.detach() + + loss_tot = loss_sum / N_batches + err_tot = err_sum / N_batches + + # Full Validation new + if epoch % N_eval_epoch == 0: + + CNN_net.eval() + DNN1_net.eval() + DNN2_net.eval() + test_flag = 1 + loss_sum = 0 + err_sum = 0 + err_sum_snt = 0 + + with torch.no_grad(): + for i in range(snt_te): + signal, _ = sf.read(data_folder + wav_lst_te[i]) + lab_batch = lab_dict[wav_lst_te[i]] + + # 1 2 3 4 5 6 7 8 9 10 11 12 13 + # stride = 3 + # window = 5 + # we should have (13 - 5) // 3 + 1 = 3 examples + # for we have (13 - 5) segments, each segment corresponds to a left end point, + # and each left point corresponds to one example, + # plus the one at the end + # 1 2 3 4 5 + # 4 5 6 7 8 + # 7 8 9 10 11 + + N_fr_actual = 0 + N_fr = (signal.shape[0] - wlen) // wshift + 1 + sig_arr = np.zeros((N_fr, wlen), dtype=np.float32) + lab = (torch.zeros(N_fr) + lab_batch).cuda(device).contiguous().long() + for i_sig, beg_samp in enumerate(range(0, signal.shape[0], wshift)): + end_samp = beg_samp + wlen + if end_samp > signal.shape[0]: + break + else: + sig_arr[i_sig] = signal[beg_samp:end_samp] + N_fr_actual += 1 + assert N_fr_actual == N_fr + sig_arr = torch.from_numpy(sig_arr).cuda(device).contiguous() + pout = DNN2_net(DNN1_net(CNN_net(sig_arr))) + + pred = torch.max(pout, dim=1)[1] + loss = cost(pout, lab.long()) + err = torch.mean((pred != lab.long()).float()) + + [val, best_class] = torch.max(torch.sum(pout, dim=0), 0) + err_sum_snt = err_sum_snt + (best_class != lab[0]).float() + + loss_sum = loss_sum + loss.detach() + err_sum = err_sum + err.detach() + + err_tot_dev_snt = err_sum_snt / snt_te + loss_tot_dev = loss_sum / snt_te + err_tot_dev = err_sum / snt_te + + print("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f time=%f" % ( + epoch, loss_tot, err_tot, loss_tot_dev, err_tot_dev, err_tot_dev_snt, time.monotonic() - epoch_start)) + + with open(output_folder + "/res.res", "a") as res_file: + res_file.write("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f time=%f\n" % ( + epoch, loss_tot, err_tot, loss_tot_dev, err_tot_dev, err_tot_dev_snt, time.monotonic() - epoch_start)) + + checkpoint = {'CNN_model_par': CNN_net.state_dict(), + 'DNN1_model_par': DNN1_net.state_dict(), + 'DNN2_model_par': DNN2_net.state_dict()} + torch.save(checkpoint, output_folder + f'/model_raw_{epoch}.pkl') + else: + print("epoch %i, loss_tr=%f err_tr=%f time=%f" % ( + epoch, loss_tot, err_tot, time.monotonic() - epoch_start))