Skip to content

Both the starting and ending indices are always same #2

Open
@vanangamudi

Description

@vanangamudi

I have ported your code to Pytorch. Here is the link to repo

After training, I get same values for both indices. I sense something is wrong in the decoding part, but unable to pinpoint. Can you please take a look?

from torch import nn
from torch.autograd import Variable
class Model(nn.Module):
    def __init__(self, input_dim, hidden_size, num_of_indices, blend_dim, batch_size):
        super(Model, self).__init__()
        
        self.batch_size = batch_size               # B
        self.input_dim = input_dim                 # I
        self.hidden_size = hidden_size             # H
        self.num_of_indices = num_of_indices       # N
        self.blend_dim = blend_dim                 # D
                
        self.encode = nn.LSTMCell(input_dim, hidden_size)
        self.decode = nn.LSTMCell(input_dim, hidden_size)
        self.blend_decoder = nn.Linear(hidden_size, blend_dim)
        self.blend_encoder = nn.Linear(hidden_size, blend_dim)
        self.scale_blend = nn.Linear(blend_dim, input_dim)
        
    def zero_hidden_state(self):
        return Variable(torch.zeros([self.batch_size, self.hidden_size]).cuda())
        
    def forward(self, inp):
        #TODO - zero 
        hidden = self.zero_hidden_state()                                            # BxH
        cell_state = self.zero_hidden_state()                                        # BxH
        encoder_states = []
        for j in range(inp.size()[1]):                                          # inp -> BxJxI
            encoder_input = inp[:, j:j+1]                                            # BxI
            hidden, cell_state = self.encode(encoder_input, (hidden, cell_state)) 
            encoder_states.append((hidden, cell_state))
            
        pointers = []
        pointer_distributions = []
        
        start_token = 0
        decoder_input = Variable(torch.Tensor([start_token] * self.batch_size)        # BxI
                                 .view(self.batch_size, self.input_dim).cuda())

        hidden, cell_state = encoder_states[-1]                                  # BxH
        for i in range(self.num_of_indices):       
            hidden, cell_state = self.decode(decoder_input, (hidden, cell_state)) # BxH
            print(hidden, cell_state)
            decoder_blend = self.blend_decoder(hidden)                            # BxD
            encoder_blends = []
            index_predists = []
            for i in range(inp.size()[1]):
                encoder_blend = self.blend_encoder(encoder_states[i][1])               # BxD
                raw_blend = encoder_blend + decoder_blend                              # BxD
                scaled_blend = self.scale_blend(raw_blend).squeeze(1)                  # BxI
                
                index_predist = scaled_blend
                
                encoder_blends.append(encoder_blend)
                index_predists.append(index_predist)
                
            index_predistribution = torch.stack(index_predists).t()                    # BxJ
            index_distribution = F.softmax(index_predistribution)
            pointer_distributions.append(index_distribution)                          
            index = index_distribution.data.max(1)[1].squeeze(1)                       # B

            emb = embedding_lookup(inp.t(), Variable(index))                           # BxB
            pointer_raw = torch.diag(emb)                                              # B
            pointer = pointer_raw
            pointers.append(pointer)
            decoder_input = pointer.unsqueeze(1)                                       # Bx1
            print(decoder_input)

        index_distributions = torch.stack(pointer_distributions)                    
        return index_distributions    

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions