This repository contains the official implementation of the paper:
Graph Transformer Networks for Accurate Band Structure Prediction: An End-to-End Approach
Weiyi Gong, Tao Sun, Hexin Bai, Jeng-Yuan Tsai, Haibin Ling, Qimin Yan
arXiv:2411.16483
Predicting electronic band structures from crystal structures is crucial for understanding structure-property correlations in materials science. First-principles approaches are accurate but computationally intensive. Here, we introduce a graph Transformer-based end-to-end approach that directly predicts band structures from crystal structures with high accuracy. Our method leverages the continuity of the k-path and treats continuous bands as a sequence. We demonstrate that our model not only provides accurate band structure predictions but also can derive other properties (such as band gap, band center, and band dispersion) with high accuracy.
If you find this work useful in your research, please cite our paper:
@misc{gong2024graphtransformernetworksaccurate,
title={Graph Transformer Networks for Accurate Band Structure Prediction: An End-to-End Approach},
author={Weiyi Gong and Tao Sun and Hexin Bai and Jeng-Yuan Tsai and Haibin Ling and Qimin Yan},
year={2024},
eprint={2411.16483},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci},
url={https://arxiv.org/abs/2411.16483},
}The project relies on the following key packages:
- Python >= 3.11
- PyTorch == 2.8
- PyTorch Geometric == 2.7.0
- Pymatgen >= 2025.10.7
- ASE >= 3.26.0
- NumPy >= 2.3.5
- WandB >= 0.23.0
First, clone the repository:
git clone https://github.com/username/bandformer.git
cd bandformerWe recommend using uv for efficient and reliable dependency management. For detailed documentation, please visit the official uv website.
-
Install uv
If
uvis not already installed, use the standalone installer:curl -LsSf https://astral.sh/uv/install.sh | sh -
Install Dependencies
The
pyproject.tomlis configured withfind-linksfor CUDA 12.6 (cu126). If you have a different CUDA version, update thefind-linksURL inpyproject.tomlbefore runninguv sync:[tool.uv] find-links = ["https://data.pyg.org/whl/torch-2.8.0+cu126.html"] # ^^^^^ change to match your CUDA version # e.g., cu121, cu124
Then create the virtual environment and install all dependencies:
uv sync
If the download times out for large packages, increase the timeout:
UV_HTTP_TIMEOUT=300 uv sync
Activate the environment:
source .venv/bin/activate
For users preferring Conda or standard Pip, follow these steps:
-
Create and Activate Environment
conda create -n bandformer python=3.11 conda activate bandformer
-
Install Dependencies
First, install PyTorch (version 2.8 recommended):
pip install torch==2.8.0 --index-url https://download.pytorch.org/whl/cu126
(Note: Adjust the CUDA version
cu126based on your system configuration.)Then, install the remaining dependencies:
pip install -e .Alternatively, manually install the requirements:
pip install setuptools pyyaml ase ase-db-backends numpy pymatgen wandb pip install torch-geometric==2.7.0 pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
The raw data was obtained from Materials Project's repository at AWS Open Data Program (information available at https://docs.materialsproject.org/downloading-data/aws-opendata). We downloaded 705k (1.5 TB) parsed band structures from bucket 'materialsproject-parsed'. The data contains band structures both using uniform mode and symm-line mode. From symm-line mode data, we manually chose non-magnetic materials and cleaned the data by removing outliers. Eventually we obtained around 27k band structures. The data is split by prepare.py into 90% for training and 10% for evaluation.
The model expects the dataset file nm-6-cleaned-maxlen-30.pt to be present in the data/ directory.
-
Download Data: Download the dataset (~254 MB) from Figshare into the
data/folder:wget -O data/nm-6-cleaned-maxlen-30.pt https://ndownloader.figshare.com/files/59214011
-
Prepare Splits: Run the preparation script to generate training and validation splits (
train.ptandval.pt).python data/prepare.py
The training script train.py supports both single-GPU and Distributed Data Parallel (DDP) training. Configuration is handled via configs/train.yaml.
You can modify training hyperparameters in configs/train.yaml, such as:
batch_sizelearning_raten_layer,n_head,n_embd(Model architecture)wandb_log(Logging)
To run training on a single GPU:
python train.py --config configs/train.yamlYou can also override configuration parameters from the command line:
python train.py --batch_size 32 --compile FalseTo run on multiple GPUs (e.g., 4 GPUs on a single node):
torchrun --standalone --nproc_per_node=4 train.py --config configs/train.yamlWe acknowledge the following resources that were helpful in the development of this project:
- Transformer Implementation: We referred to PyTorch's official implementation and The Annotated Transformer.
- Open Source Repositories:
