A Comprehensive Graph Machine Learning Library Built on Jittor
Documentation β’ Examples β’ Installation β’ Quick Start β’ Models
JittorGeometric 2.0 is a state-of-the-art graph machine learning library built on the Jittor framework. As a Chinese-developed deep learning library, JittorGeometric provides comprehensive support for Graph Neural Networks (GNNs) research and applications, featuring enhanced performance, flexibility, and scalability.
- π JIT Compilation: Leverage Just-In-Time compilation for dynamic code modification without pre-compilation overhead
- β‘ Optimized Sparse Operations: High-performance sparse matrix computations with CuSparse acceleration
- π― Comprehensive Model Zoo: 40+ implemented models covering classic, spectral, dynamic, molecular, and transformer-based GNNs
- π Rich Dataset Support: Built-in loaders for popular graph datasets (Planetoid, OGB, Reddit, etc.)
- π Distributed Training: Multi-GPU and multi-node training support with MPI
- π Dynamic Graph Processing: Event-based dynamic graph support with parallel processing
- π¦ Mini-batch Support: Efficient mini-batch training for large-scale graphs
- π§ Ascend-GNN: GNN for NPU
- ποΈ Extended Model Categories: Graph transformers, self-supervised learning, and recommendation systems
### Dataset Selection
import os.path as osp
from jittor_geometric.datasets import Planetoid
import jittor_geometric.transforms as T
import jittor as jt
dataset = 'cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
v_num = data.x.shape[0]
### Data Preprocess
from jittor_geometric.ops import cootocsr,cootocsc
from jittor_geometric.nn.conv.gcn_conv import gcn_norm
edge_index, edge_weight = data.edge_index, data.edge_attr
edge_index, edge_weight = gcn_norm(
edge_index, edge_weight,v_num,
improved=False, add_self_loops=True)
with jt.no_grad():
data.csc = cootocsc(edge_index, edge_weight, v_num)
data.csr = cootocsr(edge_index, edge_weight, v_num)
### Model Definition
from jittor import nn
from jittor_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, dataset, dropout=0.8):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels=dataset.num_features, out_channels=256)
self.conv2 = GCNConv(in_channels=256, out_channels=dataset.num_classes)
self.dropout = dropout
def execute(self):
x, csc, csr = data.x, data.csc, data.csr
x = nn.relu(self.conv1(x, csc, csr))
x = nn.dropout(x, self.dropout, is_train=self.training)
x = self.conv2(x, csc, csr)
return nn.log_softmax(x, dim=1)
### Training
model = GCN(dataset)
optimizer = nn.Adam(params=model.parameters(), lr=0.001, weight_decay=5e-4)
for epoch in range(200):
model.train()
pred = model()[data.train_mask]
label = data.y[data.train_mask]
loss = nn.nll_loss(pred, label)
optimizer.step(loss)JittorGeometric 2.0 includes implementations of 40+ state-of-the-art GNN models:
| Model | Year | Venue | Description |
|---|---|---|---|
| ChebNet | 2016 | NeurIPS | Spectral graph convolutions |
| GCN | 2017 | ICLR | Graph Convolutional Networks |
| GraphSAGE | 2017 | NeurIPS | Inductive graph learning |
| GAT | 2018 | ICLR | Graph Attention Networks |
| SGC | 2019 | ICML | Simplified Graph Convolution |
| APPNP | 2019 | ICLR | Approximate Personalized Propagation |
| GCNII | 2020 | ICML | Deeper Graph Convolutional Networks |
| Model | Year | Venue | Description |
|---|---|---|---|
| GPRGNN | 2021 | ICLR | Generalized PageRank GNN |
| BernNet | 2021 | NeurIPS | Bernstein polynomial filters |
| ChebNetII | 2022 | NeurIPS | Improved Chebyshev filters |
| EvenNet | 2022 | NeurIPS | Even polynomial filters |
| OptBasis | 2023 | ICML | Optimal basis functions |
| Model | Year | Venue | Description |
|---|---|---|---|
| JODIE | 2019 | SIGKDD | Temporal interaction networks |
| DyRep | 2019 | ICLR | Dynamic representation learning |
| TGN | 2020 | ArXiv | Temporal Graph Networks |
| GraphMixer | 2022 | ICLR | MLP-based dynamic graphs |
| Dygformer | 2023 | NeurIPS | Dynamic graph transformers |
| Model | Year | Venue | Description |
|---|---|---|---|
| SchNet | 2017 | NeurIPS | Continuous-filter convolutions |
| DimeNet | 2020 | ICLR | Directional message passing |
| EGNN | 2021 | ICML | Equivariant Graph Networks |
| Graphormer | 2021 | NeurIPS | Graph transformers for molecules |
| SphereNet | 2022 | ICLR | Spherical message passing |
| Uni-Mol | 2023 | ICLR | Universal molecular representation |
| Transformer-M | 2023 | ICLR | Molecular transformers |
| Model | Year | Venue | Description |
|---|---|---|---|
| DGI | 2019 | ICLR | Deep Graph Infomax |
| MVGRL | 2020 | ICML | Multi-view contrastive learning |
| GRACE | 2020 | ICML | Graph contrastive learning |
| PolyGCL | 2024 | ICLR | Polynomial graph contrastive learning |
| Model | Year | Venue | Description |
|---|---|---|---|
| SASREC | 2018 | ICDM | Self-attentive sequential recommendation |
| SGNNHN | 2020 | CIKM | Set-based GNN for heterogeneous networks |
| CRAFT | 2025 | ArXiv | Cross-attention recommendation |
| Model | Year | Venue | Description |
|---|---|---|---|
| Deepwalk | 2014 | KDD | Random walk embeddings |
| LINE | 2015 | WWW | Large-scale information network embedding |
| Node2Vec | 2016 | KDD | Scalable feature learning |
| LightGCN | 2020 | SIGIR | Simplified GCN for recommendation |
| DirectAU | 2022 | KDD | Direct alignment and uniformity |
| SimGCL | 2024 | KAIS | Simple graph contrastive learning |
| XSimGCL | 2024 | TKDE | Extreme simple graph contrastive learning |
| Model | Year | Venue | Description |
|---|---|---|---|
| SGFormer | 2023 | NeurIPS | Simplifying graph transformers |
| NAGFormer | 2023 | ICLR | Neighborhood aggregation transformers |
| PolyFormer | 2024 | KDD | Polynomial-based graph transformers |
-
Create a conda environment
conda create -n jittorgeometric python=3.10 conda activate jittorgeometric
-
Install Jittor
python -m pip install git+https://github.com/Jittor/jittor.git
or follow the Jittor official documentation.
-
Install dependencies
pip install astunparse==1.6.3 autograd==1.7.0 cupy==13.3.0 numpy==1.24.0 \ pandas==2.2.3 Pillow==11.1.0 PyMetis==2023.1.1 six==1.16.0 \ pyparsing==3.2 scipy==1.15.1 setuptools==69.5.1 sympy==1.13.3 \ tqdm==4.66.4 einops huggingface_hub==0.27.1 networkx==3.4.2 \ scikit-learn==1.7.1 rdkit==2025.3.5 seaborn==0.13.2 \ alive-progress==3.3.0 -
Install JittorGeometric
git clone https://github.com/AlgRUC/JittorGeometric.git cd JittorGeometric pip install .
-
Verify installation
python examples/gcn_example.py
Install MPI support:
conda install -c conda-forge openmpi=4.0.5
conda install -c conda-forge mpi4py- Python 3.10+
- CUDA 11.0+ (for GPU support)
- Jittor 1.3.0+
- CuPy (for CUDA operations)
- NumPy, SciPy, NetworkX
- For distributed training: OpenMPI 4.0.5+ and mpi4py
JittorGeometric 2.0 supports distributed training across multiple GPUs and nodes:
mpiexec -n 2 python dist_gcn.py --num_parts 2 --dataset reddit- Configure your hostfile:
172.31.195.15 slots=1
172.31.195.16 slots=1
- Partition the graph:
python dist_partition.py --dataset reddit --num_parts 2 --use_gdc- Launch distributed training:
mpirun -n 2 --hostfile hostfile \
--prefix /path/to/conda/env \
python dist_gcn.py --num_parts 2 --dataset redditFor detailed distributed training setup, see examples/README.md.
Run a specific example:
python examples/gcn_example.pyComprehensive documentation is available at https://algruc.github.io/JittorGeometric/index.html
This project is actively maintained by the JittorGeometric Team at Renmin University of China and Northeastern Universityβ.
- Project Lead: [email protected]
- Contributors: See Contributors
JittorGeometric is released under the Apache 2.0 License.
- Jittor Team for the deep learning framework
- PyTorch Geometric for inspiration
- All contributors and users of JittorGeometric
