11# -*- coding: utf-8 -*-
2+ import torch .nn .functional as F
23from ..layers import ImageEncoder , XuDecoder
34
45from ..datasets import Multi30kRawDataset
@@ -41,13 +42,18 @@ def set_defaults(self):
4142 'cnn_pretrained' : True , # Should we use pretrained imagenet weights
4243 'cnn_finetune' : None , # Should we finetune part or all of CNN
4344 'pool' : None , # ('Avg|Max', kernel_size, stride_size)
45+ 'l2_norm' : False , # L2 normalize features
46+ 'l2_norm_dim' : - 1 , # Which dimension to L2 normalize
4447 'resize' : 256 , # resize width, height for images
4548 'crop' : 224 , # center crop size after resize
49+ 'replicate' : 1 , # For multi-caption setup, replicates images N times
4650 'direction' : None , # Network directionality, i.e. en->de
4751 }
4852
4953 def __init__ (self , opts , logger = None ):
5054 super ().__init__ (opts , logger )
55+ if self .opts .model ['alpha_c' ] > 0 :
56+ self .aux_loss ['alpha_reg' ] = 0.0
5157
5258 def setup (self , is_train = True ):
5359 self .print ('Loading CNN' )
@@ -103,6 +109,7 @@ def load_data(self, split):
103109 warmup = (split != 'train' ),
104110 resize = self .opts .model ['resize' ],
105111 crop = self .opts .model ['crop' ],
112+ replicate = self .opts .model ['replicate' ] if split == 'train' else 1 ,
106113 vocabs = self .vocabs ,
107114 topology = self .topology )
108115 self .print (self .datasets [split ])
@@ -111,5 +118,18 @@ def encode(self, batch):
111118 # Get features into (n,c,w*h) and then (w*h,n,c)
112119 feats = self .cnn (batch ['image' ])
113120 feats = feats .view ((* feats .shape [:2 ], - 1 )).permute (2 , 0 , 1 )
121+ if self .opts .model ['l2_norm' ]:
122+ feats = F .normalize (
123+ feats , dim = self .opts .model ['l2_norm_dim' ]).detach ()
114124
115125 return {'image' : (feats , None )}
126+
127+ def forward (self , batch ):
128+ result = super ().forward (batch )
129+
130+ if self .training and self .opts .model ['alpha_c' ] > 0 :
131+ alpha_loss = (sum (self .dec .alphas )** 2 ).sum (0 ).mean ()
132+ self .aux_loss ['alpha_reg' ] = alpha_loss .mul (
133+ self .opts .model ['alpha_c' ])
134+
135+ return result
0 commit comments