diff --git a/README.md b/README.md index c6f284d..9857cc0 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Modules that consider successive calls to `forward` as different time-steps in a * [AbstractRecurrent](#rnn.AbstractRecurrent) : an abstract class inherited by Recurrent and LSTM; * [Recurrent](#rnn.Recurrent) : a generalized recurrent neural network container; * [LSTM](#rnn.LSTM) : a vanilla Long-Short Term Memory module; - * [FastLSTM](#rnn.FastLSTM) : a faster [LSTM](#rnn.LSTM) with optional support for batch normalization; + * [FastLSTM](#rnn.FastLSTM) : a faster [LSTM](#rnn.LSTM) with optional support for batch normalization; * [GRU](#rnn.GRU) : Gated Recurrent Units module; * [MuFuRu](#rnn.MuFuRu) : [Multi-function Recurrent Unit](https://arxiv.org/abs/1606.03002) module; * [Recursor](#rnn.Recursor) : decorates a module to make it conform to the [AbstractRecurrent](#rnn.AbstractRecurrent) interface; @@ -19,8 +19,11 @@ Modules that `forward` entire sequences through a decorated `AbstractRecurrent` * [AbstractSequencer](#rnn.AbstractSequencer) : an abstract class inherited by Sequencer, Repeater, RecurrentAttention, etc.; * [Sequencer](#rnn.Sequencer) : applies an encapsulated module to all elements in an input sequence (Tensor or Table); * [SeqLSTM](#rnn.SeqLSTM) : a very fast version of `nn.Sequencer(nn.FastLSTM)` where the `input` and `output` are tensors; - * [SeqLSTMP](#rnn.SeqLSTMP) : `SeqLSTM` with a projection layer; + * [SeqLSTM_WN](#rnn.SeqLSTM) : a weight-normalized version of `nn.SeqLSTM`; + * [SeqLSTMP](#rnn.SeqLSTMP) : `SeqLSTM` with a projection layer; + * [SeqLSTMP_WN](#rnn.SeqLSTMP) : a weight-normalized version of `nn.SeqLSTMP`; * [SeqGRU](#rnn.SeqGRU) : a very fast version of `nn.Sequencer(nn.GRU)` where the `input` and `output` are tensors; + * [SeqGRU_WN](#rnn.SeqGRU) : a weight-normalized version of `nn.SeqGRU`; * [SeqBRNN](#rnn.SeqBRNN) : Bidirectional RNN based on SeqLSTM; * [BiSequencer](#rnn.BiSequencer) : used for implementing Bidirectional RNNs and LSTMs; * [BiSequencerLM](#rnn.BiSequencerLM) : used for implementing Bidirectional RNNs and LSTMs for language models; diff --git a/SeqGRU_WN.lua b/SeqGRU_WN.lua new file mode 100644 index 0000000..39aee02 --- /dev/null +++ b/SeqGRU_WN.lua @@ -0,0 +1,574 @@ +--[[ +The MIT License (MIT) + +Copyright (c) 2016 Stéphane Guillitte, Joost van Doorn + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +--]] + +-- Modified by Richard Assar + +require 'torch' +require 'nn' + +local SeqGRU_WN, parent = torch.class('nn.SeqGRU_WN', 'nn.Module') + +--[[ +If we add up the sizes of all the tensors for output, gradInput, weights, +gradWeights, and temporary buffers, we get that a SequenceGRU stores this many +scalar values: + +NTD + 4NTH + 5NH + 6H^2 + 6DH + 7H + +Note that this class doesn't own input or gradOutput, so you'll +see a bit higher memory usage in practice. +--]] + +function SeqGRU_WN:__init(inputSize, outputSize) + parent.__init(self) + + self.inputSize = inputSize + self.outputSize = outputSize + self.seqLength = 1 + self.miniBatch = 1 + + local D, H = inputSize, outputSize + + self.weight = torch.Tensor(D + H, 3 * H) + self.gradWeight = torch.Tensor(D + H, 3 * H):zero() + self.hTmp = torch.Tensor(H, 3 * H):zero() + self.bias = torch.Tensor(3 * H) + self.gradBias = torch.Tensor(3 * H):zero() + + self.g = torch.Tensor(2, 3 * H) + self.gradG = torch.Tensor(2, 3 * H):zero() + self.v = torch.Tensor(D + H, 3 * H) + self.gradV = torch.Tensor(D + H, 3 * H):zero() + + self.norm = torch.Tensor(2, 3 * H) + self.scale = torch.Tensor(2, 3 * H) + + self.eps = 1e-16 + + self:reset() + + self.gates = torch.Tensor() -- This will be (T, N, 3H) + self.buffer1 = torch.Tensor() -- This will be (N, H) + self.buffer2 = torch.Tensor() -- This will be (N, H) + self.buffer3 = torch.Tensor() -- This will be (H,) + self.buffer4 = torch.Tensor() + self.buffer5 = torch.Tensor() + self.buffer6 = torch.Tensor() + self.buffer7 = torch.Tensor() + self.buffer8 = torch.Tensor() + self.grad_a_buffer = torch.Tensor() -- This will be (N, 3H) + + self.h0 = torch.Tensor() + + self._remember = 'neither' + + self.grad_h0 = torch.Tensor() + self.grad_x = torch.Tensor() + self.gradInput = {self.grad_h0, self.grad_x} + + -- set this to true to forward inputs as batchsize x seqlen x ... + -- instead of seqlen x batchsize + self.batchfirst = false + -- set this to true for variable length sequences that seperate + -- independent sequences with a step of zeros (a tensor of size D) + self.maskzero = false +end + +function SeqGRU_WN:parameters() + return {self.g, self.v, self.bias}, {self.gradG, self.gradV, self.gradBias} +end + +function SeqGRU_WN:initFromWeight(weight) + weight = weight or self.weight + + local D, H = self.inputSize, self.outputSize + + self.g[{{1}}] = weight[{{1,D}}]:norm(2,1):clamp(self.eps,math.huge) + self.g[{{2}}] = weight[{{D+1,D+H}}]:norm(2,1):clamp(self.eps,math.huge) + + self.v[{{1,D}}]:copy(weight[{{1,D}}]) + self.v[{{D+1,D+H}}]:copy(weight[{{D+1,D+H}}]) + + return self +end + +function SeqGRU_WN:reset(std) + if not std then + std = 1.0 / math.sqrt(self.outputSize + self.inputSize) + end + self.bias:zero() + self.bias[{{self.outputSize + 1, 2 * self.outputSize}}]:fill(1) + self.weight:normal(0, std) + + self:initFromWeight() + + return self +end + +function SeqGRU_WN:resetStates() + self.h0 = self.h0.new() +end + +-- unlike MaskZero, the mask is applied in-place +function SeqGRU_WN:recursiveMask(output, mask) + if torch.type(output) == 'table' then + for k,v in ipairs(output) do + self:recursiveMask(output[k], mask) + end + else + assert(torch.isTensor(output)) + + -- make sure mask has the same dimension as the output tensor + local outputSize = output:size():fill(1) + outputSize[1] = output:size(1) + mask:resize(outputSize) + -- build mask + local zeroMask = mask:expandAs(output) + output:maskedFill(zeroMask, 0) + end +end + +local function check_dims(x, dims) + assert(x:dim() == #dims) + for i, d in ipairs(dims) do + assert(x:size(i) == d) + end +end + +-- makes sure x, h0 and gradOutput have correct sizes. +-- batchfirst = true will transpose the N x T to conform to T x N +function SeqGRU_WN:_prepare_size(input, gradOutput) + local h0, x + if torch.type(input) == 'table' and #input == 2 then + h0, x = unpack(input) + elseif torch.isTensor(input) then + x = input + else + assert(false, 'invalid input') + end + assert(x:dim() == 3, "Only supports batch mode") + + if self.batchfirst then + x = x:transpose(1,2) + gradOutput = gradOutput and gradOutput:transpose(1,2) or nil + end + + local T, N = x:size(1), x:size(2) + local H, D = self.outputSize, self.inputSize + + check_dims(x, {T, N, D}) + if h0 then + check_dims(h0, {N, H}) + end + if gradOutput then + check_dims(gradOutput, {T, N, H}) + end + return h0, x, gradOutput +end + +function SeqGRU_WN:updateWeightMatrix() + local H, D = self.outputSize, self.inputSize + + self.norm[{{1}}]:norm(self.v[{{1, D}}],2,1):clamp(self.eps,math.huge) + self.norm[{{2}}]:norm(self.v[{{D + 1, D + H}}],2,1):clamp(self.eps,math.huge) + + self.scale:cdiv(self.g,self.norm) + + self.weight[{{1, D}}]:cmul(self.v[{{1, D}}],self.scale[{{1}}]:expandAs(self.v[{{1, D}}])) + self.weight[{{D + 1, D + H}}]:cmul(self.v[{{D + 1, D + H}}],self.scale[{{2}}]:expandAs(self.v[{{D + 1, D + H}}])) +end + +--[[ +Input: +- h0: Initial hidden state, (N, H) +- x: Input sequence, (T, N, D) + +Output: +- h: Sequence of hidden states, (T, N, H) +--]] + + + +function SeqGRU_WN:updateOutput(input) + if self.train ~= false then + self:updateWeightMatrix() + end + + self.recompute_backward = true + local h0, x = self:_prepare_size(input) + local T, N = x:size(1), x:size(2) + local D, H = self.inputSize, self.outputSize + self._output = self._output or self.weight.new() + + -- remember previous state? + local remember + if self.train ~= false then -- training + if self._remember == 'both' or self._remember == 'train' then + remember = true + elseif self._remember == 'neither' or self._remember == 'eval' then + remember = false + end + else -- evaluate + if self._remember == 'both' or self._remember == 'eval' then + remember = true + elseif self._remember == 'neither' or self._remember == 'train' then + remember = false + end + end + + self._return_grad_h0 = (h0 ~= nil) + + if not h0 then + h0 = self.h0 + if self.userPrevOutput then + local prev_N = self.userPrevOutput:size(1) + assert(prev_N == N, 'batch sizes must be consistent with userPrevOutput') + h0:resizeAs(self.userPrevOutput):copy(self.userPrevOutput) + elseif h0:nElement() == 0 or not remember then + h0:resize(N, H):zero() + elseif remember then + local prev_T, prev_N = self._output:size(1), self._output:size(2) + assert(prev_N == N, 'batch sizes must be the same to remember states') + h0:copy(self._output[prev_T]) + end + end + + local bias_expand = self.bias:view(1, 3 * H):expand(N, 3 * H) + local Wx = self.weight[{{1, D}}] + local Wh = self.weight[{{D + 1, D + H}}] + + local h = self._output + h:resize(T, N, H):zero() + local prev_h = h0 + self.gates:resize(T, N, 3 * H):zero() + for t = 1, T do + local cur_x = x[t] + local next_h = h[t] + local cur_gates = self.gates[t] + + cur_gates:addmm(bias_expand, cur_x, Wx) + cur_gates[{{}, {1, 2 * H}}]:addmm(prev_h, Wh[{{}, {1, 2 * H}}]) + cur_gates[{{}, {1, 2 * H}}]:sigmoid() + local r = cur_gates[{{}, {1, H}}] --reset gate : r = sig(Wx * x + Wh * prev_h + b) + local u = cur_gates[{{}, {H + 1, 2 * H}}] --update gate : u = sig(Wx * x + Wh * prev_h + b) + next_h:cmul(r, prev_h) --temporary buffer : r . prev_h + cur_gates[{{}, {2 * H + 1, 3 * H}}]:addmm(next_h, Wh[{{}, {2 * H + 1, 3 * H}}]) -- hc += Wh * r . prev_h + local hc = cur_gates[{{}, {2 * H + 1, 3 * H}}]:tanh() --hidden candidate : hc = tanh(Wx * x + Wh * r . prev_h + b) + next_h:addcmul(hc, -1, u, hc) + next_h:addcmul(u, prev_h) --next_h = (1-u) . hc + u . prev_h + + if self.maskzero then + -- build mask from input + local vectorDim = cur_x:dim() + self._zeroMask = self._zeroMask or cur_x.new() + self._zeroMask:norm(cur_x, 2, vectorDim) + self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) + self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) + -- zero masked output + self:recursiveMask({next_h, cur_gates}, self.zeroMask) + end + + prev_h = next_h + end + self.userPrevOutput = nil + + if self.batchfirst then + self.output = self._output:transpose(1,2) -- T x N -> N X T + else + self.output = self._output + end + + return self.output +end + +function SeqGRU_WN:backward(input, gradOutput, scale) + self.recompute_backward = false + scale = scale or 1.0 + assert(scale == 1.0, 'must have scale=1') + + local h0, x, grad_h = self:_prepare_size(input, gradOutput) + assert(grad_h, "Expecting gradOutput") + local N, T = x:size(2), x:size(1) + local D, H = self.inputSize, self.outputSize + + self._grad_x = self._grad_x or self.weight.new() + + if not h0 then h0 = self.h0 end + + local grad_h0, grad_x = self.grad_h0, self._grad_x + local h = self._output + + local Wx = self.weight[{{1, D}}] + local Wh = self.weight[{{D + 1, D + H}}] + local grad_Wx = self.gradWeight[{{1, D}}] + local grad_Wh = self.gradWeight[{{D + 1, D + H}}] + local grad_b = self.gradBias + + local Vx = self.v[{{1, D}}] + local Vh = self.v[{{D + 1, D + H}}] + + local scale_x = self.scale[{{1}}]:expandAs(Vx) + local scale_h = self.scale[{{2}}]:expandAs(Vh) + + local norm_x = self.norm[{{1}}]:expandAs(Vx) + local norm_h = self.norm[{{2}}]:expandAs(Vh) + + local grad_Gx = self.gradG[{{1}}] + local grad_Gh = self.gradG[{{2}}] + + local grad_Vx = self.gradV[{{1, D}}] + local grad_Vh = self.gradV[{{D + 1, D + H}}] + + grad_h0:resizeAs(h0):zero() + + grad_x:resizeAs(x):zero() + self.buffer1:resizeAs(h0) + local grad_next_h = self.gradPrevOutput and self.buffer1:copy(self.gradPrevOutput) or self.buffer1:zero() + local temp_buffer = self.buffer2:resizeAs(h0):zero() + --local dWx = self.dWx:resizeAs() + for t = T, 1, -1 do + local next_h = h[t] + local prev_h = nil + if t == 1 then + prev_h = h0 + else + prev_h = h[t - 1] + end + grad_next_h:add(grad_h[t]) + + if self.maskzero then + -- build mask from input + local cur_x = x[t] + local vectorDim = cur_x:dim() + self._zeroMask = self._zeroMask or cur_x.new() + self._zeroMask:norm(cur_x, 2, vectorDim) + self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) + self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) + -- zero masked gradOutput + self:recursiveMask(grad_next_h, self.zeroMask) + end + + local r = self.gates[{t, {}, {1, H}}] + local u = self.gates[{t, {}, {H + 1, 2 * H}}] + local hc = self.gates[{t, {}, {2 * H + 1, 3 * H}}] + + local grad_a = self.grad_a_buffer:resize(N, 3 * H):zero() + local grad_ar = grad_a[{{}, {1, H}}] + local grad_au = grad_a[{{}, {H + 1, 2 * H}}] + local grad_ahc = grad_a[{{}, {2 * H + 1, 3 * H}}] + + -- We will use grad_au as temporary buffer + -- to compute grad_ahc. + local grad_hc = grad_au:fill(0):addcmul(grad_next_h, -1, u, grad_next_h) + grad_ahc:fill(1):addcmul(-1, hc,hc):cmul(grad_hc) + local grad_r = grad_au:fill(0):addmm(grad_ahc, Wh[{{}, {2 * H + 1, 3 * H}}]:t() ):cmul(prev_h) + grad_ar:fill(1):add(-1, r):cmul(r):cmul(grad_r) + + temp_buffer:fill(0):add(-1, hc):add(prev_h) + grad_au:fill(1):add(-1, u):cmul(u):cmul(temp_buffer):cmul(grad_next_h) + grad_x[t]:mm(grad_a, Wx:t()) + + local dWx = self.buffer4:resize(x[t]:t():size(1), grad_a:size(2)):mm(x[t]:t(), grad_a) + grad_Wx:cmul(dWx,Vx):cdiv(norm_x) + + local dGradGx = self.buffer7:resize(1,grad_Wx:size(2)):sum(grad_Wx,1) + grad_Gx:add(dGradGx) + + dWx:cmul(scale_x) + + grad_Wx:cmul(Vx,scale_x):cdiv(norm_x) + grad_Wx:cmul(dGradGx:expandAs(grad_Wx)) + + dWx:add(-1,grad_Wx) + + grad_Vx:add(dWx) + + local dWh = self.buffer5:resize(prev_h:t():size(1),grad_a[{{}, {1, 2 * H}}]:size(2)):mm(prev_h:t(), grad_a[{{}, {1, 2 * H}}]) + grad_Wh[{{}, {1, 2 * H}}]:copy(dWh) + + local grad_a_sum = self.buffer3:resize(H):sum(grad_a, 1) + grad_b:add(scale, grad_a_sum) + temp_buffer:fill(0):add(prev_h):cmul(r) + + local dWh = self.buffer6:resize(temp_buffer:t():size(1),grad_ahc:size(2)):mm(temp_buffer:t(), grad_ahc) + grad_Wh[{{}, {2 * H + 1, 3 * H}}]:copy(dWh) + + self.hTmp:cmul(grad_Wh,Vh):cdiv(norm_h) + + local dGradGh = self.buffer8:resize(1,self.hTmp:size(2)):sum(self.hTmp,1) + grad_Gh:add(dGradGh) + + grad_Wh:cmul(scale_h) + + self.hTmp:cmul(Vh,scale_h):cdiv(norm_h) + self.hTmp:cmul(dGradGh:expandAs(self.hTmp)) + + grad_Wh:add(-1,self.hTmp) + + grad_Vh:add(grad_Wh) + + grad_next_h:cmul(u) + grad_next_h:addmm(grad_a[{{}, {1, 2 * H}}], Wh[{{}, {1, 2 * H}}]:t()) + temp_buffer:fill(0):addmm(grad_a[{{}, {2 * H + 1, 3 * H}}], Wh[{{}, {2 * H + 1, 3 * H}}]:t()):cmul(r) + grad_next_h:add(temp_buffer) + end + grad_h0:copy(grad_next_h) + + if self.batchfirst then + self.grad_x = grad_x:transpose(1,2) -- T x N -> N x T + else + self.grad_x = grad_x + end + self.gradPrevOutput = nil + self.userGradPrevOutput = self.grad_h0 + + if self._return_grad_h0 then + self.gradInput = {self.grad_h0, self.grad_x} + else + self.gradInput = self.grad_x + end + + return self.gradInput +end + +function SeqGRU_WN:clearState() + self.gates:set() + self.buffer1:set() + self.buffer2:set() + self.buffer3:set() + self.buffer4:set() + self.buffer5:set() + self.buffer6:set() + self.buffer7:set() + self.buffer8:set() + self.grad_a_buffer:set() + + self.grad_h0:set() + self.grad_x:set() + self._grad_x = nil + self.output:set() + self._output = nil + self.gradInput = nil + + self.zeroMask = nil + self._zeroMask = nil + self._maskbyte = nil + self._maskindices = nil + + self.userGradPrevOutput = nil + self.gradPrevOutput = nil +end + +function SeqGRU_WN:updateGradInput(input, gradOutput) + if self.recompute_backward then + self:backward(input, gradOutput, 1.0) + end + return self.gradInput +end + +function SeqGRU_WN:forget() + self.h0:resize(0) +end + +function SeqGRU_WN:accGradParameters(input, gradOutput, scale) + if self.recompute_backward then + self:backward(input, gradOutput, scale) + end +end + +function SeqGRU_WN:type(type, ...) + self.zeroMask = nil + self._zeroMask = nil + self._maskbyte = nil + self._maskindices = nil + return parent.type(self, type, ...) +end + +-- Toggle to feed long sequences using multiple forwards. +-- 'eval' only affects evaluation (recommended for RNNs) +-- 'train' only affects training +-- 'neither' affects neither training nor evaluation +-- 'both' affects both training and evaluation (recommended for LSTMs) +SeqGRU_WN.remember = nn.Sequencer.remember + +function SeqGRU_WN:training() + if self.train == false then + -- forget at the start of each training + self:forget() + end + parent.training(self) +end + +function SeqGRU_WN:evaluate() + if self.train ~= false then + self:updateWeightMatrix() + -- forget at the start of each evaluation + self:forget() + end + parent.evaluate(self) + assert(self.train == false) +end + +function SeqGRU_WN:toGRU() + self:updateWeightMatrix() + + local D, H = self.inputSize, self.outputSize + + local Wx = self.weight[{{1, D}}] + local Wh = self.weight[{{D + 1, D + H}}] + local gWx = self.gradWeight[{{1, D}}] + local gWh = self.gradWeight[{{D + 1, D + H}}] + + -- bias + local bxi = self.bias[{{1, 2 * H}}] + local bxo = self.bias[{{2 * H + 1, 3 * H}}] + + local gbxi = self.gradBias[{{1, 2 * H}}] + local gbxo = self.gradBias[{{2 * H + 1, 3 * H}}] + + local gru = nn.GRU(self.inputSize, self.outputSize) + local params, gradParams = gru:parameters() + local nWxi, nbxi, nWhi, nWxo, nbxo, nWho = unpack(params) + local ngWxi, ngbxi, ngWhi, ngWxo, ngbxo, ngWho = unpack(gradParams) + + + nWxi:t():copy(Wx[{{}, {1, 2*H}}]) -- update and reset gate + nWxo:t():copy(Wx[{{}, {2 * H + 1, 3 * H}}]) + nWhi:t():copy(Wh[{{}, {1, 2*H}}]) + nWho:t():copy(Wh[{{}, {2 * H + 1, 3 * H}}]) + nbxi:copy(bxi[{{1, 2 * H}}]) + nbxo:copy(bxo) + ngWxi:t():copy(gWx[{{}, {1, 2*H}}]) -- update and reset gate + ngWxo:t():copy(gWx[{{}, {2 * H + 1, 3 * H}}]) -- + ngWhi:t():copy(gWh[{{}, {1, 2*H}}]) + ngWho:t():copy(gWh[{{}, {2 * H + 1, 3 * H}}]) + ngbxi:copy(gbxi[{{1, 2 * H}}]) + ngbxo:copy(gbxo) + + return gru +end + +function SeqGRU_WN:maskZero() + self.maskzero = true +end \ No newline at end of file diff --git a/SeqLSTMP_WN.lua b/SeqLSTMP_WN.lua new file mode 100644 index 0000000..e112636 --- /dev/null +++ b/SeqLSTMP_WN.lua @@ -0,0 +1,128 @@ +-- Modified by Richard Assar +-- Weight Normalized LSTM with weighted peephole connections +local SeqLSTMP_WN, parent = torch.class('nn.SeqLSTMP_WN', 'nn.SeqLSTM_WN') + +function SeqLSTMP_WN:__init(inputsize, hiddensize, outputsize) + outputsize = outputsize or hiddensize + assert(inputsize and hiddensize and outputsize, "Expecting input, hidden and output size") + local D, H, R = inputsize, hiddensize, outputsize + + self.weightO = torch.Tensor(H, R) + self.gradWeightO = torch.Tensor(H, R) + + self.gO = torch.Tensor(1, R) + self.gradGO = torch.Tensor(1, R):zero() + self.vO = torch.Tensor(H, R) + self.gradVO = torch.Tensor(H, R):zero() + + self.normO = torch.Tensor(1, R) + self.scaleO = torch.Tensor(1, R) + + self.bufferO1 = torch.Tensor() + self.bufferO2 = torch.Tensor() + + parent.__init(self, inputsize, hiddensize, outputsize) +end + +function SeqLSTMP_WN:initFromWeight(weight, weightO) + parent.initFromWeight(self, weight) + + weightO = weightO or self.weightO + + self.gO = weightO:norm(2,1):clamp(self.eps,math.huge) + self.vO:copy(weightO) + + return self +end + +function SeqLSTMP_WN:reset(std) + self.bias:zero() + self.bias[{{self.outputsize + 1, 2 * self.outputsize}}]:fill(1) + + if not std then + self.weight:normal(0, 1.0 / math.sqrt(self.hiddensize + self.inputsize)) + self.weightO:normal(0, 1.0 / math.sqrt(self.outputsize + self.hiddensize)) + else + self.weight:normal(0, std) + self.weightO:normal(0, std) + end + + self:initFromWeight() + + return self +end + +function SeqLSTMP_WN:updateWeightMatrix() + parent.updateWeightMatrix(self) + + local H, R, D = self.hiddensize, self.outputsize, self.inputsize + + self.normO:norm(self.vO,2,1):clamp(self.eps,math.huge) + self.scaleO:cdiv(self.gO,self.normO) + self.weightO:cmul(self.vO,self.scaleO:expandAs(self.vO)) +end + +function SeqLSTMP_WN:adapter(t) + local T, N = self._output:size(1), self._output:size(2) + self._hidden = self._hidden or self.next_h.new() + self._hidden:resize(T, N, self.hiddensize) + + self._hidden[t]:copy(self.next_h) + self.next_h:resize(N,self.outputsize) + self.next_h:mm(self._hidden[t], self.weightO) +end + +function SeqLSTMP_WN:gradAdapter(scale, t) + scale = scale or 1.0 + assert(scale == 1.0, 'must have scale=1') + + self.buffer3:resizeAs(self.grad_next_h):copy(self.grad_next_h) + + local dWo = self.bufferO1:resize(self._hidden[t]:t():size(1), self.grad_next_h:size(2)):mm(self._hidden[t]:t(), self.grad_next_h) + + local normO = self.normO:expandAs(self.vO) + local scaleO = self.scaleO:expandAs(self.vO) + + self.gradWeightO:cmul(dWo,self.vO):cdiv(normO) + + local dGradGO = self.bufferO2:resize(1,self.gradWeightO:size(2)):sum(self.gradWeightO,1) + self.gradGO:add(dGradGO) + + dWo:cmul(scaleO) + + self.gradWeightO:cmul(self.vO,scaleO):cdiv(normO) + self.gradWeightO:cmul(dGradGO:expandAs(self.gradWeightO)) + + dWo:add(-1,self.gradWeightO) + + self.gradVO:add(dWo) + + self.grad_next_h:resize(self._output:size(2), self.hiddensize) + self.grad_next_h:mm(self.buffer3, self.weightO:t()) +end + +function SeqLSTMP_WN:parameters() + local param,dparam = parent.parameters(self) + + table.insert(param, self.gO) + table.insert(param, self.vO) + + table.insert(dparam, self.gradGO) + table.insert(dparam, self.gradVO) + + return param,dparam +end + +function SeqLSTMP_WN:clearState() + parent.clearState(self) + self.bufferO1:set() + self.bufferO2:set() +end + +function SeqLSTMP_WN:accUpdateGradParameters(input, gradOutput, lr) + error "accUpdateGradParameters not implemented for SeqLSTMP_WN" +end + +function SeqLSTMP_WN:toFastLSTM() + error "toFastLSTM not supported for SeqLSTMP_WN" +end \ No newline at end of file diff --git a/SeqLSTM_WN.lua b/SeqLSTM_WN.lua new file mode 100644 index 0000000..27a2b62 --- /dev/null +++ b/SeqLSTM_WN.lua @@ -0,0 +1,672 @@ +--[[ +The MIT License (MIT) + +Copyright (c) 2016 Justin Johnson + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +--]] + +--[[ +Thank you Justin for this awesome super fast code: + * https://github.com/jcjohnson/torch-rnn + +If we add up the sizes of all the tensors for output, gradInput, weights, +gradWeights, and temporary buffers, we get that a SeqLSTM_WN stores this many +scalar values: + +NTD + 6NTH + 8NH + 8H^2 + 8DH + 9H + +N : batchsize; T : seqlen; D : inputsize; H : outputsize + +For N = 100, D = 512, T = 100, H = 1024 and with 4 bytes per number, this comes +out to 305MB. Note that this class doesn't own input or gradOutput, so you'll +see a bit higher memory usage in practice. +--]] + +-- Modified by Richard Assar +local SeqLSTM_WN, parent = torch.class('nn.SeqLSTM_WN', 'nn.Module') + +function SeqLSTM_WN:__init(inputsize, hiddensize, outputsize) + parent.__init(self) + -- for non-SeqLSTM_WNP, only inputsize, hiddensize=outputsize are provided + outputsize = outputsize or hiddensize + local D, H, R = inputsize, hiddensize, outputsize + self.inputsize, self.hiddensize, self.outputsize = D, H, R + + self.weight = torch.Tensor(D+R, 4 * H) + self.gradWeight = torch.Tensor(D+R, 4 * H) + + self.bias = torch.Tensor(4 * H) + self.gradBias = torch.Tensor(4 * H):zero() + + self.g = torch.Tensor(2, 4 * H) + self.gradG = torch.Tensor(2, 4 * H):zero() + self.v = torch.Tensor(D + R, 4 * H) + self.gradV = torch.Tensor(D + R, 4 * H):zero() + + self.norm = torch.Tensor(2, 4 * H) + self.scale = torch.Tensor(2, 4 * H) + + self.eps = 1e-16 + + self:reset() + + self.cell = torch.Tensor() -- This will be (T, N, H) + self.gates = torch.Tensor() -- This will be (T, N, 4H) + self.buffer1 = torch.Tensor() -- This will be (N, H) + self.buffer2 = torch.Tensor() -- This will be (N, H) + self.buffer3 = torch.Tensor() -- This will be (1, 4H) + self.buffer4 = torch.Tensor() + self.buffer5 = torch.Tensor() + self.buffer6 = torch.Tensor() + self.buffer7 = torch.Tensor() + + self.grad_a_buffer = torch.Tensor() -- This will be (N, 4H) + + self.h0 = torch.Tensor() + self.c0 = torch.Tensor() + + self._remember = 'neither' + + self.grad_c0 = torch.Tensor() + self.grad_h0 = torch.Tensor() + self.grad_x = torch.Tensor() + self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} + + -- set this to true to forward inputs as batchsize x seqlen x ... + -- instead of seqlen x batchsize + self.batchfirst = false + -- set this to true for variable length sequences that seperate + -- independent sequences with a step of zeros (a tensor of size D) + self.maskzero = false +end + +function SeqLSTM_WN:parameters() + return {self.g, self.v, self.bias}, {self.gradG, self.gradV, self.gradBias} +end + +function SeqLSTM_WN:initFromWeight(weight) + weight = weight or self.weight + + local H, R, D = self.hiddensize, self.outputsize, self.inputsize + + self.g[{{1}}] = weight[{{1,D}}]:norm(2,1):clamp(self.eps,math.huge) + self.g[{{2}}] = weight[{{D+1,D+R}}]:norm(2,1):clamp(self.eps,math.huge) + + self.v[{{1,D}}]:copy(weight[{{1,D}}]) + self.v[{{D+1,D+R}}]:copy(weight[{{D+1,D+R}}]) + + return self +end + +function SeqLSTM_WN:reset(std) + if not std then + std = 1.0 / math.sqrt(self.outputsize + self.inputsize) + end + + self.bias:zero() + self.bias[{{self.outputsize + 1, 2 * self.outputsize}}]:fill(1) + self.weight:normal(0, std) + + self:initFromWeight() + + return self +end + +function SeqLSTM_WN:resetStates() + self.h0 = self.h0.new() + self.c0 = self.c0.new() +end + +-- unlike MaskZero, the mask is applied in-place +function SeqLSTM_WN:recursiveMask(output, mask) + if torch.type(output) == 'table' then + for k,v in ipairs(output) do + self:recursiveMask(output[k], mask) + end + else + assert(torch.isTensor(output)) + + -- make sure mask has the same dimension as the output tensor + local outputSize = output:size():fill(1) + outputSize[1] = output:size(1) + mask:resize(outputSize) + -- build mask + local zeroMask = mask:expandAs(output) + output:maskedFill(zeroMask, 0) + end +end + +local function check_dims(x, dims) + assert(x:dim() == #dims) + for i, d in ipairs(dims) do + assert(x:size(i) == d) + end +end + +-- makes sure x, h0, c0 and gradOutput have correct sizes. +-- batchfirst = true will transpose the N x T to conform to T x N +function SeqLSTM_WN:_prepare_size(input, gradOutput) + local c0, h0, x + if torch.type(input) == 'table' and #input == 3 then + c0, h0, x = unpack(input) + elseif torch.type(input) == 'table' and #input == 2 then + h0, x = unpack(input) + elseif torch.isTensor(input) then + x = input + else + assert(false, 'invalid input') + end + assert(x:dim() == 3, "Only supports batch mode") + + if self.batchfirst then + x = x:transpose(1,2) + gradOutput = gradOutput and gradOutput:transpose(1,2) or nil + end + + local T, N = x:size(1), x:size(2) + local H, D = self.outputsize, self.inputsize + + check_dims(x, {T, N, D}) + if h0 then + check_dims(h0, {N, H}) + end + if c0 then + check_dims(c0, {N, H}) + end + if gradOutput then + check_dims(gradOutput, {T, N, H}) + end + return c0, h0, x, gradOutput +end + +function SeqLSTM_WN:updateWeightMatrix() + local H, R, D = self.hiddensize, self.outputsize, self.inputsize + + self.norm[{{1}}]:norm(self.v[{{1, D}}],2,1):clamp(self.eps,math.huge) + self.norm[{{2}}]:norm(self.v[{{D + 1, D + R}}],2,1):clamp(self.eps,math.huge) + + self.scale:cdiv(self.g,self.norm) + + self.weight[{{1, D}}]:cmul(self.v[{{1, D}}],self.scale[{{1}}]:expandAs(self.v[{{1, D}}])) + self.weight[{{D + 1, D + R}}]:cmul(self.v[{{D + 1, D + R}}],self.scale[{{2}}]:expandAs(self.v[{{D + 1, D + R}}])) +end + +--[[ +Input: +- c0: Initial cell state, (N, H) +- h0: Initial hidden state, (N, H) +- x: Input sequence, (T, N, D) + +Output: +- h: Sequence of hidden states, (T, N, H) +--]] + +function SeqLSTM_WN:updateOutput(input) + if self.train ~= false then + self:updateWeightMatrix() + end + + self.recompute_backward = true + local c0, h0, x = self:_prepare_size(input) + local N, T = x:size(2), x:size(1) + self.hiddensize = self.hiddensize or self.outputsize -- backwards compat + local H, R, D = self.hiddensize, self.outputsize, self.inputsize + + self._output = self._output or self.weight.new() + + -- remember previous state? + local remember + if self.train ~= false then -- training + if self._remember == 'both' or self._remember == 'train' then + remember = true + elseif self._remember == 'neither' or self._remember == 'eval' then + remember = false + end + else -- evaluate + if self._remember == 'both' or self._remember == 'eval' then + remember = true + elseif self._remember == 'neither' or self._remember == 'train' then + remember = false + end + end + + self._return_grad_c0 = (c0 ~= nil) + self._return_grad_h0 = (h0 ~= nil) + if not c0 then + c0 = self.c0 + if self.userPrevCell then + local prev_N = self.userPrevCell:size(1) + assert(prev_N == N, 'batch sizes must be consistent with userPrevCell') + c0:resizeAs(self.userPrevCell):copy(self.userPrevCell) + elseif c0:nElement() == 0 or not remember then + c0:resize(N, H):zero() + elseif remember then + local prev_T, prev_N = self.cell:size(1), self.cell:size(2) + assert(prev_N == N, 'batch sizes must be constant to remember states') + c0:copy(self.cell[prev_T]) + end + end + if not h0 then + h0 = self.h0 + if self.userPrevOutput then + local prev_N = self.userPrevOutput:size(1) + assert(prev_N == N, 'batch sizes must be consistent with userPrevOutput') + h0:resizeAs(self.userPrevOutput):copy(self.userPrevOutput) + elseif h0:nElement() == 0 or not remember then + h0:resize(N, R):zero() + elseif remember then + local prev_T, prev_N = self._output:size(1), self._output:size(2) + assert(prev_N == N, 'batch sizes must be the same to remember states') + h0:copy(self._output[prev_T]) + end + end + + local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H) + local Wx = self.weight:narrow(1,1,D) + local Wh = self.weight:narrow(1,D+1,R) + + local h, c = self._output, self.cell + h:resize(T, N, R):zero() + c:resize(T, N, H):zero() + local prev_h, prev_c = h0, c0 + self.gates:resize(T, N, 4 * H):zero() + for t = 1, T do + local cur_x = x[t] + self.next_h = h[t] + local next_c = c[t] + local cur_gates = self.gates[t] + cur_gates:addmm(bias_expand, cur_x, Wx) + cur_gates:addmm(prev_h, Wh) + cur_gates[{{}, {1, 3 * H}}]:sigmoid() + cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh() + local i = cur_gates[{{}, {1, H}}] -- input gate + local f = cur_gates[{{}, {H + 1, 2 * H}}] -- forget gate + local o = cur_gates[{{}, {2 * H + 1, 3 * H}}] -- output gate + local g = cur_gates[{{}, {3 * H + 1, 4 * H}}] -- input transform + self.next_h:cmul(i, g) + next_c:cmul(f, prev_c):add(self.next_h) + self.next_h:tanh(next_c):cmul(o) + + -- for LSTMP + self:adapter(t) + + if self.maskzero then + -- build mask from input + local vectorDim = cur_x:dim() + self._zeroMask = self._zeroMask or cur_x.new() + self._zeroMask:norm(cur_x, 2, vectorDim) + self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) + self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) + -- zero masked output + self:recursiveMask({self.next_h, next_c, cur_gates}, self.zeroMask) + end + + prev_h, prev_c = self.next_h, next_c + end + self.userPrevOutput = nil + self.userPrevCell = nil + + if self.batchfirst then + self.output = self._output:transpose(1,2) -- T x N -> N X T + else + self.output = self._output + end + + return self.output +end + +function SeqLSTM_WN:adapter(scale, t) + -- Placeholder for SeqLSTM_WNP +end + +function SeqLSTM_WN:backward(input, gradOutput, scale) + self.recompute_backward = false + scale = scale or 1.0 + assert(scale == 1.0, 'must have scale=1') + + local c0, h0, x, grad_h = self:_prepare_size(input, gradOutput) + assert(grad_h, "Expecting gradOutput") + local N, T = x:size(2), x:size(1) + self.hiddensize = self.hiddensize or self.outputsize -- backwards compat + local H, R, D = self.hiddensize, self.outputsize, self.inputsize + + self._grad_x = self._grad_x or self.weight:narrow(1,1,D).new() + + if not c0 then c0 = self.c0 end + if not h0 then h0 = self.h0 end + + local grad_c0, grad_h0, grad_x = self.grad_c0, self.grad_h0, self._grad_x + local h, c = self._output, self.cell + + local Wx = self.weight:narrow(1,1,D) + local Wh = self.weight:narrow(1,D+1,R) + local grad_Wx = self.gradWeight:narrow(1,1,D) + local grad_Wh = self.gradWeight:narrow(1,D+1,R) + local grad_b = self.gradBias + + local Vx = self.v[{{1, D}}] + local Vh = self.v[{{D + 1, D + R}}] + + local scale_x = self.scale[{{1}}]:expandAs(Vx) + local scale_h = self.scale[{{2}}]:expandAs(Vh) + + local norm_x = self.norm[{{1}}]:expandAs(Vx) + local norm_h = self.norm[{{2}}]:expandAs(Vh) + + local grad_Gx = self.gradG[{{1}}] + local grad_Gh = self.gradG[{{2}}] + + local grad_Vx = self.gradV[{{1, D}}] + local grad_Vh = self.gradV[{{D + 1, D + R}}] + + grad_h0:resizeAs(h0):zero() + grad_c0:resizeAs(c0):zero() + grad_x:resizeAs(x):zero() + self.buffer1:resizeAs(h0) + self.buffer2:resizeAs(c0) + self.grad_next_h = self.gradPrevOutput and self.buffer1:copy(self.gradPrevOutput) or self.buffer1:zero() + local grad_next_c = self.userNextGradCell and self.buffer2:copy(self.userNextGradCell) or self.buffer2:zero() + + for t = T, 1, -1 do + local next_h, next_c = h[t], c[t] + local prev_h, prev_c = nil, nil + if t == 1 then + prev_h, prev_c = h0, c0 + else + prev_h, prev_c = h[t - 1], c[t - 1] + end + self.grad_next_h:add(grad_h[t]) + + if self.maskzero and torch.type(self) ~= 'nn.SeqLSTM_WN' then + -- we only do this for sub-classes (LSTM doesn't need it) + -- build mask from input + local cur_x = x[t] + local vectorDim = cur_x:dim() + self._zeroMask = self._zeroMask or cur_x.new() + self._zeroMask:norm(cur_x, 2, vectorDim) + self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) + self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) + -- zero masked gradOutput + self:recursiveMask(self.grad_next_h, self.zeroMask) + end + + -- for LSTMP + self:gradAdapter(scale, t) + + local i = self.gates[{t, {}, {1, H}}] + local f = self.gates[{t, {}, {H + 1, 2 * H}}] + local o = self.gates[{t, {}, {2 * H + 1, 3 * H}}] + local g = self.gates[{t, {}, {3 * H + 1, 4 * H}}] + + local grad_a = self.grad_a_buffer:resize(N, 4 * H):zero() + local grad_ai = grad_a[{{}, {1, H}}] + local grad_af = grad_a[{{}, {H + 1, 2 * H}}] + local grad_ao = grad_a[{{}, {2 * H + 1, 3 * H}}] + local grad_ag = grad_a[{{}, {3 * H + 1, 4 * H}}] + + -- We will use grad_ai, grad_af, and grad_ao as temporary buffers + -- to to compute grad_next_c. We will need tanh_next_c (stored in grad_ai) + -- to compute grad_ao; the other values can be overwritten after we compute + -- grad_next_c + local tanh_next_c = grad_ai:tanh(next_c) + local tanh_next_c2 = grad_af:cmul(tanh_next_c, tanh_next_c) + local my_grad_next_c = grad_ao + my_grad_next_c:fill(1):add(-1, tanh_next_c2):cmul(o):cmul(self.grad_next_h) + grad_next_c:add(my_grad_next_c) + + -- We need tanh_next_c (currently in grad_ai) to compute grad_ao; after + -- that we can overwrite it. + grad_ao:fill(1):add(-1, o):cmul(o):cmul(tanh_next_c):cmul(self.grad_next_h) + + -- Use grad_ai as a temporary buffer for computing grad_ag + local g2 = grad_ai:cmul(g, g) + grad_ag:fill(1):add(-1, g2):cmul(i):cmul(grad_next_c) + + -- We don't need any temporary storage for these so do them last + grad_ai:fill(1):add(-1, i):cmul(i):cmul(g):cmul(grad_next_c) + grad_af:fill(1):add(-1, f):cmul(f):cmul(prev_c):cmul(grad_next_c) + + grad_x[t]:mm(grad_a, Wx:t()) + + -- + local dWx = self.buffer4:resize(x[t]:t():size(1), grad_a:size(2)):mm(x[t]:t(), grad_a) + + grad_Wx:cmul(dWx,Vx):cdiv(norm_x) + + local dGradGx = self.buffer5:resize(1,grad_Wx:size(2)):sum(grad_Wx,1) + grad_Gx:add(dGradGx) + + dWx:cmul(scale_x) + + grad_Wx:cmul(Vx,scale_x):cdiv(norm_x) + grad_Wx:cmul(dGradGx:expandAs(grad_Wx)) + + dWx:add(-1,grad_Wx) + + grad_Vx:add(dWx) + + -- + local dWh = self.buffer6:resize(prev_h:t():size(1), grad_a:size(2)):mm(prev_h:t(), grad_a) + + grad_Wh:cmul(dWh,Vh):cdiv(norm_h) + + local dGradGh = self.buffer7:resize(1,grad_Wh:size(2)):sum(grad_Wh,1) + grad_Gh:add(dGradGh) + + dWh:cmul(scale_h) + + grad_Wh:cmul(Vh,scale_h):cdiv(norm_h) + grad_Wh:cmul(dGradGh:expandAs(grad_Wh)) + + dWh:add(-1,grad_Wh) + + grad_Vh:add(dWh) + + -- + local grad_a_sum = self.buffer3:resize(1, 4 * H):sum(grad_a, 1) + grad_b:add(scale, grad_a_sum) + + self.grad_next_h = torch.mm(grad_a, Wh:t()) + grad_next_c:cmul(f) + + end + grad_h0:copy(self.grad_next_h) + grad_c0:copy(grad_next_c) + + if self.batchfirst then + self.grad_x = grad_x:transpose(1,2) -- T x N -> N x T + else + self.grad_x = grad_x + end + self.gradPrevOutput = nil + self.userNextGradCell = nil + self.userGradPrevCell = self.grad_c0 + self.userGradPrevOutput = self.grad_h0 + + if self._return_grad_c0 and self._return_grad_h0 then + self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} + elseif self._return_grad_h0 then + self.gradInput = {self.grad_h0, self.grad_x} + else + self.gradInput = self.grad_x + end + + return self.gradInput +end + +function SeqLSTM_WN:gradAdapter(scale, t) + -- Placeholder for SeqLSTM_WNP +end + +function SeqLSTM_WN:clearState() + self.cell:set() + self.gates:set() + self.buffer1:set() + self.buffer2:set() + self.buffer3:set() + self.buffer4:set() + self.buffer5:set() + self.buffer6:set() + self.buffer7:set() + self.grad_a_buffer:set() + + self.grad_c0:set() + self.grad_h0:set() + self.grad_x:set() + self._grad_x = nil + self.output:set() + self._output = nil + self.gradInput = nil + + self.zeroMask = nil + self._zeroMask = nil + self._maskbyte = nil + self._maskindices = nil +end + +function SeqLSTM_WN:updateGradInput(input, gradOutput) + if self.recompute_backward then + self:backward(input, gradOutput, 1.0) + end + return self.gradInput +end + +function SeqLSTM_WN:accGradParameters(input, gradOutput, scale) + if self.recompute_backward then + self:backward(input, gradOutput, scale) + end +end + +function SeqLSTM_WN:forget() + self.c0:resize(0) + self.h0:resize(0) +end + +function SeqLSTM_WN:type(type, ...) + self.zeroMask = nil + self._zeroMask = nil + self._maskbyte = nil + self._maskindices = nil + return parent.type(self, type, ...) +end + +-- Toggle to feed long sequences using multiple forwards. +-- 'eval' only affects evaluation (recommended for RNNs) +-- 'train' only affects training +-- 'neither' affects neither training nor evaluation +-- 'both' affects both training and evaluation (recommended for LSTMs) +SeqLSTM_WN.remember = nn.Sequencer.remember + +function SeqLSTM_WN:training() + if self.train == false then + -- forget at the start of each training + self:forget() + end + parent.training(self) +end + +function SeqLSTM_WN:evaluate() + if self.train ~= false then + self:updateWeightMatrix() + -- forget at the start of each evaluation + self:forget() + end + parent.evaluate(self) + assert(self.train == false) +end + +function SeqLSTM_WN:toFastLSTM() + self:updateWeightMatrix() + + local D, H = self.inputsize, self.outputsize + -- input : x to ... + local Wxi = self.weight[{{1, D},{1, H}}] + local Wxf = self.weight[{{1, D},{H + 1, 2 * H}}] + local Wxo = self.weight[{{1, D},{2 * H + 1, 3 * H}}] + local Wxg = self.weight[{{1, D},{3 * H + 1, 4 * H}}] + + local gWxi = self.gradWeight[{{1, D},{1, H}}] + local gWxf = self.gradWeight[{{1, D},{H + 1, 2 * H}}] + local gWxo = self.gradWeight[{{1, D},{2 * H + 1, 3 * H}}] + local gWxg = self.gradWeight[{{1, D},{3 * H + 1, 4 * H}}] + + -- hidden : h to ... + local Whi = self.weight[{{D + 1, D + H},{1, H}}] + local Whf = self.weight[{{D + 1, D + H},{H + 1, 2 * H}}] + local Who = self.weight[{{D + 1, D + H},{2 * H + 1, 3 * H}}] + local Whg = self.weight[{{D + 1, D + H},{3 * H + 1, 4 * H}}] + + local gWhi = self.gradWeight[{{D + 1, D + H},{1, H}}] + local gWhf = self.gradWeight[{{D + 1, D + H},{H + 1, 2 * H}}] + local gWho = self.gradWeight[{{D + 1, D + H},{2 * H + 1, 3 * H}}] + local gWhg = self.gradWeight[{{D + 1, D + H},{3 * H + 1, 4 * H}}] + + -- bias + local bi = self.bias[{{1, H}}] + local bf = self.bias[{{H + 1, 2 * H}}] + local bo = self.bias[{{2 * H + 1, 3 * H}}] + local bg = self.bias[{{3 * H + 1, 4 * H}}] + + local gbi = self.gradBias[{{1, H}}] + local gbf = self.gradBias[{{H + 1, 2 * H}}] + local gbo = self.gradBias[{{2 * H + 1, 3 * H}}] + local gbg = self.gradBias[{{3 * H + 1, 4 * H}}] + + local lstm = nn.FastLSTM(self.inputsize, self.outputsize) + local params, gradParams = lstm:parameters() + local Wx, b, Wh = params[1], params[2], params[3] + local gWx, gb, gWh = gradParams[1], gradParams[2], gradParams[3] + + Wx[{{1, H}}]:t():copy(Wxi) + Wx[{{H + 1, 2 * H}}]:t():copy(Wxg) + Wx[{{2 * H + 1, 3 * H}}]:t():copy(Wxf) + Wx[{{3 * H + 1, 4 * H}}]:t():copy(Wxo) + + gWx[{{1, H}}]:t():copy(gWxi) + gWx[{{H + 1, 2 * H}}]:t():copy(gWxg) + gWx[{{2 * H + 1, 3 * H}}]:t():copy(gWxf) + gWx[{{3 * H + 1, 4 * H}}]:t():copy(gWxo) + + Wh[{{1, H}}]:t():copy(Whi) + Wh[{{H + 1, 2 * H}}]:t():copy(Whg) + Wh[{{2 * H + 1, 3 * H}}]:t():copy(Whf) + Wh[{{3 * H + 1, 4 * H}}]:t():copy(Who) + + gWh[{{1, H}}]:t():copy(gWhi) + gWh[{{H + 1, 2 * H}}]:t():copy(gWhg) + gWh[{{2 * H + 1, 3 * H}}]:t():copy(gWhf) + gWh[{{3 * H + 1, 4 * H}}]:t():copy(gWho) + + b[{{1, H}}]:copy(bi) + b[{{H + 1, 2 * H}}]:copy(bg) + b[{{2 * H + 1, 3 * H}}]:copy(bf) + b[{{3 * H + 1, 4 * H}}]:copy(bo) + + gb[{{1, H}}]:copy(gbi) + gb[{{H + 1, 2 * H}}]:copy(gbg) + gb[{{2 * H + 1, 3 * H}}]:copy(gbf) + gb[{{3 * H + 1, 4 * H}}]:copy(gbo) + + return lstm +end + +function SeqLSTM_WN:maskZero() + self.maskzero = true +end diff --git a/init.lua b/init.lua index 51d6451..8416ff3 100644 --- a/init.lua +++ b/init.lua @@ -54,6 +54,9 @@ torch.include('rnn', 'RecurrentAttention.lua') torch.include('rnn', 'SeqLSTM.lua') torch.include('rnn', 'SeqLSTMP.lua') torch.include('rnn', 'SeqGRU.lua') +torch.include('rnn', 'SeqLSTM_WN.lua') +torch.include('rnn', 'SeqLSTMP_WN.lua') +torch.include('rnn', 'SeqGRU_WN.lua') torch.include('rnn', 'SeqReverseSequence.lua') torch.include('rnn', 'SeqBRNN.lua')