Skip to content

camlab-ethz/TensorMesh

Repository files navigation

TensorMesh

A fast, differentiable, JIT-free, debugging-friendly finite element library for PyTorch.

Documentation  |  Installation  |  Quickstart  |  Examples  |  Citation

PyPI version Python versions PyTorch 2.0+ License: Apache-2.0 Documentation


TensorMesh is a finite element method (FEM) library built natively on PyTorch. It is designed to solve partial differential equations (PDEs) with the ergonomics of modern deep learning frameworks — automatic differentiation, GPU acceleration, eager execution — without sacrificing the rigour of classical FEM. Custom weak forms are written in pure Python; the library takes care of tensorized assembly, sparse linear algebra, boundary conditions, and time integration.

Core strengths

  • GPU-native & differentiable. Built on PyTorch from the ground up. Moving an entire FEM workflow to the GPU takes a single line of code — every downstream assembly, solve, and gradient inherits the device automatically, with no separate backend or data-marshalling step. Native autograd flows seamlessly through assembly and solve, enabling end-to-end differentiable PDE pipelines.
  • High-performance tensorized assembly. A fully tensorized Map-Reduce algorithm powered by TensorGalerkin, which fuses element-wise operations into monolithic GPU kernels, eliminating Python-level loops and delivering order-of-magnitude speedups over CPU-based FEM stacks.
  • JIT-free & debugging-friendly. Eager execution with no compilation overhead. Dynamic meshes, adaptive refinement, and interactive workflows just work — no recompilation latency, no opaque traces.
  • Comprehensive element & mesh support. Triangular, tetrahedral, pyramid, and prismatic elements with automated mesh generation for common geometries and seamless Gmsh / VTKHDF5 I/O.
  • Flexible Solvers. Powered by torch-sla, our companion library for differentiable sparse linear algebra. Linear, nonlinear, and eigenvalue solvers run across multiple backends on CPU and GPU, with full autograd support, batched solves, and distributed multi-GPU scaling.
  • Pythonic API. Custom weak forms in pure Python — no separate DSL, no form compiler. If you can write PyTorch, you can write FEM.

Installation

Requirements: Python ≥ 3.10, PyTorch ≥ 2.0.

pip install tensormesh-fem             # CPU only
pip install "tensormesh-fem[gpu]"      # + CUDA sparse solvers (CuPy + cuDSS)

The base install ships only the CPU sparse stack (SciPy / native PyTorch via torch-sla). The [gpu] extra pulls in both CUDA backends; if you only want one, use [cupy] or [cudss] instead:

pip install "tensormesh-fem[cupy]"     # CuPy CUDA backend (iterative + SuperLU)
pip install "tensormesh-fem[cudss]"    # cuDSS CUDA backend (fastest GPU direct)

The quotes are needed because [...] is a shell glob character.

Install from source (for development)
git clone https://github.com/camlab-ethz/TensorMesh.git
cd TensorMesh
pip install -e ".[test]"

After installing, sanity-check the install:

python -m tensormesh.verify_install

To see which sparse-solver backends are usable on your machine — and a one-line install hint for any that are not — run:

import torch_sla
torch_sla.show_backends()

Quickstart

Solve $-\Delta u = f$ on the unit square with homogeneous Dirichlet boundary conditions, where $f(x, y) = 2\pi^2\sin(\pi x)\sin(\pi y)$ so the exact solution is $u(x, y) = \sin(\pi x)\sin(\pi y)$:

import math
import torch
from tensormesh import ElementAssembler, NodeAssembler, Mesh, Condenser

# 1. Generate a triangular mesh of the unit square.
mesh = Mesh.gen_rectangle(chara_length=0.05)

# 2. Stiffness weak form:  a(u, v) = ∫ ∇u · ∇v dΩ
class LaplaceAssembler(ElementAssembler):
    def forward(self, gradu, gradv):
        return gradu @ gradv

# 3. Load weak form:  l(v) = ∫ f v dΩ
class SourceAssembler(NodeAssembler):
    def forward(self, v, f):
        return f * v

# 4. Source term, evaluated at every mesh node.
x, y = mesh.points[:, 0], mesh.points[:, 1]
f_vals = 2 * math.pi**2 * torch.sin(math.pi * x) * torch.sin(math.pi * y)

# 5. Assemble the stiffness matrix and load vector.
K = LaplaceAssembler.from_mesh(mesh)()
b = SourceAssembler.from_mesh(mesh)(point_data={"f": f_vals})

# 6. Apply Dirichlet BCs by static condensation, then solve.
condenser = Condenser(mesh.boundary_mask)
K_, b_ = condenser(K, b)
u_ = K_.solve(b_, verbose = True)
u = condenser.recover(u_)

# 7. Compare against the analytical solution.
u_exact = torch.sin(math.pi * x) * torch.sin(math.pi * y)
print(f"L2 error: {(u - u_exact).norm() / u_exact.norm():.3e}")
[torch-sla] solve: n=431, nnz=2859, dtype=float64, device=cpu, symmetric=True, spd=False, backend=scipy, method=lu
L2 error: 3.135e-03

The workflow is Mesh → Assembler → SparseMatrix → Condenser → Solve. Move everything to GPU with a single mesh = mesh.cuda(); enable gradients with mesh.points.requires_grad_(True) and the same script becomes an inverse problem.

See the full walkthrough in the Quickstart.

Examples

A small selection from the example gallery:

