diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..0b1be16 --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) + self.pool = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(64 * 4 * 4, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.pool(nn.functional.relu(self.conv1(x))) + x = self.pool(nn.functional.relu(self.conv2(x))) + x = x.view(-1, 64 * 4 * 4) + x = nn.functional.relu(self.fc1(x)) + x = self.fc2(x) + return nn.functional.log_softmax(x, dim=1) + + def train_cnn(self, trainloader, epochs=3): + optimizer = optim.SGD(self.parameters(), lr=0.01) + criterion = nn.NLLLoss() + + for epoch in range(epochs): + for images, labels in trainloader: + optimizer.zero_grad() + output = self(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + + torch.save(self.state_dict(), "mnist_cnn_model.pth") + + +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) + +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) +trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + +cnn = CNN() +cnn.train_cnn(trainloader)