Skip to content

niazoys/PlatonicTransformers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

22 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Platonic Transformers: A Solid Choice For Equivariance

arXiv PyTorch Python License

Platonic Transformers: A Solid Choice For Equivariance by Mohammad Mohaiminul Islam, Rishabh Anand, David R. Wessels, Friso de Kruiff, Thijs P. Kuipers, Rex Ying, Clara I. Sรกnchez, Sharvaree Vadgama, Georg Bรถkman, Erik J. Bekkers

Welcome to the Platonic Transformer project, where geometric group theory meets modern attention architectures ๐ŸŒŸ. This repository contains research code for Platonic Transformers, a drop-in way to add geometric inductive biases to vanilla Transformers.

Platonic Transformer Architecture

๐Ÿ“„ About the Paper

Platonic Transformers provide a drop-in method to build geometric inductive biases into the standard Transformer architecture, achieving approximate SE(2), E(2), SE(3), or E(3) equivariance at no additional computational cost. Our approach is based on:

  • Frame-relative attention. Point-wise features are lifted to functions on a finite roto-reflection group; each group element acts as a reference frame, and attention (with RoPE) runs in parallel across frames with shared weights.
  • Equivariance by design. This yields translation equivariance (via RoPE) and discrete roto-reflectional equivariance (via weight sharing over the chosen group), without changing the attention mechanism.
  • Dynamic group convolution. Omitting softmax turns attention into a linear-time, content-aware group convolution equivalent.
  • Cross-domain applicability. Competitive results across CIFAR-10 (images), ScanObjectNN (3D), QM9 & OMol25 (molecular learning).

โœจ Key Features

  • ๐Ÿ”ท Group-Equivariant Attention โ€” Based on the symmetries of Platonic solids (e.g., tetrahedron with 12, or octahedron 24 rotations).
  • ๐Ÿ”„ Unified Scalar/Vector I/O โ€” Equivariantly processes scalar and vector features as both input and output.
  • ๐Ÿ”ณ Generalizes Standard Transformers โ€” The standard Transformer architecture is recovered by choosing the trivial symmetry group.
  • ๐ŸŽฏ Multiple Benchmarks โ€” CIFAR-10, QM9 regression, ScanObjectNN, and OMol25.
  • โšก Linear-Time Variant โ€” Dynamic group convolution by dropping softmax.
  • ๐Ÿ› ๏ธ Easy to Use โ€” Unified meta_main.py entry point for all datasets.

๐Ÿš€ Quick Start

# Clone and setup
git clone https://github.com/niazoys/PlatonicTransformer.git
cd PlatonicTransformer
chmod +x setup.sh && ./setup.sh
source .venv/bin/activate


# Train on CIFAR-10 (loads configs/cifar10_deit.yaml)
python meta_main.py cifar10 --batch_size 256 --lr 8e-4

# Train on QM9 molecular properties (loads configs/qm9_regr.yaml)
python meta_main.py qm9_regr --target mu --batch_size 96

# Train on OMol energy/force regression (loads configs/omol.yaml)
python meta_main.py omol --predict_forces --force_weight 100

๐Ÿ“‚ Repository Structure

.
โ”œโ”€โ”€ meta_main.py             # ๐ŸŽฏ Unified entry point for all datasets
โ”œโ”€โ”€ configs/                 # Dataset-specific YAML configs
โ”œโ”€โ”€ data/                    # Downloaded datasets and artifacts
โ”œโ”€โ”€ mains/                   # Dataset-specific training scripts
โ”‚   โ”œโ”€โ”€ main_cifar10.py
โ”‚   โ”œโ”€โ”€ main_omol.py
โ”‚   โ””โ”€โ”€ main_qm9_regr.py
โ”œโ”€โ”€ models/
โ”‚   โ””โ”€โ”€ platoformer/         # Platonic Transformer building blocks
โ”‚       โ”œโ”€โ”€ block.py         # Core PlatonicBlock (attention + feedforward)
โ”‚       โ”œโ”€โ”€ conv.py          # Group convolution utilities
โ”‚       โ”œโ”€โ”€ groups.py        # Symmetry group definitions for Platonic solids
โ”‚       โ”œโ”€โ”€ io.py            # Lifting, pooling, dense/sparse utilities
โ”‚       โ”œโ”€โ”€ linear.py        # Equivariant linear projections
โ”‚       โ””โ”€โ”€ platoformer.py   # Full PlatonicTransformer module
โ”œโ”€โ”€ datasets/                # Dataset loaders for supported benchmarks
โ”œโ”€โ”€ utils/                   # Config loader and helper utilities
โ”œโ”€โ”€ requirements.txt         # Python dependencies
โ”œโ”€โ”€ setup.sh                 # Environment setup script

