Skip to content

RISHIT7/Geospatial-Image-Classification-ViT

Repository files navigation

Geospatial Image Classification with CNNs and Vision Transformers

COL780 / COL7680 / JRL7680 — Assignment 3
Author: Rishit Jakharia (2022CS11621)


Overview

This project benchmarks CNN and Vision Transformer architectures on a 10-class geospatial land-use classification dataset. Four model configurations are trained and evaluated, each with a Focal Loss ablation study:

Task Architecture Key Modification
1.1 ResNet-18 Pretrained baseline
1.2 ResNet-18 + SE Blocks Squeeze-and-Excitation channel attention
2.1 DeiT-3 Small Pretrained Vision Transformer, [CLS] token classification
2.2 DeiT-3 Small + DyT LayerNorm replaced with Dynamic Tanh (Zhu et al., CVPR 2025)

Bonus tasks include Grad-CAM visualizations (Tasks 1.1, 1.2), attention map extraction (Task 2.1), and an empirical analysis of activation distributions comparing LayerNorm and DyT.


Results

All metrics are computed on the held-out test set using sklearn with multi_class='ovr' and average='macro'.

Model Variant Accuracy Macro F1 Macro AUC
ResNet-18 Baseline 0.9815 0.9808 0.9993
ResNet-18 Focal Loss 0.9830 0.9820 0.9995
ResNet-18 + SE Baseline 0.9856 0.9845 0.9995
ResNet-18 + SE Focal Loss 0.9856 0.9851 0.9993
DeiT-3 Small Baseline 0.9741 0.9731 0.9988
DeiT-3 Small Focal Loss 0.9870 0.9866 0.9997
DeiT-3 Small + DyT Baseline 0.9289 0.9265 0.9966
DeiT-3 Small + DyT Focal Loss 0.9433 0.9410 0.9972

Repository Structure

.
├── train.py                  # Unified training entry point (config-driven)
├── test.py                   # Unified evaluation script
├── requirements.txt          # Python dependencies
├── configs/                  # YAML experiment configs
│   ├── run_1.1_resnet.yaml
│   ├── run_1.1_resnet_ablation.yaml
│   ├── run_1.2_se_focal.yaml
│   ├── run_1.2_se_focal_ablation.yaml
│   ├── run_2.1_deit.yaml
│   ├── run_2.1_deit_ablation.yaml
│   ├── run_2.2_deit_dyt.yaml
│   └── run_2.2_deit_dyt_ablation.yaml
├── data/
│   ├── load_data.py          # CropData dataset class
│   └── augmentations.py      # Train/eval augmentation pipelines
├── models/
│   ├── cnn/
│   │   ├── resnet.py         # ResNet-18 and SE-ResNet-18
│   │   └── se_block.py       # Squeeze-and-Excitation module
│   └── vit/
│       ├── deit.py           # DeiT-3 Small factory (with optional DyT swap)
│       └── dyt_layer.py      # Dynamic Tanh layer
├── engine/
│   ├── trainer.py            # Training loop, validation, early stopping
│   ├── losses.py             # Focal Loss implementation
│   └── optimizers.py         # Optimizer and scheduler utilities
├── tools/
│   ├── grad_cam.py           # Grad-CAM visualization
│   ├── attention_maps.py     # Transformer attention map extraction
│   └── analyze_norms.py      # LayerNorm vs DyT activation analysis
└── results/
    ├── results.md            # Tabulated test metrics
    ├── images/               # Grad-CAM and attention map outputs
    └── norms_analysis/       # Activation distribution plots (DeiT vs DyT)

Training

All training is config-driven through YAML files. Each config specifies the model family, loss function, optimizer, scheduler, and hyperparameters.

python train.py --config <path-to-config>

Task 1: CNNs

# 1.1 ResNet-18 Baseline
python train.py --config configs/run_1.1_resnet.yaml

# 1.1 ResNet-18 Ablation (Focal Loss)
python train.py --config configs/run_1.1_resnet_ablation.yaml

# 1.2 SE-ResNet-18 Baseline
python train.py --config configs/run_1.2_se_focal.yaml

# 1.2 SE-ResNet-18 Ablation (Focal Loss)
python train.py --config configs/run_1.2_se_focal_ablation.yaml

Task 2: Vision Transformers

# 2.1 DeiT-3 Small Baseline
python train.py --config configs/run_2.1_deit.yaml

# 2.1 DeiT-3 Small Ablation (Focal Loss)
python train.py --config configs/run_2.1_deit_ablation.yaml

# 2.2 DeiT-3 Small + DyT Baseline
python train.py --config configs/run_2.2_deit_dyt.yaml

# 2.2 DeiT-3 Small + DyT Ablation (Focal Loss)
python train.py --config configs/run_2.2_deit_dyt_ablation.yaml

Training logs are synced to Weights & Biases under the COL7680-A3 project. The best checkpoint (by validation Macro AUC) is saved automatically.


Evaluation

python test.py --checkpoint <path-to-checkpoint> --model_family <cnn|transformer> [options]

Arguments

Argument Type Default Description
--checkpoint str required Path to a .pth model checkpoint
--model_family str required cnn or transformer
--img_dir str data/A3_Dataset Root image directory
--test_csv str data/A3_Dataset/test.csv Path to test split CSV
--batch_size int 32 Inference batch size
--num_classes int 10 Number of output classes
--use_se flag False Use SE blocks (CNN only)
--use_dyt flag False Use DyT layers (Transformer only)

Examples

# Evaluate ResNet-18
python test.py --checkpoint best_model.pth --model_family cnn

# Evaluate SE-ResNet-18
python test.py --checkpoint best_model.pth --model_family cnn --use_se

# Evaluate DeiT-3 Small
python test.py --checkpoint best_model.pth --model_family transformer

# Evaluate DeiT-3 Small + DyT
python test.py --checkpoint best_model.pth --model_family transformer --use_dyt

Visualization Tools

Grad-CAM (Tasks 1.1, 1.2)

The tools/grad_cam.py module provides a GradCAM class that hooks into the final convolutional layer to produce class-discriminative saliency maps.

from tools.grad_cam import GradCAM

cam = GradCAM(model, target_layer=model.layer4[-1])
heatmap = cam(input_tensor, class_idx=3)

Attention Maps (Task 2.1)

The tools/attention_maps.py module extracts [CLS] token attention weights from a specified transformer block.

from tools.attention_maps import AttentionMap

attn = AttentionMap(model, target_block=model.blocks[-1])
attn_map = attn(input_tensor)

Activation Distribution Analysis (Task 2.2)

# Standard DeiT-3 (LayerNorm)
python tools/analyze_norms.py --checkpoint <deit_checkpoint> --out_dir results/norms_analysis/deit

# DeiT-3 + DyT
python tools/analyze_norms.py --use_dyt --checkpoint <dyt_checkpoint> --out_dir results/norms_analysis/dyt

References

  1. He, K. et al. "Deep Residual Learning for Image Recognition." CVPR, 2016.
  2. Hu, J. et al. "Squeeze-and-Excitation Networks." CVPR, 2018.
  3. Touvron, H. et al. "DeiT III: Revenge of the ViT." ECCV, 2022.
  4. Zhu, J. et al. "Transformers without Normalization." CVPR, 2025.
  5. Lin, T. Y. et al. "Focal Loss for Dense Object Detection." ICCV, 2017.

About

Benchmarking CNN and Vision Transformer architectures for geospatial land-use classification. Features ResNet-18 (SE), DeiT-3, and Normalization-Free Transformers (DyT) with Grad-CAM and Attention Map analysis.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages