@@ -90,9 +90,9 @@ def _pad(seq, max_len, constant_values=0):
9090 mode = 'constant' , constant_values = constant_values )
9191
9292
93- def _pad_2d (x , max_len , b_pad = 0 ):
93+ def _pad_2d (x , max_len , b_pad = 0 , constant_values = 0 ):
9494 x = np .pad (x , [(b_pad , max_len - len (x ) - b_pad ), (0 , 0 )],
95- mode = "constant" , constant_values = 0 )
95+ mode = "constant" , constant_values = constant_values )
9696 return x
9797
9898
@@ -417,17 +417,19 @@ def collate_fn(batch):
417417 # (B, T, C)
418418 # pad for time-axis
419419 if is_mulaw_quantize (hparams .input_type ):
420+ padding_value = P .mulaw_quantize (0 , mu = hparams .quantize_channels )
420421 x_batch = np .array ([_pad_2d (np_utils .to_categorical (
421422 x [0 ], num_classes = hparams .quantize_channels ),
422- max_input_len ) for x in batch ], dtype = np .float32 )
423+ max_input_len , padding_value ) for x in batch ], dtype = np .float32 )
423424 else :
424425 x_batch = np .array ([_pad_2d (x [0 ].reshape (- 1 , 1 ), max_input_len )
425426 for x in batch ], dtype = np .float32 )
426427 assert len (x_batch .shape ) == 3
427428
428429 # (B, T)
429430 if is_mulaw_quantize (hparams .input_type ):
430- y_batch = np .array ([_pad (x [0 ], max_input_len ) for x in batch ], dtype = np .int )
431+ padding_value = P .mulaw_quantize (0 , mu = hparams .quantize_channels )
432+ y_batch = np .array ([_pad (x [0 ], max_input_len , constant_values = padding_value ) for x in batch ], dtype = np .int )
431433 else :
432434 y_batch = np .array ([_pad (x [0 ], max_input_len ) for x in batch ], dtype = np .float32 )
433435 assert len (y_batch .shape ) == 2
0 commit comments