Open
Description
I have ported your code to Pytorch. Here is the link to repo
After training, I get same values for both indices. I sense something is wrong in the decoding part, but unable to pinpoint. Can you please take a look?
from torch import nn
from torch.autograd import Variable
class Model(nn.Module):
def __init__(self, input_dim, hidden_size, num_of_indices, blend_dim, batch_size):
super(Model, self).__init__()
self.batch_size = batch_size # B
self.input_dim = input_dim # I
self.hidden_size = hidden_size # H
self.num_of_indices = num_of_indices # N
self.blend_dim = blend_dim # D
self.encode = nn.LSTMCell(input_dim, hidden_size)
self.decode = nn.LSTMCell(input_dim, hidden_size)
self.blend_decoder = nn.Linear(hidden_size, blend_dim)
self.blend_encoder = nn.Linear(hidden_size, blend_dim)
self.scale_blend = nn.Linear(blend_dim, input_dim)
def zero_hidden_state(self):
return Variable(torch.zeros([self.batch_size, self.hidden_size]).cuda())
def forward(self, inp):
#TODO - zero
hidden = self.zero_hidden_state() # BxH
cell_state = self.zero_hidden_state() # BxH
encoder_states = []
for j in range(inp.size()[1]): # inp -> BxJxI
encoder_input = inp[:, j:j+1] # BxI
hidden, cell_state = self.encode(encoder_input, (hidden, cell_state))
encoder_states.append((hidden, cell_state))
pointers = []
pointer_distributions = []
start_token = 0
decoder_input = Variable(torch.Tensor([start_token] * self.batch_size) # BxI
.view(self.batch_size, self.input_dim).cuda())
hidden, cell_state = encoder_states[-1] # BxH
for i in range(self.num_of_indices):
hidden, cell_state = self.decode(decoder_input, (hidden, cell_state)) # BxH
print(hidden, cell_state)
decoder_blend = self.blend_decoder(hidden) # BxD
encoder_blends = []
index_predists = []
for i in range(inp.size()[1]):
encoder_blend = self.blend_encoder(encoder_states[i][1]) # BxD
raw_blend = encoder_blend + decoder_blend # BxD
scaled_blend = self.scale_blend(raw_blend).squeeze(1) # BxI
index_predist = scaled_blend
encoder_blends.append(encoder_blend)
index_predists.append(index_predist)
index_predistribution = torch.stack(index_predists).t() # BxJ
index_distribution = F.softmax(index_predistribution)
pointer_distributions.append(index_distribution)
index = index_distribution.data.max(1)[1].squeeze(1) # B
emb = embedding_lookup(inp.t(), Variable(index)) # BxB
pointer_raw = torch.diag(emb) # B
pointer = pointer_raw
pointers.append(pointer)
decoder_input = pointer.unsqueeze(1) # Bx1
print(decoder_input)
index_distributions = torch.stack(pointer_distributions)
return index_distributions
Metadata
Metadata
Assignees
Labels
No labels