View Documentation on GitHub Pages
This project implements an image classification model based on the Fashion MNIST dataset, using PyTorch Lightning to structure the code in a modular and scalable way.
The dataset is loaded directly from torchvision.datasets. The goal is to train a simple convolutional network that classifies images into 10 different clothing categories.
To install dependencies and prepare the environment with uv, run the following commands in the terminal:
- Download and install dependencies:
curl -sSf https://uv.io/install.sh | sh - Initialize the environment:
uv init - Sync dependencies and environment:
uv sync
You can run the training and evaluation script using uvx, which executes the project's command-line scripts without needing to install the package in editable mode.
To run the training script:
uvx fashion-mnist-classifierTo run the app:
uvx --from fashion-mnist-classifier fashion-mnist-appTo build and publish the package to PyPI, follow these steps:
- Install build tools:
uv pip install build twine
- Build the package:
python -m build
- Publish to PyPI:
You will be prompted for your PyPI username and password.
twine upload dist/*
- Dataset:
torchvision.datasets.FashionMNISTwith custom transformations. - Model: Simple CNN with one convolutional layer, pooling, and fully connected layers.
- Training: Implemented with PyTorch Lightning to facilitate handling epochs, performance, and metrics.
- Configuration: Parameters such as batch size, paths, epochs defined in config.py.
- Optimization: Adam with CrossEntropyLoss.
When you run the training script (src/my_project/train.py) or the Gradio application (src/my_project/app.py), the following directories and files are created or updated:
data/:- Contains the downloaded Fashion-MNIST dataset files (e.g.,
FashionMNIST/raw/train-images-idx3-ubyte.gz).
- Contains the downloaded Fashion-MNIST dataset files (e.g.,
models/lightning_logs/:- Stores logs and checkpoints generated by PyTorch Lightning during training. This typically includes:
version_X/checkpoints/: Model checkpoints (e.g.,epoch=4-step=2340.ckpt).version_X/metrics.csv: Training and validation metrics logged byCSVLogger.
- Stores logs and checkpoints generated by PyTorch Lightning during training. This typically includes:
reports/figures/:- Contains output visualizations from the evaluation step and data exploration. These include:
confusion_matrix.png: Confusion matrix of model predictions on the test set.per_class_accuracy.png: Bar chart showing accuracy for each class.misclassified_grid.png: Grid of misclassified sample images from the test set.calibration_curve.png: Reliability diagram for model calibration.train_loss_*.png: Plots of training loss over steps/epochs (generated ifCSVLoggeris used).val_acc_*.png: Plots of validation accuracy over epochs (generated ifCSVLoggeris used).class_distribution.png: Plot showing the distribution of samples per class in the dataset.class_correlation_dendrogram.png: Dendrogram illustrating class similarity based on mean images.
reports/figures/gradio/: A subdirectory specifically for figures generated when using the Gradio application.
- Contains output visualizations from the evaluation step and data exploration. These include:
The project generates various reports and visualizations to assess model performance and explore the dataset. Details on the specific files and their locations can be found in the Generated Artifacts section.
Key visualizations include:
- Confusion matrix: Shows the performance of the classification model.
- Per-class accuracy: Illustrates how well the model performs on each individual class.
- Calibration curve: Assesses the confidence of the model's predictions.
- Misclassified image grids: Displays examples of images that the model predicted incorrectly.
- Class distribution plots: Visualizes the balance of classes within the dataset.
- Class similarity dendrograms: Helps understand relationships between different clothing categories based on their image features.
- Training loss and validation accuracy curves: Track the model's learning progress over epochs.
This project uses Typst to generate a PDF report from the model's results and visualizations.
To compile the report, you first need to install Typst. You can find installation instructions for your operating system on the official Typst GitHub repository.
Once Typst is installed, navigate to the reports directory and compile the main.typ file. This will generate a report file in the same directory, containing a summary of the project, including the generated plots and evaluation metrics.
fashion-mnist-classifier/
├── models/ # Directory for saved models and checkpoints
├── reports/ # Evaluation reports and generated figures
├── src/
│ └── my_project/ # Project source code
│ ├── __init__.py
│ ├── app.py # Gradio application
│ ├── config.py # Configurations and parameters
│ ├── dataset.py # Dataset and DataModule
│ ├── model.py # PyTorch Lightning model
│ ├── plots.py # Visualization functions
│ └── train.py # Main training script
├── .gitignore
├── LICENSE # Project license
├── pyproject.toml # Project metadata and dependencies
└── README.md # This file