GAN-based image composition via Spatial Transformer Networks and HomographyNet. Jointly learns geometric alignment and adversarial appearance adaptation to composite donor regions (glasses) onto recipient face images from CelebA.
Research implementation of GAN-based image composition combining Homography estimation (HomographyNet), Spatial Transformer Networks, and a GAN discriminator to splice image regions — specifically glasses onto faces — with geometric alignment and appearance adaptation.
Compositing a donor region (e.g., glasses) into a recipient image (a face) requires:
- Geometric alignment — the donor must be warped to fit the recipient's geometry
- Appearance adaptation — lighting and colour must match the recipient
This project trains a Spatial Transformer GAN on the CelebA face dataset to jointly learn the affine warp parameters (scale, rotation, translation) and discriminate realistic from unrealistic composites.
Donor (glasses) + Recipient (face)
↓
Geometric Predictor (CNN) → affine parameters [θ]
↓
Spatial Transformer → warped donor
↓
GAN Discriminator → realistic / unrealistic
↓
Composite output
flowchart LR
subgraph Input["Input"]
D["Donor\n(glasses patch)"]
R["Recipient\n(face, CelebA)"]
end
subgraph STN["Spatial Transformer GAN"]
D --> GP["Geometric Predictor\nCNN → θ (6 affine params)"]
R --> GP
GP --> ST["Spatial Transformer\nwarpAffine"]
ST --> WARP["Warped donor\n(aligned to face)"]
WARP --> DISC["GAN Discriminator\nreal / fake"]
R --> DISC
end
subgraph Loss["Training Losses"]
DISC --> ADV["Adversarial loss\n(GAN)"]
ST --> GEOM["Geometric loss\n(landmark distance)"]
end
WARP --> OUT["Composite output"]
Glasses composited onto CelebA aligned faces:
Multi-face results (faces in the wild, MTCNN detection):
GenerativeAdversarialNetwork-ImageComposition/
├── src/
│ └── stgan/ # Installable Python package
│ ├── __init__.py # Package metadata
│ ├── _geometry.py # Pure NumPy/SciPy geometry (fit, compose, inverse)
│ ├── warp.py # Differentiable STN ops (TF): vec2mtrx, transformImage
│ ├── graph.py # Network architectures: geometric predictor + discriminator
│ ├── data.py # CelebA data loading, batch sampling, history buffer
│ ├── utils.py # I/O, terminal colours, TensorBoard helpers, checkpointing
│ └── options.py # Argparse config with derived fields (refMtrx, warpDim, …)
├── train_stgan.py # Full ST-GAN training loop (wrapped in main())
├── eval_stgan.py # Model evaluation / inference (wrapped in main())
├── train_Donly.py # Discriminator pre-training script
├── scripts/
│ └── preprocess_celeba.py # CelebA → .npy with --celeba-path argparse
├── tests/
│ └── test_geometry.py # 12 tests for _geometry (no TF required)
├── dataset/
│ ├── attribute_train.npy # CelebA attribute vectors (train split)
│ ├── attribute_test.npy # CelebA attribute vectors (test split)
│ └── glasses.npy # RGBA glasses donor patches
├── images/ # Result visualisations
├── media/ # Development progress images
├── pyproject.toml # Build config (hatchling), uv deps, pytest settings
└── requirements.txt # Legacy flat deps (see pyproject.toml)
Requires Python 3.8+. Uses uv for dependency management.
git clone https://github.com/ashish-code/GenerativeAdversarialNetwork-ImageComposition.git
cd GenerativeAdversarialNetwork-ImageComposition
# Using uv (recommended)
uv sync --extra dev
# Or using pip
pip install -e ".[dev]"
# For training (adds TensorFlow)
pip install -e ".[train,dev]"TensorFlow note: The architecture uses the TF 1.x session-based API. For GPU training, install
tensorflow-gpu==1.15(CUDA 10.0 + cuDNN 7.6). TF 2.x runs in compatibility mode viatf.compat.v1.
Download CelebA and extract:
- Aligned & cropped images (
img_align_celeba/) - Attribute annotations (
list_attr_celeba.txt) - Train/val/test partitions (
list_eval_partition.txt)
python scripts/preprocess_celeba.py \
--celeba-path /data/CelebA \
--output-dir dataset/ \
--fraction 10 # use 10% of images; set to 1 for full dataset
# Outputs: dataset/image_train.npy, dataset/image_test.npy,
# dataset/attribute_{train,test}.npy# Single-stage warp (warpN=1)
python train_stgan.py --group 0 --model STGAN --warpN 1 --toIt 50000
# Resume from checkpoint at iteration 20000
python train_stgan.py --group 0 --model STGAN --warpN 1 --fromIt 20000 --toIt 50000
# Multi-stage (stacks warpN=2 on top of completed warpN=1 checkpoint)
python train_stgan.py --group 0 --model STGAN --warpN 2 --toIt 50000python eval_stgan.py \
--group 0 --model STGAN --warpN 1 \
--loadGP 0_STGAN \
--loadImage path/to/face.png# Run all 12 tests (no TensorFlow required)
pytest
# With coverage
pytest --cov=stgan --cov-report=term-missingFor unconstrained face images, face detection uses MTCNN (Multi-Task Cascaded CNNs) to locate faces before passing cropped regions to the spatial transformer:
Input image (multiple faces)
↓
MTCNN face detection → bounding boxes
↓
Crop + resize to CelebA format (218×178)
↓
Spatial Transformer GAN → composite glasses
↓
Reinsert composited face into original image
| Dataset | Classes | Images | Download |
|---|---|---|---|
| CelebA | 40 attributes | 202,599 | CelebA Project |
The glasses attribute (attribute index 15) is used as the primary training signal.
- Jaderberg, M. et al. (2015). Spatial Transformer Networks. NeurIPS.
- Zhang, K. et al. (2016). Joint Face Detection and Alignment using Multi-task Cascaded CNNs. IEEE SPL.
- Liu, Z. et al. (2015). Deep Learning Face Attributes in the Wild. ICCV.
- Goodfellow, I. et al. (2014). Generative Adversarial Nets. NeurIPS.
MIT — see LICENSE.








