Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions experiments/conversion/fully_connected_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import numpy as np
import matplotlib.pyplot as plt

import os
from time import time as t

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

from bindsnet.datasets import MNIST
from bindsnet.conversion import ann_to_snn
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.network.monitors import Monitor
from bindsnet.analysis.plotting import plot_spikes

Expand Down Expand Up @@ -38,8 +41,7 @@ def forward(self, x):
return x


def main(seed=0, n_epochs=5, batch_size=100, time=50, update_interval=50, plot=False):

def main(seed=0, n_epochs=5, batch_size=100, time=50, dt=250, intensity=128, update_interval=50, plot=False):
np.random.seed(seed)

if torch.cuda.is_available():
Expand All @@ -56,8 +58,20 @@ def main(seed=0, n_epochs=5, batch_size=100, time=50, update_interval=50, plot=F
ANN = FullyConnectedNetwork()

# Get the MNIST data.
images, labels = MNIST('../../data/MNIST', download=True).get_train()
images /= images.max() # Standardizing to [0, 1].
mnist = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join("..", "..", "data", "MNIST"),
download=True,
train=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
),
)
images = mnist.train_data
labels = mnist.train_labels

images = images / images.max() # Standardizing to [0, 1].
images = images.view(-1, 784)
labels = labels.long()

Expand Down Expand Up @@ -120,14 +134,14 @@ def main(seed=0, n_epochs=5, batch_size=100, time=50, update_interval=50, plot=F
)
start = t()

SNN.run(inpts={'Input': images[i].repeat(time, 1, 1)}, time=time)
SNN.run(inputs={'Input': images[i].repeat(time, 1, 1)}, time=time)

spikes = {layer: SNN.monitors[layer].get('s') for layer in SNN.monitors}
voltages = {layer: SNN.monitors[layer].get('v') for layer in SNN.monitors}
prediction = torch.softmax(voltages['5'].sum(1), 0).argmax()
prediction = torch.softmax(voltages['5'].squeeze()[-1], dim=0).argmax()
correct.append((prediction == labels[i]).item())

SNN.reset_()
SNN.reset_state_variables()

if plot:
spikes = {k: spikes[k].cpu() for k in spikes}
Expand All @@ -143,6 +157,8 @@ def main(seed=0, n_epochs=5, batch_size=100, time=50, update_interval=50, plot=F
parser.add_argument('--time', type=int, default=50)
parser.add_argument('--update_interval', type=int, default=50)
parser.add_argument('--plot', dest='plot', action='store_true')
parser.add_argument("--dt", type=int, default=1.0)
parser.add_argument("--intensity", type=float, default=128)
parser.set_defaults(plot=False)
args = vars(parser.parse_args())

Expand Down