A fast, differentiable, JIT-free, debugging-friendly finite element library for PyTorch.
Documentation | Installation | Quickstart | Examples | Citation
TensorMesh is a finite element method (FEM) library built natively on PyTorch. It is designed to solve partial differential equations (PDEs) with the ergonomics of modern deep learning frameworks — automatic differentiation, GPU acceleration, eager execution — without sacrificing the rigour of classical FEM. Custom weak forms are written in pure Python; the library takes care of tensorized assembly, sparse linear algebra, boundary conditions, and time integration.
- GPU-native & differentiable. Built on PyTorch from the ground up. Moving an entire FEM workflow to the GPU takes a single line of code — every downstream assembly, solve, and gradient inherits the device automatically, with no separate backend or data-marshalling step. Native autograd flows seamlessly through assembly and solve, enabling end-to-end differentiable PDE pipelines.
- High-performance tensorized assembly. A fully tensorized Map-Reduce algorithm powered by TensorGalerkin, which fuses element-wise operations into monolithic GPU kernels, eliminating Python-level loops and delivering order-of-magnitude speedups over CPU-based FEM stacks.
- JIT-free & debugging-friendly. Eager execution with no compilation overhead. Dynamic meshes, adaptive refinement, and interactive workflows just work — no recompilation latency, no opaque traces.
- Comprehensive element & mesh support. Triangular, tetrahedral, pyramid, and prismatic elements with automated mesh generation for common geometries and seamless Gmsh / VTKHDF5 I/O.
- Flexible Solvers. Powered by torch-sla, our companion library for differentiable sparse linear algebra. Linear, nonlinear, and eigenvalue solvers run across multiple backends on CPU and GPU, with full autograd support, batched solves, and distributed multi-GPU scaling.
- Pythonic API. Custom weak forms in pure Python — no separate DSL, no form compiler. If you can write PyTorch, you can write FEM.
Requirements: Python ≥ 3.10, PyTorch ≥ 2.0.
pip install tensormesh-fem # CPU only
pip install "tensormesh-fem[gpu]" # + CUDA sparse solvers (CuPy + cuDSS)The base install ships only the CPU sparse stack (SciPy / native PyTorch via
torch-sla). The [gpu] extra pulls in both
CUDA backends; if you only want one, use [cupy] or [cudss] instead:
pip install "tensormesh-fem[cupy]" # CuPy CUDA backend (iterative + SuperLU)
pip install "tensormesh-fem[cudss]" # cuDSS CUDA backend (fastest GPU direct)The quotes are needed because [...] is a shell glob character.
Install from source (for development)
git clone https://github.com/camlab-ethz/TensorMesh.git
cd TensorMesh
pip install -e ".[test]"After installing, sanity-check the install:
python -m tensormesh.verify_installTo see which sparse-solver backends are usable on your machine — and a one-line install hint for any that are not — run:
import torch_sla
torch_sla.show_backends()Solve
import math
import torch
from tensormesh import ElementAssembler, NodeAssembler, Mesh, Condenser
# 1. Generate a triangular mesh of the unit square.
mesh = Mesh.gen_rectangle(chara_length=0.05)
# 2. Stiffness weak form: a(u, v) = ∫ ∇u · ∇v dΩ
class LaplaceAssembler(ElementAssembler):
def forward(self, gradu, gradv):
return gradu @ gradv
# 3. Load weak form: l(v) = ∫ f v dΩ
class SourceAssembler(NodeAssembler):
def forward(self, v, f):
return f * v
# 4. Source term, evaluated at every mesh node.
x, y = mesh.points[:, 0], mesh.points[:, 1]
f_vals = 2 * math.pi**2 * torch.sin(math.pi * x) * torch.sin(math.pi * y)
# 5. Assemble the stiffness matrix and load vector.
K = LaplaceAssembler.from_mesh(mesh)()
b = SourceAssembler.from_mesh(mesh)(point_data={"f": f_vals})
# 6. Apply Dirichlet BCs by static condensation, then solve.
condenser = Condenser(mesh.boundary_mask)
K_, b_ = condenser(K, b)
u_ = K_.solve(b_, verbose = True)
u = condenser.recover(u_)
# 7. Compare against the analytical solution.
u_exact = torch.sin(math.pi * x) * torch.sin(math.pi * y)
print(f"L2 error: {(u - u_exact).norm() / u_exact.norm():.3e}")[torch-sla] solve: n=431, nnz=2859, dtype=float64, device=cpu, symmetric=True, spd=False, backend=scipy, method=lu
L2 error: 3.135e-03
The workflow is Mesh → Assembler → SparseMatrix → Condenser → Solve.
Move everything to GPU with a single mesh = mesh.cuda(); enable
gradients with mesh.points.requires_grad_(True) and the same script
becomes an inverse problem.
See the full walkthrough in the Quickstart.
A small selection from the example gallery:
| Category | Path | Description |
|---|---|---|
| Basics | examples/basics/ |
Mesh visualization, basis functions, element gallery |
| Poisson | examples/poisson/ |
2D / 3D Poisson, batched RHS, h-adaptivity |
| Diffusion | examples/diffusion/ |
Heat equation, Allen-Cahn phase field |
| Wave | examples/wave/ |
Wave equation with central-difference scheme |
| Solid | examples/solid/ |
Cantilever beam, hyperelasticity, contact, plasticity |
| Fluid | examples/fluid/ |
Lid-driven cavity, cylinder flow, flow past obstacles, Rayleigh-Bénard, Taylor-Green |
| Magnetostatics | examples/maxwell/ |
3D Maxwell: magnetic field around a current-carrying wire via a stabilized nodal curl-curl formulation |
| Inverse design | examples/inverse_design/ |
Coefficient-field identification and density-based topology optimization, all via autograd |
| Physics-informed | examples/physics_informed/ |
Train a neural network to minimize the assembled Galerkin residual |
| Dataset | examples/dataset/ |
Batch dataset generation for ML (heat, wave, Poisson) |
| Distributed | examples/distributed/ |
Graph coloring, mesh partitioning, multi-GPU assembly |
| Feature | FEniCS | scikit-fem | JAX-FEM | torch-fem | TensorMesh |
|---|---|---|---|---|---|
| Custom weak forms (Pythonic) | ✅ | ❌ | ❌ | ✅ | |
| Easy install | ❌ | ✅ | ✅ | ✅ | |
| Easy debug | ❌ | ✅ | ❌ | ✅ | ✅ |
| Easy I/O | ❌ | ❌ | ❌ | ❌ | ✅ |
| Large meshes | ✅ | ✅ | ❌ | ❌ | ✅ |
| GPU support | ✅ | ❌ | ✅ | ✅ | ✅ |
| Efficiency | ✅ | ❌ | ✅ | ✅ | |
| End-to-end autograd | ❌ | ✅ | ✅ | ✅ | |
| Deep-learning integration | ❌ | ❌ | ✅ | ✅ | |
| Maturity | ✅ | ✅ |
Custom Weak Forms (Pythonic) — user-defined bilinear / linear forms directly in Python, without a separate DSL such as UFL. End-to-End Autograd — gradients flow natively through the entire pipeline; FEniCS supports this via the external
dolfin-adjointpackage. Maturity — reflects project age, ecosystem size, and production deployments.
The core workflow: Mesh → Assembler → SparseMatrix → Condenser → Solve.
| Module | Description |
|---|---|
tensormesh.mesh |
Mesh data structure; built-in generators (gen_rectangle, gen_circle, gen_cube, gen_L, …); Gmsh / VTK-HDF5 I/O |
tensormesh.element |
Shape functions, quadrature rules, element transformations (geometric order 1–4) |
tensormesh.assemble |
ElementAssembler, NodeAssembler, FacetAssembler for matrix and vector assembly |
tensormesh.sparse |
SparseMatrix (subclass of torch_sla.SparseTensor); linear & nonlinear sparse solves via torch-sla backends (SciPy / Eigen / native PyTorch / CuPy / cuDSS) |
tensormesh.operator |
Condenser for Dirichlet boundary conditions via static condensation |
tensormesh.ode |
Time integrators: explicit / implicit Euler, midpoint, Runge–Kutta |
tensormesh.dataset |
Parametric PDE dataset generation (Poisson, Heat, Wave, linear elasticity) |
tensormesh.visualization |
Matplotlib and PyVista plotting backends |
tensormesh.functional |
Tensor utilities for FEM (elasticity, Voigt notation, common ops) |
tensormesh.material |
Material property definitions for solid mechanics |
tensormesh.optimizer |
Optimization algorithms (e.g. OC for topology optimization) |
Full documentation, including a user guide, an example gallery, the API reference, and performance benchmarks, lives at docs.tensor-mesh.com.
Key entry points:
- Getting started — installation, quickstart, and an install smoke-test.
- User guide — meshes, weak forms, boundary conditions, linear solvers, time integration, differentiability.
- Example gallery — runnable examples from Poisson to Navier–Stokes and topology optimization.
- API reference — module-by-module signatures and docstrings.
- Performance — benchmarks against FEniCS / Firedrake / MFEM / scikit-fem / JAX-FEM / torch-fem.
- Discord — real-time chat, help channels, and showcase.
- GitHub Discussions — announcements, Q&A, ideas & RFCs.
- GitHub Issues — bug reports and feature requests.
- Email to Shizheng Wen at shizheng.wen@sam.math.ethz.ch — collaborations and partnerships.
Contributions are welcome — see CONTRIBUTING.md for the
development setup, test workflow, documentation build, and PR conventions.
TensorMesh is the FEM solver component of the TensorGalerkin framework. If you use TensorMesh in your research, please cite the TensorGalerkin paper:
@article{wen2026tensorgalerkin,
title = {Learning, Solving and Optimizing PDEs with {TensorGalerkin}:
an Efficient High-Performance Galerkin Assembly Algorithm},
author = {Wen, Shizheng and Chi, Mingyuan and Yu, Tianwei and
Moseley, Ben and Michelis, Mike Yan and Ren, Pu and
Sun, Hao and Mishra, Siddhartha},
journal = {arXiv preprint arXiv:2602.05052},
year = {2026}
}If your work also relies on torch-sla (TensorMesh's solver backend),
please additionally cite:
@article{chi2026torchsla,
title = {torch-sla: Differentiable Sparse Linear Algebra with Adjoint
Solvers and Sparse Tensor Parallelism for PyTorch},
author = {Chi, Mingyuan and Wen, Shizheng},
journal = {arXiv preprint arXiv:2601.13994},
year = {2026}
}










