This project implements a deep learning model for classifying stereo images (9×9 pixels) into two classes: 0 and 1.
The model is a convolutional neural network (CNN) that processes 2-channel images:
- Channel 1: Grayscale image
- Channel 2: Disparity map (NA values encoded as 255)
Install the required Python packages:
pip install torch torchvision numpy matplotlib scikit-learn seaborn tqdm pandas Pillowpython train.py --data_path ./img --batch_size 64 --lr 0.0005 --epochs 20 --experiment_name 20_epochs_testTo run with default parameters:
python train.pyproject/
├── train.py # Main training script with CLI
├── img/ # Dataset folder (PNG images)
│ ├── label_0_vol_001_00000050_3416.png
│ ├── label_1_vol_001_00000100_2713.png
│ └── ...
├── logs/ # Training logs and results
│ └── 20_epochs_test/ # Example experiment
│ ├── images/ # Generated graphs and plots
│ │ ├── training_curves.png
│ │ ├── confusion_matrix.png
│ │ ├── roc_curve.png
│ │ ├── precision_recall_curve.png
│ │ ├── class_metrics.png
│ │ └── inference_grid_*.png
│ ├── metrics/ # Numerical results
│ │ ├── training_metrics.json
│ │ ├── training_history.csv
│ │ ├── evaluation_metrics.json
│ │ ├── classification_report.csv
│ │ └── confusion_matrix.csv
│ ├── inference/ # Prediction visualizations
│ ├── training_config.json
│ ├── summary_report.json
│ └── stereo_classifier_final.pth
└── utils/ # Modular code organization
├── dataset.py # StereoImageDataset class
├── model.py # StereoClassifier CNN architecture
├── utils.py # Training / evaluation / logging helpers
├── train.py # Training loop utilities
└── inference.py # Inference and visualization functions
Input: 2-channel, 9×9 pixel images
Convolutional backbone:
Conv2d(2 → 16, kernel_size=3, padding=1)+ ReLU + BatchNorm +MaxPool2d(2)Conv2d(16 → 32, kernel_size=3, padding=1)+ ReLU + BatchNorm +MaxPool2d(2)Conv2d(32 → 64, kernel_size=3, padding=1)+ ReLU + BatchNorm
Classifier head:
- Flatten
- Fully connected:
256 → 64 → 32 → 2(2 output classes) - Dropout:
p = 0.3(first FC layer),p = 0.2(second FC layer)
- Pixel values normalized to [0, 1].
- NA values encoded as 255.
- NA values remapped to 0.
- Pixel values then normalized to [0, 1].
- Random horizontal flip (probability = 0.5)
- Random rotation in the range ±5°
Main experiment command:
python train.py --data_path ./img --batch_size 64 --lr 0.0005 --epochs 20 --experiment_name 20_epochs_test| Parameter | Value | Description |
|---|---|---|
| Batch size | 64 | Training batch size |
| Learning rate | 0.0005 | Adam optimizer LR |
| Epochs | 20 | Total training epochs |
| Weight decay | 1e-4 | L2 regularization |
| Early stopping | 10 | Patience (epochs) |
| Train / Val / Test | 70 / 15 / 15 | Data split (percent) |
- Optimizer: Adam (with weight decay)
- Loss: CrossEntropyLoss
- LR scheduling: ReduceLROnPlateau
- Validation: Per epoch with detailed metrics
- Checkpointing: Best model saved as
stereo_classifier_final.pth
- Convergence: Rapid convergence within the first ~5 epochs
- Stability: Minimal oscillation in loss and accuracy
- Generalization: Validation performance comparable to or better than training (good sign)
- Efficiency: Training completes in under ~5 minutes on GPU (for the main experiment setup)
# Use default arguments
python train.pypython train.py --data_path ./img --batch_size 32 --lr 0.001 --epochs 50 --experiment_name custom_runpython train.py --data_path ./img --batch_size 64 --lr 0.0005 --epochs 20 --experiment_name my_experiment --no_cuda # Force CPU-only traininglogs/20_epochs_test/summary_report.json— Executive summary of the experimentlogs/20_epochs_test/training_config.json— Full configuration used for training
logs/20_epochs_test/images/training_curves.png— Loss and accuracy over epochslogs/20_epochs_test/images/confusion_matrix.png— Classification error structurelogs/20_epochs_test/images/roc_curve.png— ROC curvelogs/20_epochs_test/images/precision_recall_curve.png— Precision–Recall profilelogs/20_epochs_test/images/inference_grid_1.png— Sample predictionslogs/20_epochs_test/metrics/evaluation_metrics.json— Global performance metricslogs/20_epochs_test/metrics/classification_report.csv— Per-class metricslogs/20_epochs_test/metrics/confusion_matrix.csv— Numerical confusion matrix