๐Ÿ”ง Installation

Prerequisites

  • Python 3.12+
  • CUDA 12.1+ (for GPU support)
  • PyTorch 2.3+

Setup

  1. Clone the repository and install system dependencies if needed.
  2. Create the environment:
    chmod +x setup.sh
    ./setup.sh
  3. Activate the environment:
    source .venv/bin/activate
  4. Authenticate with Weights & Biases (optional, for experiment tracking):
    wandb login

๐ŸŽฎ Usage

Available Datasets

Dataset Task Description
cifar10 Image Classification CIFAR-10 with patch-based point cloud representation
qm9_regr Molecular Property Prediction QM9 quantum chemistry dataset
omol Molecular Learning Open Molecular Learning dataset

Unified Entry Point

Use meta_main.py to run any dataset training script. Each dataset automatically loads its YAML configuration from configs/<dataset>.yaml. Pass --config path/to/custom.yaml to replace the entire config file, and layer additional CLI flags on top for quick tweaks:

# List available datasets
python meta_main.py --help

# Get help for a specific dataset (shows all available arguments)
# python meta_main.py scanobjectnn --help  # (temporarily disabled)
python meta_main.py qm9_regr --help

# Run training
# Swap in a different config
python meta_main.py cifar10 --config configs/cifar10_small.yaml

# Override individual keys from the active config
# python meta_main.py scanobjectnn --epochs 500 --batch_size 128  # (temporarily disabled)
python meta_main.py cifar10 --batch_size 256 --lr 8e-4
python meta_main.py qm9_regr --target mu --batch_size 96
python meta_main.py omol --predict_forces --force_weight 100

Direct Script Execution (Alternative)

You can also run scripts directly from the mains/ directory:

python mains/main_cifar10.py --batch_size 256 --lr 8e-4
python mains/main_qm9_regr.py --target alpha --batch_size 64
python mains/main_omol.py --predict_forces --force_weight 100

Common Configuration Flags

Model Architecture:

  • --solid_name - Platonic solid: {tetrahedron, octahedron, icosahedron, trivial_3} (default: octahedron)

  • --hidden_dim - Hidden dimension size

  • --layers - Number of transformer layers

  • --num_heads - Number of attention heads

    Note on Hidden Dimension: For the model to work correctly, --hidden_dim must be divisible by both the order of the chosen group (|G|) and the specified --num_heads. The internal dimensions for attention are calculated automatically from these values.

    Example: Let's say you use --solid_name tetrahedron, --hidden_dim 768, and --num_heads 48.

    • The tetrahedron group has an order |G| = 12.
    • The feature dimension per group element is hidden_dim / |G| = 768 / 12 = 64.
    • The dimension of each attention head is hidden_dim / num_heads = 768 / 48 = 16.
    • The number of independent heads applied to each group element's features is (hidden_dim / |G|) / (hidden_dim / num_heads) = 64 / 16 = 4.

    This means the model will run 4 attention heads per group element, where each head has a dimension of 16.

Positional Encodings:

  • --rope_sigma - Sigma for Rotational Positional Encoding (RoPE)
  • --ape_sigma - Sigma for Absolute Positional Encoding (APE)
  • --freq_init - Frequency initialization: {random, spiral}

Training:

  • --epochs - Number of training epochs
  • --batch_size - Training batch size
  • --lr - Learning rate
  • --weight_decay - Weight decay for optimizer
  • --seed - Random seed for reproducibility

System:

  • --gpus - Number of GPUs to use
  • --num_workers - Number of data loading workers
  • --log - Enable/disable WandB logging

๐Ÿ’ก Tip: Start with smaller --hidden_dim (e.g., 64) and fewer --layers to validate pipelines quickly!

๐Ÿง  Model Architecture