3D Poisson
3D Poisson — tetrahedral mesh, cut view of the scalar field.
Allen–Cahn phase field
Allen–Cahn phase field — nonlinear time evolution with Newton iteration per step.
2D wave equation
Wave equation — explicit central-difference time integration.
Hyperelastic rubber
Hyperelastic rubber — large-deformation solid mechanics with a Newton solver.
Lid-driven cavity flow
Lid-driven cavity — incompressible Navier–Stokes; velocity field and streamlines.
Magnetostatic field
Magnetostatics — 3D magnetic field around a current-carrying wire (stabilized nodal curl-curl).
Topology optimization
Topology optimization — compliance minimization via the Optimality Criteria method.
Physics-informed Galerkin training
Physics-informed learning — a network trained to minimize the assembled Galerkin residual.
Category Path Description
Basics examples/basics/ Mesh visualization, basis functions, element gallery
Poisson examples/poisson/ 2D / 3D Poisson, batched RHS, h-adaptivity
Diffusion examples/diffusion/ Heat equation, Allen-Cahn phase field
Wave examples/wave/ Wave equation with central-difference scheme
Solid examples/solid/ Cantilever beam, hyperelasticity, contact, plasticity
Fluid examples/fluid/ Lid-driven cavity, cylinder flow, flow past obstacles, Rayleigh-Bénard, Taylor-Green
Magnetostatics examples/maxwell/ 3D Maxwell: magnetic field around a current-carrying wire via a stabilized nodal curl-curl formulation
Inverse design examples/inverse_design/ Coefficient-field identification and density-based topology optimization, all via autograd
Physics-informed examples/physics_informed/ Train a neural network to minimize the assembled Galerkin residual
Dataset examples/dataset/ Batch dataset generation for ML (heat, wave, Poisson)
Distributed examples/distributed/ Graph coloring, mesh partitioning, multi-GPU assembly

Feature comparison

Feature FEniCS scikit-fem JAX-FEM torch-fem TensorMesh
Custom weak forms (Pythonic) ⚠️
Easy install ⚠️
Easy debug
Easy I/O
Large meshes
GPU support
Efficiency ⚠️
End-to-end autograd ⚠️
Deep-learning integration ⚠️
Maturity ⚠️ ⚠️ ⚠️

Custom Weak Forms (Pythonic) — user-defined bilinear / linear forms directly in Python, without a separate DSL such as UFL. End-to-End Autograd — gradients flow natively through the entire pipeline; FEniCS supports this via the external dolfin-adjoint package. Maturity — reflects project age, ecosystem size, and production deployments.

Architecture

The core workflow: Mesh → Assembler → SparseMatrix → Condenser → Solve.

Module Description
tensormesh.mesh Mesh data structure; built-in generators (gen_rectangle, gen_circle, gen_cube, gen_L, …); Gmsh / VTK-HDF5 I/O
tensormesh.element Shape functions, quadrature rules, element transformations (geometric order 1–4)
tensormesh.assemble ElementAssembler, NodeAssembler, FacetAssembler for matrix and vector assembly
tensormesh.sparse SparseMatrix (subclass of torch_sla.SparseTensor); linear & nonlinear sparse solves via torch-sla backends (SciPy / Eigen / native PyTorch / CuPy / cuDSS)
tensormesh.operator Condenser for Dirichlet boundary conditions via static condensation
tensormesh.ode Time integrators: explicit / implicit Euler, midpoint, Runge–Kutta
tensormesh.dataset Parametric PDE dataset generation (Poisson, Heat, Wave, linear elasticity)
tensormesh.visualization Matplotlib and PyVista plotting backends
tensormesh.functional Tensor utilities for FEM (elasticity, Voigt notation, common ops)
tensormesh.material Material property definitions for solid mechanics
tensormesh.optimizer Optimization algorithms (e.g. OC for topology optimization)

Documentation

Full documentation, including a user guide, an example gallery, the API reference, and performance benchmarks, lives at docs.tensor-mesh.com.

Key entry points:

  • Getting started — installation, quickstart, and an install smoke-test.
  • User guide — meshes, weak forms, boundary conditions, linear solvers, time integration, differentiability.
  • Example gallery — runnable examples from Poisson to Navier–Stokes and topology optimization.
  • API reference — module-by-module signatures and docstrings.
  • Performance — benchmarks against FEniCS / Firedrake / MFEM / scikit-fem / JAX-FEM / torch-fem.

Community

Contributing

Contributions are welcome — see CONTRIBUTING.md for the development setup, test workflow, documentation build, and PR conventions.

Citation

TensorMesh is the FEM solver component of the TensorGalerkin framework. If you use TensorMesh in your research, please cite the TensorGalerkin paper:

@article{wen2026tensorgalerkin,
  title   = {Learning, Solving and Optimizing PDEs with {TensorGalerkin}:
             an Efficient High-Performance Galerkin Assembly Algorithm},
  author  = {Wen, Shizheng and Chi, Mingyuan and Yu, Tianwei and
             Moseley, Ben and Michelis, Mike Yan and Ren, Pu and
             Sun, Hao and Mishra, Siddhartha},
  journal = {arXiv preprint arXiv:2602.05052},
  year    = {2026}
}

If your work also relies on torch-sla (TensorMesh's solver backend), please additionally cite:

@article{chi2026torchsla,
  title   = {torch-sla: Differentiable Sparse Linear Algebra with Adjoint
             Solvers and Sparse Tensor Parallelism for PyTorch},
  author  = {Chi, Mingyuan and Wen, Shizheng},
  journal = {arXiv preprint arXiv:2601.13994},
  year    = {2026}
}

Acknowledgements

CAMLab — ETH Zürich         ETH AI Center     ETH Zürich

About

A fast, differentiable, JIT-free, debugging-friendly finite element library for PyTorch.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages