99import torch
1010import torch .nn .functional as F
1111
12- from Modules .Vocoder .Avocodo_Discriminators import MultiCoMBDiscriminator
13- from Modules .Vocoder .Avocodo_Discriminators import MultiSubBandDiscriminator
1412from Modules .Vocoder .SAN_modules import SANConv1d
1513from Modules .Vocoder .SAN_modules import SANConv2d
1614
@@ -456,10 +454,13 @@ def forward(self, x):
456454
457455
458456class AvocodoHiFiGANJointDiscriminator (torch .nn .Module ):
457+ """
458+ Contradicting the legacy name, the Avocodo parts were removed again for stability
459+ """
459460
460461 def __init__ (self ,
461462 # Multi-scale discriminator related
462- scales = 3 ,
463+ scales = 4 ,
463464 scale_downsample_pooling = "AvgPool1d" ,
464465 scale_downsample_pooling_params = {"kernel_size" : 4 ,
465466 "stride" : 2 ,
@@ -471,7 +472,7 @@ def __init__(self,
471472 "max_downsample_channels" : 1024 ,
472473 "max_groups" : 16 ,
473474 "bias" : True ,
474- "downsample_scales" : [4 , 4 , 4 , 4 , 1 ],
475+ "downsample_scales" : [4 , 4 , 4 , 1 ],
475476 "nonlinear_activation" : "LeakyReLU" ,
476477 "nonlinear_activation_params" : {"negative_slope" : 0.1 }, },
477478 follow_official_norm = True ,
@@ -481,41 +482,14 @@ def __init__(self,
481482 "out_channels" : 1 ,
482483 "kernel_sizes" : [5 , 3 ],
483484 "channels" : 32 ,
484- "downsample_scales" : [3 , 3 , 3 , 3 , 1 ],
485+ "downsample_scales" : [3 , 3 , 3 , 1 ],
485486 "max_downsample_channels" : 1024 ,
486487 "bias" : True ,
487488 "nonlinear_activation" : "LeakyReLU" ,
488489 "nonlinear_activation_params" : {"negative_slope" : 0.1 },
489490 "use_weight_norm" : True ,
490491 "use_spectral_norm" : False , },
491- # CoMB discriminator related
492- kernels = ((7 , 11 , 11 , 11 , 11 , 5 ),
493- (11 , 21 , 21 , 21 , 21 , 5 ),
494- (15 , 41 , 41 , 41 , 41 , 5 )),
495- channels = (16 , 64 , 256 , 1024 , 1024 , 1024 ),
496- groups = (1 , 4 , 16 , 64 , 256 , 1 ),
497- strides = (1 , 1 , 4 , 4 , 4 , 1 ),
498- # Sub-Band discriminator related
499- tkernels = (7 , 5 , 3 ),
500- fkernel = 5 ,
501- tchannels = (64 , 128 , 256 , 256 , 256 ),
502- fchannels = (32 , 64 , 128 , 128 , 128 ),
503- tstrides = ((1 , 1 , 3 , 3 , 1 ),
504- (1 , 1 , 3 , 3 , 1 ),
505- (1 , 1 , 3 , 3 , 1 )),
506- fstride = (1 , 1 , 3 , 3 , 1 ),
507- tdilations = (((5 , 7 , 11 ), (5 , 7 , 11 ), (5 , 7 , 11 ), (5 , 7 , 11 ), (5 , 7 , 11 ), (5 , 7 , 11 )),
508- ((3 , 5 , 7 ), (3 , 5 , 7 ), (3 , 5 , 7 ), (3 , 5 , 7 ), (3 , 5 , 7 )),
509- ((1 , 2 , 3 ), (1 , 2 , 3 ), (1 , 2 , 3 ), (1 , 2 , 3 ), (1 , 2 , 3 ))),
510- fdilations = ((1 , 2 , 3 ),
511- (1 , 2 , 3 ),
512- (1 , 2 , 3 ),
513- (2 , 3 , 5 ),
514- (2 , 3 , 5 )),
515- tsubband = (6 , 11 , 16 ),
516- n = 16 ,
517- m = 64 ,
518- freq_init_ch = 192 ):
492+ ):
519493 super ().__init__ ()
520494 self .msd = HiFiGANMultiScaleDiscriminator (scales = scales ,
521495 downsample_pooling = scale_downsample_pooling ,
@@ -524,10 +498,8 @@ def __init__(self,
524498 follow_official_norm = follow_official_norm , )
525499 self .mpd = HiFiGANMultiPeriodDiscriminator (periods = periods ,
526500 discriminator_params = period_discriminator_params , )
527- self .mcmbd = MultiCoMBDiscriminator (kernels , channels , groups , strides )
528- self .msbd = MultiSubBandDiscriminator (tkernels , fkernel , tchannels , fchannels , tstrides , fstride , tdilations , fdilations , tsubband , n , m , freq_init_ch )
529501
530- def forward (self , wave , intermediate_wave_upsampled_twice = None , intermediate_wave_upsampled_once = None , discriminator_train_flag = False ):
502+ def forward (self , wave , discriminator_train_flag = False ):
531503 """
532504 Calculate forward propagation.
533505
@@ -542,9 +514,9 @@ def forward(self, wave, intermediate_wave_upsampled_twice=None, intermediate_wav
542514 """
543515 msd_outs , msd_feats = self .msd (wave , discriminator_train_flag )
544516 mpd_outs , mpd_feats = self .mpd (wave , discriminator_train_flag )
545- mcmbd_outs , mcmbd_feats = self . mcmbd ( wave_final = wave ,
546- intermediate_wave_upsampled_twice = intermediate_wave_upsampled_twice ,
547- intermediate_wave_upsampled_once = intermediate_wave_upsampled_once ,
548- discriminator_train_flag = discriminator_train_flag )
549- msbd_outs , msbd_feats = self . msbd ( wave , discriminator_train_flag )
550- return msd_outs + mpd_outs + mcmbd_outs + msbd_outs , msd_feats + mpd_feats + mcmbd_feats + msbd_feats
517+ return msd_outs + mpd_outs , msd_feats + mpd_feats
518+
519+
520+ if __name__ == '__main__' :
521+ d = AvocodoHiFiGANJointDiscriminator ( )
522+ print ( d ( torch . randn ([ 2 , 1 , 12288 * 2 ])))
0 commit comments