Platonic Transformers leverage the rotational symmetries of Platonic solids to enforce SE(3)-equivariance in attention mechanisms. The architecture is implemented in models/platoformer/.

Core Components

  • Lifting (io.py) - Maps scalar and vector node features to group-aligned channels
  • Attention Blocks (block.py) - Stacked PlatonicBlock layers with group-aware attention and equivariant MLPs
  • Equivariant Convolutions (conv.py) - SE(3)-equivariant convolution layers
  • Group Theory (groups.py) - Platonic solid symmetry group implementations
  • Positional Encodings - Dual encoding strategy:
    • RoPE (rope.py) - Rotational Positional Encoding for relative positions
    • APE (ape.py) - Absolute Positional Encoding for global context
  • Readout (io.py) - Separate scalar/vector readouts with pooling for graph or node-level predictions

Supported Platonic Solids

CLI label(s) Dim Type Order (|G|) Notes / Typical use
trivial 3 Identity only 1 3D baseline (no rotational bias); translation handled via RoPE.
trivial_n (n = 2โ€ฆ10) n Identity only 1 Identity-only group in chosen dimension; e.g., trivial_2, trivial_3, โ€ฆ
tetrahedron 3 Platonic rotational 12 Default: lightweight 3D rotational equivariance; fewer frames/compute.
octahedron 3 Platonic rotational 24 Higher capacity than tetra; balanced accuracy/compute.
icosahedron 3 Platonic rotational 60 Highest rotational expressivity; most frames/compute.
octahedron_reflections 3 Axis-aligned reflections (x/y/z) 8 Independent flips about x, y, z; useful when parity (mirror) cues matter.
cyclic_n (n = 2โ€ฆ20) 2 Rotation-only (n) 2D discrete rotations; e.g., cyclic_4, cyclic_6.
dihedral_n (n = 2โ€ฆ20) 2 Rotations + reflections (2n) 2D rotations and mirror symmetry; e.g., dihedral_4, dihedral_6.
flop_2d_<axis> (axis = 1, 2) 2 Single-axis reflection 2 Axis 1: reflect across x-axis (flip y); Axis 2: reflect across y-axis (flip x).
flop_3d_<axis> (axis = 1, 2, 3) 3 Single-axis reflection 2 Axis 1: YZ-plane (flip x); Axis 2: XZ-plane (flip y); Axis 3: XY-plane (flip z).

Examples

# Default (3D rotational, 12 frames)
python meta_main.py omol --solid_name tetrahedron ...

# 2D rotation-only / rotations+reflections
python meta_main.py cifar10 --solid_name cyclic_4 ...
python meta_main.py cifar10 --solid_name dihedral_6 ...

๐Ÿ“Š Datasets

CIFAR-10 (cifar10)

  • Task: Image Classification (10 classes)
  • Representation: Patches converted to point clouds
  • Key Args: --patch_size, --num_points_per_patch

QM9 (qm9_regr)

  • Task: Molecular Property Regression
  • Properties: 12 quantum chemical properties (e.g., dipole moment ฮผ, HOMO-LUMO gap)
  • Key Args: --target {mu, alpha, homo, lumo, ...}, --use_bonds

Open Molecular (omol)

  • Task: Molecular Property Prediction with LMDB backend
  • Features: Large-scale molecular learning with atomic embeddings
  • Key Args: --radius, --max_neighbors

๐Ÿ“– Citation

If you use Platonic Transformers in your research, please cite:

@misc{islam2025platonictransformerssolidchoice,
      title={Platonic Transformers: A Solid Choice For Equivariance}, 
      author={Mohammad Mohaiminul Islam and Rishabh Anand and David R. Wessels and Friso de Kruiff and Thijs P. Kuipers and Rex Ying and Clara I. Sรกnchez and Sharvaree Vadgama and Georg Bรถkman and Erik J. Bekkers},
      year={2025},
      eprint={2510.03511},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2510.03511}, 
}

๐Ÿ“ License

This project is licensed under the MIT License - see the LICENSE file for details.

๐Ÿค Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.

๐Ÿ“ฌ Contact

For questions or issues:

  • Open an issue on GitHub
  • Email us here

About

Platonic Transformers: A Solid Choice For Equivariance

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 5

Languages