Skip to content

ainhoupna/SD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fashion-MNIST Classifier

PyPI version

View Documentation on GitHub Pages

This project implements an image classification model based on the Fashion MNIST dataset, using PyTorch Lightning to structure the code in a modular and scalable way. The dataset is loaded directly from torchvision.datasets. The goal is to train a simple convolutional network that classifies images into 10 different clothing categories.

Installation

To install dependencies and prepare the environment with uv, run the following commands in the terminal:

  1. Download and install dependencies: curl -sSf https://uv.io/install.sh | sh
  2. Initialize the environment: uv init
  3. Sync dependencies and environment: uv sync

Training and Evaluation

You can run the training and evaluation script using uvx, which executes the project's command-line scripts without needing to install the package in editable mode.

To run the training script:

uvx fashion-mnist-classifier

To run the app:

uvx --from fashion-mnist-classifier fashion-mnist-app

Building and Publishing to PyPI

To build and publish the package to PyPI, follow these steps:

  1. Install build tools:
    uv pip install build twine
  2. Build the package:
    python -m build
  3. Publish to PyPI:
    twine upload dist/*
    You will be prompted for your PyPI username and password.

Technical Details

  • Dataset: torchvision.datasets.FashionMNIST with custom transformations.
  • Model: Simple CNN with one convolutional layer, pooling, and fully connected layers.
  • Training: Implemented with PyTorch Lightning to facilitate handling epochs, performance, and metrics.
  • Configuration: Parameters such as batch size, paths, epochs defined in config.py.
  • Optimization: Adam with CrossEntropyLoss.

Generated Artifacts

When you run the training script (src/my_project/train.py) or the Gradio application (src/my_project/app.py), the following directories and files are created or updated:

  • data/:
    • Contains the downloaded Fashion-MNIST dataset files (e.g., FashionMNIST/raw/train-images-idx3-ubyte.gz).
  • models/lightning_logs/:
    • Stores logs and checkpoints generated by PyTorch Lightning during training. This typically includes:
      • version_X/checkpoints/: Model checkpoints (e.g., epoch=4-step=2340.ckpt).
      • version_X/metrics.csv: Training and validation metrics logged by CSVLogger.
  • reports/figures/:
    • Contains output visualizations from the evaluation step and data exploration. These include:
      • confusion_matrix.png: Confusion matrix of model predictions on the test set.
      • per_class_accuracy.png: Bar chart showing accuracy for each class.
      • misclassified_grid.png: Grid of misclassified sample images from the test set.
      • calibration_curve.png: Reliability diagram for model calibration.
      • train_loss_*.png: Plots of training loss over steps/epochs (generated if CSVLogger is used).
      • val_acc_*.png: Plots of validation accuracy over epochs (generated if CSVLogger is used).
      • class_distribution.png: Plot showing the distribution of samples per class in the dataset.
      • class_correlation_dendrogram.png: Dendrogram illustrating class similarity based on mean images.
    • reports/figures/gradio/: A subdirectory specifically for figures generated when using the Gradio application.

Reports & Visualizations

The project generates various reports and visualizations to assess model performance and explore the dataset. Details on the specific files and their locations can be found in the Generated Artifacts section.

Key visualizations include:

  • Confusion matrix: Shows the performance of the classification model.
  • Per-class accuracy: Illustrates how well the model performs on each individual class.
  • Calibration curve: Assesses the confidence of the model's predictions.
  • Misclassified image grids: Displays examples of images that the model predicted incorrectly.
  • Class distribution plots: Visualizes the balance of classes within the dataset.
  • Class similarity dendrograms: Helps understand relationships between different clothing categories based on their image features.
  • Training loss and validation accuracy curves: Track the model's learning progress over epochs.

Generating PDF Report with Typst

This project uses Typst to generate a PDF report from the model's results and visualizations.

Installation

To compile the report, you first need to install Typst. You can find installation instructions for your operating system on the official Typst GitHub repository.

Compiling the Report

Once Typst is installed, navigate to the reports directory and compile the main.typ file. This will generate a report file in the same directory, containing a summary of the project, including the generated plots and evaluation metrics.

Project Structure

fashion-mnist-classifier/
├── models/               # Directory for saved models and checkpoints
├── reports/              # Evaluation reports and generated figures
├── src/
│   └── my_project/       # Project source code
│       ├── __init__.py
│       ├── app.py        # Gradio application
│       ├── config.py     # Configurations and parameters
│       ├── dataset.py    # Dataset and DataModule
│       ├── model.py      # PyTorch Lightning model
│       ├── plots.py      # Visualization functions
│       └── train.py      # Main training script
├── .gitignore
├── LICENSE               # Project license
├── pyproject.toml        # Project metadata and dependencies
└── README.md             # This file

Contact

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •