Skip to content

Autograd support for affine transformations #2490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Adjoint support for differentiating w.r.t. parameters (`center`, `size`) of transformed `td.Box`.
- Adjoint support for differentiating w.r.t. rotation angle of `td.Transform`.
- Adjoint support for differentiating through chained affine transformations (`.scaled(...).rotated(...).translated(...)`).

## [2.8.4] - 2025-05-15

### Added
Expand Down
120 changes: 120 additions & 0 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,3 +2208,123 @@ def objective(x):

with pytest.raises(ValueError):
g = ag.grad(objective)(1.0)


def make_sim_rotation(center: tuple, size: tuple, angle: float, axis: int):
wavelength = 1.5
L = 10 * wavelength
freq0 = td.C_0 / wavelength
buffer = 1.0 * wavelength

# Source
src = td.PointDipole(
center=(-L / 2 + buffer, 0, 0),
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
polarization="Ez",
)
# Monitor
mnt = td.FieldMonitor(
center=(
+L / 2 - buffer,
0.5 * buffer,
0.5 * buffer,
),
size=(0.0, 0.0, 0.0),
freqs=[freq0],
name="point",
)
# The box geometry
base_box = td.Box(center=center, size=size)
if angle is not None:
base_box = base_box.rotated(angle, axis)

scatterer = td.Structure(
geometry=base_box,
medium=td.Medium(permittivity=2.0),
)

sim = td.Simulation(
size=(L, L, L),
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
structures=[scatterer],
sources=[src],
monitors=[mnt],
run_time=120 / freq0,
)
return sim


def objective_fn(center, size, angle, axis):
sim = make_sim_rotation(center, size, angle, axis)
sim_data = web.run(sim, task_name="emulated_rot_test", local_gradient=True, verbose=False)
return anp.sum(sim_data.get_intensity("point").values)


Comment on lines +2261 to +2262
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: make_sim_rotation could be reused by storing sim as a fixture instead of recreating for each test

def get_grad(center, size, angle, axis):
def wrapped(c, s):
return objective_fn(c, s, angle, axis)

val, (grad_c, grad_s) = ag.value_and_grad(wrapped, argnum=(0, 1))(center, size)
return val, grad_c, grad_s


@pytest.mark.numerical
@pytest.mark.parametrize(
"angle_deg, axis",
[
(0.0, 1),
(180.0, 1),
(90.0, 1),
(270.0, 1),
],
)
def test_box_rotation_gradients(use_emulated_run, angle_deg, axis):
center0 = (0.0, 0.0, 0.0)
size0 = (2.0, 2.0, 2.0)

angle_rad = np.deg2rad(angle_deg)
val, grad_c, grad_s = get_grad(center0, size0, angle=None, axis=None)
npx, npy, npz = grad_c
sSx, sSy, sSz = grad_s

assert not np.allclose(grad_c, 0.0), "center gradient is all zero."
Comment on lines +2286 to +2290
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: use more descriptive variable names than npx,npy,npz for gradient components

assert not np.allclose(grad_s, 0.0), "size gradient is all zero."

if angle_deg == 180.0:
# rotating 180° about y => (x,z) become negated, y stays same
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
rSx, rSy, rSz = grad_s_ref
rx, ry, rz = grad_c_ref

assert np.allclose(grad_c[0], -grad_c_ref[0], atol=1e-6), "center_x sign mismatch"
assert np.allclose(grad_c[1], grad_c_ref[1], atol=1e-6), "center_y mismatch"
assert np.allclose(grad_c[2], -grad_c_ref[2], atol=1e-6), "center_z sign mismatch"
assert np.allclose(grad_s, grad_s_ref, atol=1e-6), "size grads changed unexpectedly"

elif angle_deg == 90.0:
# rotating 90° about y => new x= old z, new z=- old x, y stays same
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
rSx, rSy, rSz = grad_s_ref
rx, ry, rz = grad_c_ref

assert np.allclose(npx, rz, atol=1e-6), "center_x != old center_z"
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
assert np.allclose(npz, -rx, atol=1e-6), "center_z != - old center_x"

assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"

elif angle_deg == 270.0:
# rotating 270° about y => new x= - old z, new z= old x, y stays same
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
rSx, rSy, rSz = grad_s_ref
rx, ry, rz = grad_c_ref

assert np.allclose(npx, -rz, atol=1e-6), "center_x != - old center_z"
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
assert np.allclose(npz, rx, atol=1e-6), "center_z != old center_x"

assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"
253 changes: 253 additions & 0 deletions tests/test_components/test_box_chained_derivatives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""
FD vs AD checks for every affine parameter of a dielectric box:
• centre c = (cx,cy,cz)
• size a = (ax,ay,az)
• rotation θ about each axis
• scale s = (sx,sy,sz)
• translation t = (tx,ty,tz)
"""

import atexit
import os
from collections import defaultdict

import autograd
import autograd.numpy as anp
import matplotlib.pyplot as plt
import numpy as np
import pytest
import tidy3d as td
import tidy3d.web as web
from autograd import tuple

# ───────── switches ───────────────────────────────────────────────────────
SAVE = False # save raw .npz
PLOT = True # make png plots
OUT = "./fd_ad_all_results"

# ───────── physical / geometric constants ─────────────────────────────────
λ = 1.5
f0 = td.C_0 / λ
Lbox = 10 * λ
buffer = 1.0 * λ
Tsim = 120 / f0

# baseline shape parameters ( **plain Python lists / tuples** )
center0 = [0.0, 0.0, 0.0]
size0 = [2.0, 2.0, 2.0]
eps_box = 2.0

theta0 = np.pi / 4 # baseline rotation (x-axis)
axis0 = 0
scale0 = [1.2, 1.3, 0.9]
trans0 = [0.45, 0.20, 0.30]

# finite-difference steps
Δθ = 0.015
Δxyz = 0.03

AXES = (0, 1, 2)
LAB = {0: "x", 1: "y", 2: "z"}

plots = defaultdict(list)


# ───────── simulation builder ─────────────────────────────────────────────
def make_simulation(center, size, scale, trans, theta, axis):
"""Return a Tidy3D Simulation with the requested affine parameters."""
src = td.PointDipole(
center=(-Lbox / 2 + 0.5 * buffer, 0, 0),
source_time=td.GaussianPulse(freq0=f0, fwidth=f0 / 10),
polarization="Ez",
)

mon = td.FieldMonitor(
center=(+Lbox / 2 - 0.5 * buffer, 2.0 * buffer, 2.0 * buffer),
size=(0, 0, 0),
freqs=[f0],
name="m",
)

geom = (
td.Box(center=tuple(center), size=tuple(size))
.rotated(theta, axis=axis)
.scaled(x=scale[0], y=scale[1], z=scale[2])
.translated(x=trans[0], y=trans[1], z=trans[2])
)
struct = td.Structure(geometry=geom, medium=td.Medium(permittivity=eps_box))

return td.Simulation(
size=(Lbox, Lbox, Lbox),
run_time=Tsim,
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
sources=[src],
monitors=[mon],
structures=[struct],
)


def objective(c, a, s, t, θ, ax):
sim = make_simulation(c, a, s, t, θ, ax)
data = web.run(sim, task_name="fd_ad_all", verbose=False, local_gradient=True)
return anp.sum(data.get_intensity("m").values)


def finite_diff(fun, x0, δ):
return (fun(x0 + δ) - fun(x0 - δ)) / (2 * δ)


def _assert_close(fd_val, ad_val, tag, axis=None, extra="", tol=0.35):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The tolerance value 0.35 (35%) seems quite high for numerical comparison. Consider tightening or documenting justification.

"""
Raise AssertionError if |FD-AD|/max(|FD|,1e-12) >= tol.
"""
rel = abs(fd_val - ad_val) / max(abs(fd_val), 1e-12)
if rel >= tol:
ax_lbl = {0: "x", 1: "y", 2: "z"}.get(axis, "")
axis_str = f"_{ax_lbl}" if ax_lbl else ""
raise AssertionError(
f"{tag}{axis_str}{extra}: FD–AD mismatch "
f"(rel diff {rel:.2%}) | FD={fd_val:.4e}, AD={ad_val:.4e}"
)


# 1. ROTATION θ_k (k = 0,1,2)
θ_vals = np.array([0, np.pi / 4, np.pi / 2])


@pytest.mark.numerical
@pytest.mark.parametrize("k", AXES, ids=[f"θ_{LAB[a]}" for a in AXES])
@pytest.mark.parametrize("θ", θ_vals)
def test_fd_vs_ad_rotation(k, θ):
f = lambda th: objective(center0, size0, scale0, trans0, th, k)
g_ad = autograd.grad(f)(θ)
g_fd = finite_diff(f, θ, Δθ)

plots[("θ", k, "fd")].append((θ, g_fd))
plots[("θ", k, "ad")].append((θ, g_ad))

assert np.isfinite(g_fd) and np.isfinite(g_ad)
_assert_close(g_fd, g_ad, tag="θ", axis=k)


# 2. SCALE s_k
@pytest.mark.numerical
@pytest.mark.parametrize("k", AXES, ids=[f"s_{LAB[a]}" for a in AXES])
def test_fd_vs_ad_scale(k):
def f(sk):
s = list(scale0)
s[k] = sk
return objective(center0, size0, s, trans0, theta0, axis0)

g_ad = autograd.grad(f)(scale0[k])
g_fd = finite_diff(f, scale0[k], Δxyz)

plots[("s", k, "fd")].append(g_fd)
plots[("s", k, "ad")].append(g_ad)

assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result"
_assert_close(g_fd, g_ad, tag="scale", axis=k)


# 3. TRANSLATION t_k
@pytest.mark.numerical
@pytest.mark.parametrize("k", AXES, ids=[f"t_{LAB[a]}" for a in AXES])
def test_fd_vs_ad_translation(k):
def f(tk):
t = list(trans0)
t[k] = tk
return objective(center0, size0, scale0, t, theta0, axis0)

g_ad = autograd.grad(f)(trans0[k])
g_fd = finite_diff(f, trans0[k], Δxyz)

plots[("t", k, "fd")].append(g_fd)
plots[("t", k, "ad")].append(g_ad)

assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result"
_assert_close(g_fd, g_ad, tag="trans", axis=k)


# # 4. SIZE a_k
@pytest.mark.numerical
@pytest.mark.parametrize("k", AXES, ids=[f"a_{LAB[a]}" for a in AXES])
def test_fd_vs_ad_size(k):
def f(ak):
a = list(size0)
a[k] = ak
return objective(center0, a, scale0, trans0, theta0, axis0)

g_ad = autograd.grad(f)(size0[k])
g_fd = finite_diff(f, size0[k], Δxyz)

plots[("a", k, "fd")].append(g_fd)
plots[("a", k, "ad")].append(g_ad)

assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result"
_assert_close(g_fd, g_ad, tag="size", axis=k)


# # 5. CENTRE c_k
@pytest.mark.numerical
@pytest.mark.parametrize("k", AXES, ids=[f"c_{LAB[a]}" for a in AXES])
def test_fd_vs_ad_center(k):
def f(ck):
c = list(center0)
c[k] = ck
return objective(c, size0, scale0, trans0, theta0, axis0)

g_ad = autograd.grad(f)(center0[k])
g_fd = finite_diff(f, center0[k], Δxyz)

plots[("c", k, "fd")].append(g_fd)
plots[("c", k, "ad")].append(g_ad)

assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result"
_assert_close(g_fd, g_ad, tag="center", axis=k)


# ───────── save / plot after test run ────────────────────────────
def _save_and_plot():
if not (SAVE or PLOT):
return

os.makedirs(OUT, exist_ok=True)

if SAVE:
np.savez_compressed(os.path.join(OUT, "gradients.npz"), **plots)

labels, errs = [], []

for k in AXES:
fd_pairs = plots.get(("θ", k, "fd"), [])
ad_pairs = plots.get(("θ", k, "ad"), [])
for (theta, fd), (_, ad) in zip(sorted(fd_pairs), sorted(ad_pairs)):
lbl = f"θ_{LAB[k]} {theta:.2f}"
rel = abs(fd - ad) / max(abs(fd), 1e-12)
labels.append(lbl)
errs.append(rel)

for tag, pretty in zip(("s", "t", "a", "c"), ("scale", "trans", "size", "center")):
for k in AXES:
fd_vals = plots.get((tag, k, "fd"), [])
ad_vals = plots.get((tag, k, "ad"), [])
if fd_vals:
rel = abs(fd_vals[0] - ad_vals[0]) / max(abs(fd_vals[0]), 1e-12)
labels.append(f"{pretty}_{LAB[k]}")
errs.append(rel)

if not (PLOT and errs):
return

plt.figure(figsize=(12, 5))
plt.bar(range(len(errs)), errs)
plt.xticks(range(len(errs)), labels, rotation=75, ha="right", fontsize=8)
plt.ylabel("relative |FD – AD| / max(|FD|)")
plt.title("FD vs AD relative errors for all parameters")
plt.tight_layout()
fname = os.path.join(OUT, "bar_rel_error_all_params.png")
plt.savefig(fname, dpi=150)
plt.close()
print(f"[plot] saved {fname}")


atexit.register(_save_and_plot)
Loading