Recreate the paper "Leveraging Persistent Homology Features for Accurate Defect Formation Energy Predictions via Graph Neural Networks" (Fang & Yan, Chem. Mater. 2025) in C++, with eventual WebAssembly compilation for browser-based visualization.
Paper claims to reproduce:
- 55% MAE reduction with persistent homology features
- Global max pooling outperforms mean pooling for defect tasks
- Transformer architecture performs best (MAE: 0.72 eV)
# Fork https://github.com/qmatyanlab/Defect_GNN
# Clone your fork
git clone https://github.com/<your-username>/Defect_GNN.git
cd Defect_GNN
# Extract dataset
cd datasets
tar -xzf raw.tar.gzWhat you get:
data.csv— Labels (pristine_idx, vacancy_idx, formation_energy)atomic_embedding_CGCNN.json— 92-dim embeddings per elementpristine_structures/— 1112 VASP filesdefective_structures/— 7753 VASP files
git clone https://github.com/qmatyanlab/Configurational-Disorder.gitFiles to reference:
| File | Purpose |
|---|---|
Disorder_GNN/model/model_embedding.py |
CGNN, GAT, Transformer architectures |
Disorder_GNN/GNN.py |
Training loop, Optuna hyperparameter search |
Disorder_GNN/dataset.py |
PyG dataset class pattern |
Disorder_GNN/utilities.py |
Graph construction |
defect-gnn-cpp/
│
├── CMakeLists.txt
├── Makefile # Convenience wrapper
├── README.md
│
├── data/ # Symlink or copy from fork
│ ├── raw/
│ │ ├── data.csv
│ │ ├── atomic_embedding_CGCNN.json
│ │ ├── pristine_structures/
│ │ └── defective_structures/
│ └── processed/ # Generated at runtime
│ ├── graphs/ # Serialized crystal graphs
│ └── betti/ # Precomputed topology features
│
├── include/
│ ├── io/
│ │ ├── vasp_parser.hpp
│ │ ├── csv_parser.hpp
│ │ └── json_parser.hpp
│ │
│ ├── crystal/
│ │ └── structure.hpp # Contains Atom class + Structure class
│ │
│ ├── graph/
│ │ ├── crystal_graph.hpp
│ │ ├── neighbor_list.hpp
│ │ └── edge_features.hpp
│ │
│ ├── topology/
│ │ ├── ripser_wrapper.hpp # C++ Ripser integration
│ │ ├── betti_features.hpp
│ │ └── pca.hpp
│ │
│ ├── nn/
│ │ ├── tensor.hpp # Eigen aliases + activations
│ │ ├── linear.hpp
│ │ ├── pool.hpp # global_max_pool + global_mean_pool
│ │ ├── model.hpp
│ │ └── conv/
│ │ ├── message_passing.hpp
│ │ ├── cgconv.hpp
│ │ ├── gatv2conv.hpp
│ │ └── transformer_conv.hpp
│ │
│ ├── train/
│ │ ├── dataset.hpp
│ │ ├── dataloader.hpp
│ │ ├── optimizer.hpp
│ │ └── trainer.hpp
│ │
│ └── utils/
│ ├── math.hpp
│ └── logging.hpp
│
├── src/
│ ├── io/
│ │ ├── vasp_parser.cpp
│ │ ├── csv_parser.cpp
│ │ └── json_parser.cpp
│ │
│ ├── crystal/
│ │ └── structure.cpp
│ │
│ ├── graph/
│ │ ├── crystal_graph.cpp
│ │ ├── neighbor_list.cpp
│ │ └── edge_features.cpp
│ │
│ ├── topology/
│ │ ├── ripser_wrapper.cpp # Ripser integration impl
│ │ ├── betti_features.cpp
│ │ └── pca.cpp
│ │
│ ├── nn/
│ │ ├── tensor.cpp # Activations impl
│ │ ├── linear.cpp
│ │ ├── pool.cpp
│ │ ├── model.cpp
│ │ └── conv/
│ │ ├── message_passing.cpp
│ │ ├── cgconv.cpp
│ │ ├── gatv2conv.cpp
│ │ └── transformer_conv.cpp
│ │
│ ├── train/
│ │ ├── dataset.cpp
│ │ ├── dataloader.cpp
│ │ ├── optimizer.cpp
│ │ └── trainer.cpp
│ │
│ └── main.cpp
│
├── tests/
│ ├── test_vasp_parser.cpp
│ ├── test_graph_construction.cpp
│ ├── test_betti_features.cpp
│ ├── test_convolutions.cpp
│ └── test_forward_pass.cpp
│
├── scripts/
│ └── validate_against_pytorch.py
│
└── third_party/
├── eigen/ # Header-only
├── nlohmann_json/ # Header-only
├── nanoflann/ # Header-only (KD-tree)
├── ripser/ # C++ persistent homology
└── spdlog/ # Header-only logging
Reference: Defect_GNN/Betti_number.py:15-46 (get_prim_structure_info)
File: include/io/vasp_parser.hpp
namespace defect_gnn::io {
struct VASPStructure {
Eigen::Matrix3d lattice; // 3x3 lattice vectors (rows)
std::vector<std::string> elements; // Element symbols
std::vector<int> counts; // Count per element
Eigen::MatrixXd frac_coords; // Nx3 fractional coordinates
std::vector<int> atom_types; // Element index per atom
};
// Parse VASP POSCAR/CONTCAR format
// Reference: Betti_number.py:15-46
VASPStructure parse_vasp(const std::string& filepath);
// Convert fractional to Cartesian coordinates
Eigen::MatrixXd frac_to_cart(const Eigen::Matrix3d& lattice,
const Eigen::MatrixXd& frac_coords);
} // namespace defect_gnn::ioVASP format reminder:
Line 1: Comment
Line 2: Scale factor
Lines 3-5: Lattice vectors (3x3)
Line 6: Element symbols
Line 7: Element counts
Line 8: "Direct" or "Cartesian"
Lines 9+: Coordinates
File: include/crystal/structure.hpp
namespace defect_gnn::crystal {
class Atom {
public:
int element; // Atomic number
Eigen::Vector3d position; // Cartesian coordinates
Eigen::Vector3d frac_position; // Fractional coordinates
};
class Structure {
public:
Structure(const io::VASPStructure& vasp);
const Eigen::Matrix3d& lattice() const;
const std::vector<Atom>& atoms() const;
size_t num_atoms() const;
// Get minimum image distance (periodic boundary conditions)
double distance(size_t i, size_t j) const;
Eigen::Vector3d displacement(size_t i, size_t j) const;
private:
Eigen::Matrix3d lattice_;
Eigen::Matrix3d inv_lattice_; // For PBC
std::vector<Atom> atoms_;
};
} // namespace defect_gnn::crystalReference: Defect_GNN/utilities.py:34-71 (structureToGraph)
File: include/graph/neighbor_list.hpp
namespace defect_gnn::graph {
struct Neighbor {
size_t idx;
double distance;
Eigen::Vector3d displacement; // Vector from center to neighbor
};
class NeighborList {
public:
// r_cutoff=10.0 Å, max_neighbors=20 (from paper)
NeighborList(const crystal::Structure& structure,
double r_cutoff = 10.0,
size_t max_neighbors = 20);
const std::vector<Neighbor>& neighbors(size_t atom_idx) const;
private:
// Use nanoflann KD-tree for efficient neighbor search
// Must handle periodic boundary conditions by replicating atoms
void build_with_pbc(const crystal::Structure& structure);
std::vector<std::vector<Neighbor>> neighbor_lists_;
};
} // namespace defect_gnn::graphReference: Defect_GNN/utilities.py:25-32 (calculateEdgeAttributes)
File: include/graph/edge_features.hpp
namespace defect_gnn::graph {
// Gaussian radial basis function expansion
// r_cutoff=10.0, dr=0.1 → 100 bins
// sigma = r_cutoff / 3
//
// Formula: exp(-0.5 * (r_grid - dist)^2 / sigma^2) / sqrt(2*pi) / sigma
//
// Reference: utilities.py:25-32
Eigen::VectorXd gaussian_rbf(double distance,
double r_cutoff = 10.0,
double dr = 0.1);
} // namespace defect_gnn::graphFile: include/graph/crystal_graph.hpp
namespace defect_gnn::graph {
class CrystalGraph {
public:
CrystalGraph(const crystal::Structure& structure,
const NeighborList& neighbors,
const std::map<int, Eigen::VectorXd>& atom_embeddings,
double r_cutoff = 10.0);
// Node features: [num_atoms, feature_dim]
// Initially 92-dim CGCNN embeddings, later + topo features
const Eigen::MatrixXd& node_features() const;
// Edge index: [2, num_edges] - COO format (source, target)
const Eigen::MatrixXi& edge_index() const;
// Edge attributes: [num_edges, 100] - Gaussian RBF
const Eigen::MatrixXd& edge_attr() const;
// Target value (formation energy)
double target() const;
void set_target(double y);
// Add topological features to node features
void add_topo_features(const Eigen::MatrixXd& topo); // [num_atoms, n_pca]
size_t num_nodes() const;
size_t num_edges() const;
private:
Eigen::MatrixXd node_features_;
Eigen::MatrixXi edge_index_;
Eigen::MatrixXd edge_attr_;
double target_;
};
} // namespace defect_gnn::graphDecision: Use C++ Ripser directly for WebAssembly compatibility and self-contained builds.
Ripser Source: https://github.com/Ripser/ripser
Step 1: Clone Ripser into third_party
cd third_party
git clone https://github.com/Ripser/ripser.git
# Remove .git to avoid nested repo issues
rm -rf ripser/.gitStep 2: Create wrapper header include/topology/ripser_wrapper.hpp
#pragma once
#include <Eigen/Dense>
#include <vector>
namespace defect_gnn::topology {
// Wrapper around Ripser's C++ implementation
// Ripser computes persistent homology via the Vietoris-Rips complex
struct PersistencePair {
double birth;
double death;
double persistence() const { return death - birth; }
};
using PersistenceDiagram = std::vector<PersistencePair>;
struct PersistenceResult {
PersistenceDiagram dim0; // Connected components (β₀)
PersistenceDiagram dim1; // Loops (β₁)
PersistenceDiagram dim2; // Voids (β₂)
};
// Compute persistence from distance matrix
// This wraps Ripser's core functionality
// max_dim: highest homology dimension to compute (0, 1, or 2)
// threshold: maximum filtration value (r_cutoff for Betti features)
PersistenceResult compute_persistence_from_distances(
const Eigen::MatrixXd& distance_matrix,
int max_dim = 2,
double threshold = 2.5
);
// Convenience: compute from point cloud (builds distance matrix internally)
PersistenceResult compute_persistence(
const Eigen::MatrixXd& point_cloud, // Nx3
double max_radius = 2.5
);
} // namespace defect_gnn::topologyStep 3: Implementation notes for src/topology/ripser_wrapper.cpp
// Ripser uses a specific input format. Key integration points:
//
// 1. Ripser expects a lower-triangular distance matrix or point cloud
// 2. Main computation is in ripser() function
// 3. Output is persistence pairs per dimension
//
// Integration approach:
// - Include ripser.cpp with modifications for library use
// - Or extract core computation into callable functions
// - Ripser uses compressed_lower_distance_matrix internally
//
// Build flag needed: -DUSE_COEFFICIENTS (for field coefficients)Add to CMakeLists.txt:
# Ripser (persistent homology)
if(EXISTS "${CMAKE_SOURCE_DIR}/third_party/ripser/ripser.cpp")
set(RIPSER_SOURCES "${CMAKE_SOURCE_DIR}/third_party/ripser/ripser.cpp")
message(STATUS "Using Ripser for persistent homology")
# Ripser compile options
add_compile_definitions(USE_COEFFICIENTS)
else()
message(WARNING "Ripser not found. Run: make deps")
endif()Reference: Defect_GNN/Betti_number.py:93-224
Key parameters:
r_cutoff = 2.5Å (different from graph cutoff!)maxdim = 2(compute β₀, β₁, β₂)
The persistence computation is handled by ripser_wrapper.hpp (defined in 2.1.1 above).
Reference: Defect_GNN/Betti_number.py:143-254
Algorithm:
- For each atom i in structure:
- For each element type e:
- Build point cloud: atom i + all atoms of type e within cutoff
- Compute persistence diagrams
- Extract 35 statistical features
File: include/topology/betti_features.hpp
namespace defect_gnn::topology {
// 35 features per atom:
// - β₀: 5 stats (mean, std, max, min, weighted_sum of death times)
// - β₁: 15 stats (5 stats each for birth, death, persistence)
// - β₂: 15 stats (same as β₁)
//
// Reference: Betti_number.py:143-224 (getElementalBettiProperties)
struct BettiStatistics {
double mean, std, max, min, weighted_sum;
};
// Extract statistics from persistence diagram
BettiStatistics compute_statistics(const PersistenceDiagram& diagram,
bool use_death = true); // birth or death
// Compute all 35 features for one atom
// Reference: Betti_number.py:226-254 (getElementalBettiFeatures)
Eigen::VectorXd compute_atom_betti_features(
const crystal::Structure& structure,
size_t atom_idx,
double r_cutoff = 2.5
);
// Compute features for all atoms in structure
// Returns [num_atoms, 35] matrix
Eigen::MatrixXd compute_structure_betti_features(
const crystal::Structure& structure,
double r_cutoff = 2.5
);
} // namespace defect_gnn::topologyReference: Defect_GNN/dataset_PCA.py:80-81
File: include/topology/pca.hpp
namespace defect_gnn::topology {
class PCA {
public:
// Fit PCA on training data
// n_components: 1, 2, 4, 6, 8, or 10 (paper tests these)
void fit(const Eigen::MatrixXd& X, int n_components = 6);
// Transform new data
Eigen::MatrixXd transform(const Eigen::MatrixXd& X) const;
// Save/load fitted PCA
void save(const std::string& path) const;
void load(const std::string& path);
// Explained variance ratio
Eigen::VectorXd explained_variance_ratio() const;
private:
Eigen::MatrixXd components_; // [n_components, 35]
Eigen::VectorXd mean_; // [35]
Eigen::VectorXd explained_var_;
};
} // namespace defect_gnn::topologyRecommendation: Start with custom implementation for learning, then optionally port to libtorch for GPU.
Use Eigen for all matrix operations. Define convenience aliases.
File: include/nn/tensor.hpp
namespace defect_gnn::nn {
using Tensor1D = Eigen::VectorXd;
using Tensor2D = Eigen::MatrixXd;
// Activation functions
Tensor2D relu(const Tensor2D& x);
Tensor2D leaky_relu(const Tensor2D& x, double negative_slope = 0.2);
Tensor2D softmax(const Tensor2D& x, int dim = -1); // Along rows or cols
// Batch normalization (inference mode)
Tensor2D batch_norm(const Tensor2D& x,
const Tensor1D& gamma,
const Tensor1D& beta,
const Tensor1D& running_mean,
const Tensor1D& running_var,
double eps = 1e-5);
} // namespace defect_gnn::nnFile: include/nn/linear.hpp
namespace defect_gnn::nn {
class Linear {
public:
Linear(int in_features, int out_features, bool bias = true);
// Forward: Y = XW^T + b
Tensor2D forward(const Tensor2D& x) const;
// Weight access for loading pretrained
void set_weight(const Tensor2D& W);
void set_bias(const Tensor1D& b);
private:
Tensor2D weight_; // [out_features, in_features]
Tensor1D bias_; // [out_features]
bool use_bias_;
};
} // namespace defect_gnn::nnReference: PyTorch Geometric MessagePassing class
File: include/nn/conv/message_passing.hpp
namespace defect_gnn::nn {
class MessagePassing {
public:
// Core message passing:
// 1. message(x_i, x_j, edge_attr) - compute messages
// 2. aggregate(messages, edge_index) - sum/mean/max per node
// 3. update(x_i, aggregated) - update node features
virtual Tensor2D forward(
const Tensor2D& x, // [N, F_in]
const Eigen::MatrixXi& edge_index, // [2, E]
const Tensor2D& edge_attr // [E, D]
) = 0;
protected:
// Aggregation: scatter_add, scatter_mean, etc.
Tensor2D scatter_add(const Tensor2D& src,
const Eigen::VectorXi& index,
int dim_size);
virtual ~MessagePassing() = default;
};
} // namespace defect_gnn::nnReference: Configurational-Disorder/Disorder_GNN/model/model_embedding.py
Paper: Xie & Grossman, Phys. Rev. Lett. 2018
File: include/nn/conv/cgconv.hpp
namespace defect_gnn::nn {
// Crystal Graph Convolution
// x_i' = x_i + Σ_j σ(z_ij^f) ⊙ g(z_ij^s)
// where z_ij = concat(x_i, x_j, e_ij) @ W
//
// Reference: model_embedding.py CGConv usage
class CGConv : public MessagePassing {
public:
CGConv(int channels, int edge_dim, bool batch_norm = true);
Tensor2D forward(const Tensor2D& x,
const Eigen::MatrixXi& edge_index,
const Tensor2D& edge_attr) override;
void load_weights(/* weight dict */);
private:
Linear lin_f_; // Filter network
Linear lin_s_; // Core network
// BatchNorm params if used
};
} // namespace defect_gnn::nnReference: model_embedding.py uses GATv2Conv
Paper: Brody et al., "How Attentive are Graph Attention Networks?"
File: include/nn/conv/gatv2conv.hpp
namespace defect_gnn::nn {
// GATv2: fixes static attention problem
// α_ij = softmax_j(a^T LeakyReLU(W[x_i || x_j || e_ij]))
// x_i' = Σ_j α_ij W_t x_j
//
// Multi-head: concatenate or average heads
class GATv2Conv : public MessagePassing {
public:
GATv2Conv(int in_channels, int out_channels,
int heads = 1, int edge_dim = -1,
double negative_slope = 0.2,
double dropout = 0.0); // Dropout only in training
Tensor2D forward(const Tensor2D& x,
const Eigen::MatrixXi& edge_index,
const Tensor2D& edge_attr) override;
private:
int heads_;
Linear lin_l_, lin_r_; // Left and right projections
Linear lin_edge_; // Edge feature projection
Tensor1D att_; // Attention vector
double negative_slope_;
};
} // namespace defect_gnn::nnReference: model_embedding.py uses TransformerConv
Paper: Shi et al., "Masked Label Prediction"
File: include/nn/conv/transformer_conv.hpp
namespace defect_gnn::nn {
// Graph Transformer
// α_ij = softmax_j((W_Q x_i)^T (W_K x_j + W_E e_ij) / sqrt(d))
// x_i' = W_1 x_i + Σ_j α_ij (W_V x_j + W_E e_ij)
class TransformerConv : public MessagePassing {
public:
TransformerConv(int in_channels, int out_channels,
int heads = 1, int edge_dim = -1);
Tensor2D forward(const Tensor2D& x,
const Eigen::MatrixXi& edge_index,
const Tensor2D& edge_attr) override;
private:
int heads_;
Linear lin_query_, lin_key_, lin_value_;
Linear lin_edge_;
Linear lin_skip_; // Skip connection W_1
};
} // namespace defect_gnn::nnCRITICAL: Paper shows global_max_pool is key for defect tasks!
Reference: Paper Section "Model Performance" + Table 2
File: include/nn/pool.hpp
#pragma once
#include "nn/tensor.hpp"
#include <Eigen/Dense>
namespace defect_gnn::nn {
// Global max pooling: take max over all nodes per graph
// batch: [N] tensor indicating which graph each node belongs to
// Returns: [num_graphs, F]
//
// IMPORTANT: Paper shows this outperforms mean pooling for defects
Tensor2D global_max_pool(const Tensor2D& x,
const Eigen::VectorXi& batch,
int num_graphs);
// Also implement mean for comparison
Tensor2D global_mean_pool(const Tensor2D& x,
const Eigen::VectorXi& batch,
int num_graphs);
} // namespace defect_gnn::nnReference: model_embedding.py - all three model classes
File: include/nn/model.hpp
namespace defect_gnn::nn {
enum class ModelType { CGNN, GAT, Transformer };
enum class PoolType { Max, Mean };
class DefectGNN {
public:
DefectGNN(ModelType type,
int num_hidden_layers,
int num_hidden_channels,
int num_edge_features,
int num_topo_features, // PCA components (0 to disable)
int num_heads = 1, // For GAT/Transformer
PoolType pool = PoolType::Max);
// Forward pass
// Returns: [batch_size] formation energy predictions
Tensor1D forward(const std::vector<graph::CrystalGraph>& batch);
// Load pretrained weights from PyTorch checkpoint
void load_weights(const std::string& path);
private:
// Embedding: maps atomic number to hidden_channels - topo_features
Tensor2D embedding_; // [100, embed_dim]
// Convolution layers
std::vector<std::unique_ptr<MessagePassing>> conv_layers_;
// MLP head
Linear lin1_, lin2_;
// BatchNorm after lin1
int num_topo_features_;
PoolType pool_type_;
};
} // namespace defect_gnn::nnKey difference from Configurational-Disorder code:
// In forward():
// 1. Embed atomic numbers
x = embedding_lookup(data.x); // [N, hidden - topo_dim]
// 2. Concatenate topo features (THE KEY MODIFICATION)
if (num_topo_features_ > 0) {
x = concat(x, data.topo_features, dim=1); // [N, hidden]
}
// 3. Message passing with SKIP CONNECTIONS (sum features from all layers)
global_feature = zeros(batch_size, hidden);
for (auto& conv : conv_layers_) {
x = relu(conv->forward(x, edge_index, edge_attr));
global_feature += global_max_pool(x, batch); // MAX not mean!
}
// 4. MLP
return lin2(relu(lin1(global_feature)));Reference: Defect_GNN/dataset_PCA.py + Configurational-Disorder/Disorder_GNN/dataset.py
File: include/train/dataset.hpp
namespace defect_gnn::train {
class DefectDataset {
public:
DefectDataset(const std::string& root,
int n_pca_components = 6,
bool precompute_betti = true);
// Access
const graph::CrystalGraph& operator[](size_t idx) const;
size_t size() const;
// Shuffle for training
void shuffle();
// Split into train/val/test
static std::tuple<DefectDataset, DefectDataset, DefectDataset>
random_split(DefectDataset& full,
double train_ratio = 0.6,
double val_ratio = 0.2);
private:
std::vector<graph::CrystalGraph> graphs_;
topology::PCA pca_; // Fitted on training set
};
} // namespace defect_gnn::trainFile: include/train/dataloader.hpp
namespace defect_gnn::train {
// Collate multiple graphs into a batch
// - Concatenate node features with batch index
// - Offset edge indices
// - Stack targets
struct BatchedGraph {
Eigen::MatrixXd x; // [total_nodes, F]
Eigen::MatrixXi edge_index; // [2, total_edges]
Eigen::MatrixXd edge_attr; // [total_edges, D]
Eigen::MatrixXd topo; // [total_nodes, n_pca]
Eigen::VectorXi batch; // [total_nodes] - graph membership
Eigen::VectorXd y; // [batch_size] - targets
int num_graphs;
};
class DataLoader {
public:
DataLoader(DefectDataset& dataset, int batch_size, bool shuffle = true);
// Iterator interface
class Iterator;
Iterator begin();
Iterator end();
size_t num_batches() const;
private:
DefectDataset& dataset_;
int batch_size_;
std::vector<size_t> indices_;
};
} // namespace defect_gnn::trainFile: include/train/optimizer.hpp
namespace defect_gnn::train {
class Adam {
public:
Adam(/* model parameters */, double lr = 1e-3,
double beta1 = 0.9, double beta2 = 0.999,
double weight_decay = 5e-4, double eps = 1e-8);
void step(/* gradients */);
void zero_grad();
private:
// First and second moment estimates for each parameter
};
} // namespace defect_gnn::trainNote: For inference-only (forward pass), you don't need backprop. Training in C++ is educational but not required if you train in Python and export weights.
Reference: Configurational-Disorder/Disorder_GNN/GNN.py
File: include/train/trainer.hpp
namespace defect_gnn::train {
struct TrainConfig {
int max_epochs = 200;
int batch_size = 128;
double lr = 0.01;
double weight_decay = 5e-4;
bool early_stopping = true;
int patience = 8;
};
class Trainer {
public:
Trainer(nn::DefectGNN& model, const TrainConfig& config);
void fit(DefectDataset& train_set, DefectDataset& val_set);
double evaluate(DefectDataset& test_set); // Returns MAE
// Loss: MAE (paper uses sum of abs errors)
double compute_loss(const Tensor1D& pred, const Tensor1D& target);
private:
nn::DefectGNN& model_;
TrainConfig config_;
std::vector<double> train_losses_, val_losses_;
};
} // namespace defect_gnn::train- Unit tests for each component
- Integration test: Compare forward pass output with PyTorch
- Reproduce paper results: Train from scratch, verify MAE
File: scripts/validate_against_pytorch.py
"""
Load same weights in both PyTorch and C++,
run forward pass on same input,
compare outputs.
"""
# 1. Export PyTorch model weights to JSON/binary
# 2. Run C++ forward pass
# 3. Compare numerically (should match to ~1e-5)| Model | Pooling | Topo Features | MAE (eV) |
|---|---|---|---|
| Transformer | Mean | No | 1.55 |
| Transformer | Max | No | 1.55 |
| Transformer | Max | Yes | 0.72 |
| GAT | Max | Yes | 0.94 |
| CGNN | Max | Yes | 0.98 |
Success criterion: Achieve MAE ≤ 0.80 eV with Transformer + Max Pool + Topo Features.
Interactive 3D crystal viewer where users can:
- Load perovskite structure
- Click atoms to create vacancies
- See predicted formation energy update in real-time
- View "heatmap" of formation energies for all sites
┌─────────────────────────────────────────────────────────────┐
│ Browser │
│ ┌─────────────────┐ ┌─────────────────────────────────┐ │
│ │ React UI │◄──►│ WebAssembly Module │ │
│ │ (Three.js) │ │ - Structure parsing │ │
│ │ │ │ - Graph construction │ │
│ │ - 3D viewer │ │ - Betti features (precomputed)│ │
│ │ - Controls │ │ - GNN forward pass │ │
│ │ - Heatmap │ │ │ │
│ └────────┬────────┘ └──────────────────────────────────┘│
│ │ │
│ └──── Embind API ─────────────────────────────────┘│
└─────────────────────────────────────────────────────────────┘
// Exposed to JavaScript
EMSCRIPTEN_BINDINGS(defect_gnn) {
function("loadStructure", &load_structure_from_string);
function("createVacancy", &create_vacancy);
function("predictFormationEnergy", &predict_formation_energy);
function("getAllSiteEnergies", &get_all_site_energies); // For heatmap
class_<Structure>("Structure")
.function("numAtoms", &Structure::num_atoms)
.function("getAtomPosition", &Structure::get_atom_position)
.function("getAtomElement", &Structure::get_atom_element);
}- Pre-compute Betti features for sample structures (avoid Ripser in browser)
- Quantize weights to reduce model size
- Limit to Transformer (best performing)
- Use ONNX Runtime Web as alternative to custom inference
web/
├── public/
│ └── wasm/
│ ├── defect_gnn.js # Emscripten glue
│ └── defect_gnn.wasm # Binary
├── src/
│ ├── components/
│ │ ├── CrystalViewer.tsx # Three.js 3D view
│ │ ├── DefectSelector.tsx # Click-to-create-vacancy
│ │ ├── EnergyDisplay.tsx # Show prediction
│ │ └── HeatmapOverlay.tsx # Color atoms by energy
│ ├── hooks/
│ │ └── useDefectGNN.ts # WASM loading & API
│ └── App.tsx
└── sample_structures/
└── BaTiO3_supercell.vasp # Pre-loaded demo
| Your File | Reference Source | Notes |
|---|---|---|
vasp_parser.cpp |
Betti_number.py:15-46 |
VASP format parsing |
structure.cpp |
— | Crystal structure + Atom class |
edge_features.cpp |
utilities.py:25-32 |
Gaussian RBF |
crystal_graph.cpp |
utilities.py:34-71 |
structureToGraph |
ripser_wrapper.cpp |
Ripser C++ source | Persistence computation |
betti_features.cpp |
Betti_number.py:93-254 |
35 Betti statistics |
pca.cpp |
dataset_PCA.py:80-81 |
Dimensionality reduction |
cgconv.cpp |
PyG CGConv + paper |
|
gatv2conv.cpp |
PyG GATv2Conv |
|
transformer_conv.cpp |
PyG TransformerConv |
|
pool.cpp |
PyG global pooling | Max pool is critical |
model.cpp |
model_embedding.py |
Add topo concat + max pool |
dataset.cpp |
dataset_PCA.py |
|
trainer.cpp |
GNN.py |
Training loop |
| Library | Purpose | Install |
|---|---|---|
| Eigen | Linear algebra | Header-only, copy to third_party/ |
| nlohmann/json | JSON parsing | Header-only |
| nanoflann | KD-tree | Header-only |
| Ripser | Persistent homology | Clone from GitHub |
| Emscripten | WASM compilation | emsdk install latest |
- Parse VASP files
- Build crystal structures with PBC
- Construct neighbor lists
- Generate Gaussian RBF edge features
- Build crystal graph objects
- Integrate C++ Ripser via ripser_wrapper
- Compute persistence diagrams
- Extract 35 Betti statistics per atom
- Implement PCA for dimensionality reduction
- Integrate topo features into graph
- Implement Linear layer
- Implement CGConv
- Implement GATv2Conv
- Implement TransformerConv
- Implement global_max_pool
- Build full model with topo feature injection
- Dataset class with preprocessing
- DataLoader with batching
- Training loop (or just use PyTorch + export)
- Weight loading from PyTorch checkpoint
- Unit tests pass
- Forward pass matches PyTorch
- Reproduce paper MAE (~0.72 eV)
- Compile core to WASM
- Embind API
- Three.js crystal viewer
- Interactive defect creation
- Formation energy heatmap