Mnist, but backward. Lol. Instead of classifying digits, the model is trained to generate images from an input digit between 1-9.
Calling loader.load()
should just work. First it will download the imgs.zip.
The model is a dumb little MLP with these layers:
- Input Layer: One-hot encoded representation of the digit (dimension: 10).
- Hidden Layer: 128 neurons with ReLU activation.
- Output Layer: A flattened representation of the MNIST image (dimension: 784).
- Instantiate the model:
model = Model()
- Train the model:
model.train()
- Generate images for a list of digits or one digit:
model.generate(list(range(10)))
model.generate(4)
Generated images are visualized in a grid. Definitely room for improvement but not back for a ~2s training loop.