Skip to content
This repository was archived by the owner on Jan 5, 2023. It is now read-only.

Commit a471e80

Browse files
committed
update SAT model
1 parent 87ac9de commit a471e80

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

nmtpytorch/layers/xu_decoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _rnn_init_mean_ctx(self, ctx, ctx_mask):
131131

132132
def f_init(self, ctx_dict):
133133
"""Returns the initial h_0, c_0 for the decoder."""
134+
self.alphas = []
134135
return self._init_func(*ctx_dict[self.ctx_name])
135136

136137
def f_next(self, ctx_dict, y, h):
@@ -140,6 +141,8 @@ def f_next(self, ctx_dict, y, h):
140141
# Apply attention
141142
self.alpha_t, z_t = self.att(
142143
h_c[0].unsqueeze(0), *ctx_dict[self.ctx_name])
144+
# Save reg loss terms
145+
self.alphas.append(1 - self.alpha_t)
143146

144147
if self.selector:
145148
z_t *= self.ff_selector(h_c[0])

nmtpytorch/models/sat.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
import torch.nn.functional as F
23
from ..layers import ImageEncoder, XuDecoder
34

45
from ..datasets import Multi30kRawDataset
@@ -41,13 +42,18 @@ def set_defaults(self):
4142
'cnn_pretrained': True, # Should we use pretrained imagenet weights
4243
'cnn_finetune': None, # Should we finetune part or all of CNN
4344
'pool': None, # ('Avg|Max', kernel_size, stride_size)
45+
'l2_norm': False, # L2 normalize features
46+
'l2_norm_dim': -1, # Which dimension to L2 normalize
4447
'resize': 256, # resize width, height for images
4548
'crop': 224, # center crop size after resize
49+
'replicate': 1, # For multi-caption setup, replicates images N times
4650
'direction': None, # Network directionality, i.e. en->de
4751
}
4852

4953
def __init__(self, opts, logger=None):
5054
super().__init__(opts, logger)
55+
if self.opts.model['alpha_c'] > 0:
56+
self.aux_loss['alpha_reg'] = 0.0
5157

5258
def setup(self, is_train=True):
5359
self.print('Loading CNN')
@@ -103,6 +109,7 @@ def load_data(self, split):
103109
warmup=(split != 'train'),
104110
resize=self.opts.model['resize'],
105111
crop=self.opts.model['crop'],
112+
replicate=self.opts.model['replicate'] if split == 'train' else 1,
106113
vocabs=self.vocabs,
107114
topology=self.topology)
108115
self.print(self.datasets[split])
@@ -111,5 +118,18 @@ def encode(self, batch):
111118
# Get features into (n,c,w*h) and then (w*h,n,c)
112119
feats = self.cnn(batch['image'])
113120
feats = feats.view((*feats.shape[:2], -1)).permute(2, 0, 1)
121+
if self.opts.model['l2_norm']:
122+
feats = F.normalize(
123+
feats, dim=self.opts.model['l2_norm_dim']).detach()
114124

115125
return {'image': (feats, None)}
126+
127+
def forward(self, batch):
128+
result = super().forward(batch)
129+
130+
if self.training and self.opts.model['alpha_c'] > 0:
131+
alpha_loss = (sum(self.dec.alphas)**2).sum(0).mean()
132+
self.aux_loss['alpha_reg'] = alpha_loss.mul(
133+
self.opts.model['alpha_c'])
134+
135+
return result

0 commit comments

Comments
 (0)