diff --git a/examples/pytorch/FastCells/fastcell_example.py b/examples/pytorch/FastCells/fastcell_example.py index 9d55dd9d7..91674894a 100644 --- a/examples/pytorch/FastCells/fastcell_example.py +++ b/examples/pytorch/FastCells/fastcell_example.py @@ -43,10 +43,17 @@ def main(): (dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std) = helpermethods.preProcessData(dataDir) - assert dataDimension % inputDims == 0, "Infeasible per step input, " + \ "Timesteps have to be integer" + timeSteps = int(dataDimension / inputDims) + Xtrain = Xtrain.reshape((-1, timeSteps, inputDims)) + Xtest = Xtest.reshape((-1, timeSteps, inputDims)) + + if not batch_first: + Xtrain = np.swapaxes(Xtrain, 0, 1) + Xtest = np.swapaxes(Xtest, 0, 1) + currDir = helpermethods.createTimeStampDir(dataDir, cell) helpermethods.dumpCommand(sys.argv, currDir) diff --git a/pytorch/edgeml_pytorch/graph/rnn.py b/pytorch/edgeml_pytorch/graph/rnn.py index 5a292ee00..ceed5a5e1 100644 --- a/pytorch/edgeml_pytorch/graph/rnn.py +++ b/pytorch/edgeml_pytorch/graph/rnn.py @@ -144,8 +144,8 @@ def getVars(self): def get_model_size(self): ''' - Function to get aimed model size - ''' + Function to get aimed model size + ''' mats = self.getVars() endW = self._num_W_matrices endU = endW + self._num_U_matrices @@ -261,7 +261,7 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1])) self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1])) - self.copy_previous_UW() + # self.copy_previous_UW() @property def name(self): @@ -330,7 +330,7 @@ class FastGRNNCUDACell(RNNCell): ''' def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, wSparsity=1.0, uSparsity=1.0, name="FastGRNNCUDACell"): - super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity, + super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_nonlinearity, update_nonlinearity, 1, 1, 2, wRank, uRank, wSparsity, uSparsity) if utils.findCUDA() is None: raise Exception('FastGRNNCUDA is supported only on GPU devices.') @@ -967,10 +967,15 @@ class BaseRNN(nn.Module): [batchSize, timeSteps, inputDims] ''' - def __init__(self, cell: RNNCell, batch_first=False): + def __init__(self, cell: RNNCell, batch_first=False, cell_reverse: RNNCell=None, bidirectional=False): super(BaseRNN, self).__init__() - self._RNNCell = cell + self._RNNCell = cell self._batch_first = batch_first + self._bidirectional = bidirectional + if cell_reverse is not None: + self.RNNCell_reverse = cell_reverse + elif self._bidirectional: + self.RNNCell_reverse = cell def getVars(self): return self._RNNCell.getVars() @@ -978,12 +983,24 @@ def getVars(self): def forward(self, input, hiddenState=None, cellState=None): self.device = input.device + self.num_directions = 2 if self._bidirectional else 1 + if self._bidirectional: + self.num_directions = 2 + else: + self.num_directions = 1 + hiddenStates = torch.zeros( [input.shape[0], input.shape[1], self._RNNCell.output_size]).to(self.device) + + if self._bidirectional: + hiddenStates_reverse = torch.zeros( + [input.shape[0], input.shape[1], + self._RNNCell_reverse.output_size]).to(self.device) + if hiddenState is None: hiddenState = torch.zeros( - [input.shape[0] if self._batch_first else input.shape[1], + [self.num_directions, input.shape[0] if self._batch_first else input.shape[1], self._RNNCell.output_size]).to(self.device) if self._batch_first is True: @@ -991,39 +1008,77 @@ def forward(self, input, hiddenState=None, cellStates = torch.zeros( [input.shape[0], input.shape[1], self._RNNCell.output_size]).to(self.device) + if self._bidirectional: + cellStates_reverse = torch.zeros( + [input.shape[0], input.shape[1], + self._RNNCell_reverse.output_size]).to(self.device) if cellState is None: cellState = torch.zeros( - [input.shape[0], self._RNNCell.output_size]).to(self.device) + [self.num_directions, input.shape[0], self._RNNCell.output_size]).to(self.device) for i in range(0, input.shape[1]): - hiddenState, cellState = self._RNNCell( - input[:, i, :], (hiddenState, cellState)) - hiddenStates[:, i, :] = hiddenState - cellStates[:, i, :] = cellState - return hiddenStates, cellStates + hiddenState[0], cellState[0] = self._RNNCell( + input[:, i, :], (hiddenState[0].clone(), cellState[0].clone())) + hiddenStates[:, i, :] = hiddenState[0] + cellStates[:, i, :] = cellState[0] + if self._bidirectional: + hiddenState[1], cellState[1] = self._RNNCell_reverse( + input[:, input.shape[1]-i-1, :], (hiddenState[1].clone(), cellState[1].clone())) + hiddenStates_reverse[:, i, :] = hiddenState[1] + cellStates_reverse[:, i, :] = cellState[1] + if not self._bidirectional: + return hiddenStates, cellStates + else: + return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1) else: for i in range(0, input.shape[1]): - hiddenState = self._RNNCell(input[:, i, :], hiddenState) - hiddenStates[:, i, :] = hiddenState - return hiddenStates + hiddenState[0] = self._RNNCell(input[:, i, :], hiddenState[0].clone()) + hiddenStates[:, i, :] = hiddenState[0] + if self._bidirectional: + hiddenState[1] = self._RNNCell_reverse( + input[:, input.shape[1]-i-1, :], hiddenState[1].clone()) + hiddenStates_reverse[:, i, :] = hiddenState[1] + if not self._bidirectional: + return hiddenStates + else: + return torch.cat([hiddenStates,hiddenStates_reverse],-1) else: if self._RNNCell.cellType == "LSTMLR": cellStates = torch.zeros( [input.shape[0], input.shape[1], self._RNNCell.output_size]).to(self.device) + if self._bidirectional: + cellStates_reverse = torch.zeros( + [input.shape[0], input.shape[1], + self._RNNCell_reverse.output_size]).to(self.device) if cellState is None: cellState = torch.zeros( - [input.shape[1], self._RNNCell.output_size]).to(self.device) + [self.num_directions, input.shape[1], self._RNNCell.output_size]).to(self.device) for i in range(0, input.shape[0]): - hiddenState, cellState = self._RNNCell( - input[i, :, :], (hiddenState, cellState)) - hiddenStates[i, :, :] = hiddenState - cellStates[i, :, :] = cellState - return hiddenStates, cellStates + hiddenState[0], cellState[0] = self._RNNCell( + input[i, :, :], (hiddenState[0].clone(), cellState[0].clone())) + hiddenStates[i, :, :] = hiddenState[0] + cellStates[i, :, :] = cellState[0] + if self._bidirectional: + hiddenState[1], cellState[1] = self._RNNCell_reverse( + input[input.shape[0]-i-1, :, :], (hiddenState[1].clone(), cellState[1].clone())) + hiddenStates_reverse[i, :, :] = hiddenState[1] + cellStates_reverse[i, :, :] = cellState[1] + if not self._bidirectional: + return hiddenStates, cellStates + else: + return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1) else: for i in range(0, input.shape[0]): - hiddenState = self._RNNCell(input[i, :, :], hiddenState) - hiddenStates[i, :, :] = hiddenState - return hiddenStates + hiddenState[0] = self._RNNCell(input[i, :, :], hiddenState[0].clone()) + hiddenStates[i, :, :] = hiddenState[0] + if self._bidirectional: + hiddenState[1] = self._RNNCell_reverse( + input[input.shape[0]-i-1, :, :], hiddenState[1].clone()) + hiddenStates_reverse[i, :, :] = hiddenState[1] + if not self._bidirectional: + return hiddenStates + else: + return torch.cat([hiddenStates,hiddenStates_reverse],-1) class LSTM(nn.Module): @@ -1031,14 +1086,26 @@ class LSTM(nn.Module): def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", wRank=None, uRank=None, - wSparsity=1.0, uSparsity=1.0, batch_first=False): + wSparsity=1.0, uSparsity=1.0, batch_first=False, + bidirectional=False, is_shared_bidirectional=True): super(LSTM, self).__init__() + self._bidirectional = bidirectional + self._batch_first = batch_first + self._is_shared_bidirectional = is_shared_bidirectional self.cell = LSTMLRCell(input_size, hidden_size, gate_nonlinearity=gate_nonlinearity, update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank, wSparsity=wSparsity, uSparsity=uSparsity) - self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) + self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) + + if self._bidirectional is True and self._is_shared_bidirectional is False: + self.cell_reverse = LSTMLRCell(input_size, hidden_size, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, + wRank=wRank, uRank=uRank, + wSparsity=wSparsity, uSparsity=uSparsity) + self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) def forward(self, input, hiddenState=None, cellState=None): return self.unrollRNN(input, hiddenState, cellState) @@ -1049,14 +1116,26 @@ class GRU(nn.Module): def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", wRank=None, uRank=None, - wSparsity=1.0, uSparsity=1.0, batch_first=False): + wSparsity=1.0, uSparsity=1.0, batch_first=False, + bidirectional=False, is_shared_bidirectional=True): super(GRU, self).__init__() + self._bidirectional = bidirectional + self._batch_first = batch_first + self._is_shared_bidirectional = is_shared_bidirectional self.cell = GRULRCell(input_size, hidden_size, gate_nonlinearity=gate_nonlinearity, update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank, wSparsity=wSparsity, uSparsity=uSparsity) - self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) + self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) + + if self._bidirectional is True and self._is_shared_bidirectional is False: + self.cell_reverse = GRULRCell(input_size, hidden_size, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, + wRank=wRank, uRank=uRank, + wSparsity=wSparsity, uSparsity=uSparsity) + self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) def forward(self, input, hiddenState=None, cellState=None): return self.unrollRNN(input, hiddenState, cellState) @@ -1067,14 +1146,26 @@ class UGRNN(nn.Module): def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", wRank=None, uRank=None, - wSparsity=1.0, uSparsity=1.0, batch_first=False): + wSparsity=1.0, uSparsity=1.0, batch_first=False, + bidirectional=False, is_shared_bidirectional=True): super(UGRNN, self).__init__() + self._bidirectional = bidirectional + self._batch_first = batch_first + self._is_shared_bidirectional = is_shared_bidirectional self.cell = UGRNNLRCell(input_size, hidden_size, gate_nonlinearity=gate_nonlinearity, update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank, wSparsity=wSparsity, uSparsity=uSparsity) - self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) + self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) + + if self._bidirectional is True and self._is_shared_bidirectional is False: + self.cell_reverse = UGRNNLRCell(input_size, hidden_size, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, + wRank=wRank, uRank=uRank, + wSparsity=wSparsity, uSparsity=uSparsity) + self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) def forward(self, input, hiddenState=None, cellState=None): return self.unrollRNN(input, hiddenState, cellState) @@ -1085,15 +1176,28 @@ class FastRNN(nn.Module): def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", wRank=None, uRank=None, - wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0, batch_first=False): + wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0, + batch_first=False, bidirectional=False, is_shared_bidirectional=True): super(FastRNN, self).__init__() + self._bidirectional = bidirectional + self._batch_first = batch_first + self._is_shared_bidirectional = is_shared_bidirectional self.cell = FastRNNCell(input_size, hidden_size, gate_nonlinearity=gate_nonlinearity, update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank, wSparsity=wSparsity, uSparsity=uSparsity, alphaInit=alphaInit, betaInit=betaInit) - self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) + self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) + + if self._bidirectional is True and self._is_shared_bidirectional is False: + self.cell_reverse = FastRNNCell(input_size, hidden_size, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, + wRank=wRank, uRank=uRank, + wSparsity=wSparsity, uSparsity=uSparsity, + alphaInit=alphaInit, betaInit=betaInit) + self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) def forward(self, input, hiddenState=None, cellState=None): return self.unrollRNN(input, hiddenState, cellState) @@ -1105,15 +1209,27 @@ class FastGRNN(nn.Module): def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0, - batch_first=False): + batch_first=False, bidirectional=False, is_shared_bidirectional=True): super(FastGRNN, self).__init__() + self._bidirectional = bidirectional + self._batch_first = batch_first + self._is_shared_bidirectional = is_shared_bidirectional self.cell = FastGRNNCell(input_size, hidden_size, gate_nonlinearity=gate_nonlinearity, update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank, wSparsity=wSparsity, uSparsity=uSparsity, zetaInit=zetaInit, nuInit=nuInit) - self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) + self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) + + if self._bidirectional is True and self._is_shared_bidirectional is False: + self.cell_reverse = FastGRNNCell(input_size, hidden_size, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, + wRank=wRank, uRank=uRank, + wSparsity=wSparsity, uSparsity=uSparsity, + zetaInit=zetaInit, nuInit=nuInit) + self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) def getVars(self): return self.unrollRNN.getVars() @@ -1222,8 +1338,8 @@ def getVars(self): def get_model_size(self): ''' - Function to get aimed model size - ''' + Function to get aimed model size + ''' mats = self.getVars() endW = self._num_W_matrices endU = endW + self._num_U_matrices diff --git a/pytorch/edgeml_pytorch/trainer/fastTrainer.py b/pytorch/edgeml_pytorch/trainer/fastTrainer.py index 3f0ebd338..32c96b68d 100644 --- a/pytorch/edgeml_pytorch/trainer/fastTrainer.py +++ b/pytorch/edgeml_pytorch/trainer/fastTrainer.py @@ -9,6 +9,14 @@ from edgeml_pytorch.graph.rnn import * import numpy as np +class SimpleFC(nn.Module): + def __init__(self, input_size, num_classes, name="SimpleFC"): + super(SimpleFC, self).__init__() + self.FC = nn.Parameter(torch.randn([input_size, num_classes])) + self.FCbias = nn.Parameter(torch.randn([num_classes])) + + def forward(self, input): + return torch.matmul(input, self.FC) + self.FCbias class FastTrainer: @@ -50,23 +58,17 @@ def __init__(self, FastObj, numClasses, sW=1.0, sU=1.0, self.numMatrices = self.FastObj.num_weight_matrices self.totalMatrices = self.numMatrices[0] + self.numMatrices[1] - self.optimizer = self.optimizer() - self.RNN = BaseRNN(self.FastObj, batch_first=self.batch_first).to(self.device) - - self.FC = nn.Parameter(torch.randn( - [self.FastObj.output_size, self.numClasses])).to(self.device) - self.FCbias = nn.Parameter(torch.randn( - [self.numClasses])).to(self.device) - + self.simpleFC = SimpleFC(self.FastObj.output_size, self.numClasses).to(self.device) self.FastParams = self.FastObj.getVars() + self.optimizer = self.optimizer() def classifier(self, feats): ''' Can be raplaced by any classifier TODO: Make this a separate class if needed ''' - return torch.matmul(feats, self.FC) + self.FCbias + return self.simpleFC(feats) def computeLogits(self, input): ''' @@ -74,19 +76,23 @@ def computeLogits(self, input): ''' if self.FastObj.cellType == "LSTMLR": feats, _ = self.RNN(input) - logits = self.classifier(feats[-1, :]) else: feats = self.RNN(input) - logits = self.classifier(feats[-1, :]) - return logits, feats[:, -1] + if self.batch_first: + logits = self.classifier(feats[:, -1]) + return logits, feats[:, -1] + else: + logits = self.classifier(feats[-1, :]) + return logits, feats[-1, :] def optimizer(self): ''' Optimizer for FastObj Params ''' + paramList = list(self.FastObj.parameters()) + list(self.simpleFC.parameters()) optimizer = torch.optim.Adam( - self.FastObj.parameters(), lr=self.learningRate) + paramList, lr=self.learningRate) return optimizer @@ -168,12 +174,12 @@ def getModelSize(self): hasSparse = hasSparse or sparseFlag # Replace this with classifier class call - nnz, size, sparseFlag = utils.estimateNNZ(self.FC, 1.0) + nnz, size, sparseFlag = utils.estimateNNZ(self.simpleFC.FC, 1.0) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag - nnz, size, sparseFlag = utils.estimateNNZ(self.FCbias, 1.0) + nnz, size, sparseFlag = utils.estimateNNZ(self.simpleFC.FCbias, 1.0) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag @@ -341,8 +347,8 @@ def saveParams(self, currDir): np.save(os.path.join(currDir, "Bo.npy"), self.FastParams[self.totalMatrices + 3].data.cpu()) - np.save(os.path.join(currDir, "FC.npy"), self.FC.data.cpu()) - np.save(os.path.join(currDir, "FCbias.npy"), self.FCbias.data.cpu()) + np.save(os.path.join(currDir, "FC.npy"), self.simpleFC.FC.data.cpu()) + np.save(os.path.join(currDir, "FCbias.npy"), self.simpleFC.FCbias.data.cpu()) def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, decayStep, decayRate, dataDir, currDir): @@ -351,7 +357,13 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, ''' fileName = str(self.FastObj.cellType) + 'Results_pytorch.txt' resultFile = open(os.path.join(dataDir, fileName), 'a+') - numIters = int(np.ceil(float(Xtrain.shape[0]) / float(batchSize))) + if self.batch_first: + self.timeSteps = Xtrain.shape[1] + self.numPoints = Xtrain.shape[0] + else: + self.timeSteps = Xtrain.shape[0] + self.numPoints = Xtrain.shape[1] + numIters = int(np.ceil(float(self.numPoints) / float(batchSize))) totalBatches = numIters * totalEpochs counter = 0 @@ -362,11 +374,6 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, ihtDone = 1 maxTestAcc = -10000 header = '*' * 20 - self.timeSteps = int(Xtest.shape[1] / self.inputDims) - Xtest = Xtest.reshape((-1, self.timeSteps, self.inputDims)) - Xtest = np.swapaxes(Xtest, 0, 1) - Xtrain = Xtrain.reshape((-1, self.timeSteps, self.inputDims)) - Xtrain = np.swapaxes(Xtrain, 0, 1) for i in range(0, totalEpochs): print("\nEpoch Number: " + str(i), file=self.outFile) @@ -376,7 +383,7 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, for param_group in self.optimizer.param_groups: param_group['lr'] = self.learningRate - shuffled = list(range(Xtrain.shape[1])) + shuffled = list(range(self.numPoints)) np.random.shuffle(shuffled) trainAcc = 0.0 trainLoss = 0.0 @@ -389,9 +396,12 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, (header, msg, header), file=self.outFile) k = shuffled[j * batchSize:(j + 1) * batchSize] - batchX = Xtrain[:, k, :] + if self.batch_first: + batchX = Xtrain[k, :, :] + else: + batchX = Xtrain[:, k, :] + batchY = Ytrain[k] - self.optimizer.zero_grad() logits, _ = self.computeLogits(batchX.to(self.device)) batchLoss = self.loss(logits, batchY.to(self.device))