-
Notifications
You must be signed in to change notification settings - Fork 58
Autograd support for affine transformations #2490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
c647f6e
0cbf837
fdf0649
33e649f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2208,3 +2208,123 @@ def objective(x): | |
|
||
with pytest.raises(ValueError): | ||
g = ag.grad(objective)(1.0) | ||
|
||
|
||
def make_sim_rotation(center: tuple, size: tuple, angle: float, axis: int): | ||
wavelength = 1.5 | ||
L = 10 * wavelength | ||
freq0 = td.C_0 / wavelength | ||
buffer = 1.0 * wavelength | ||
|
||
# Source | ||
src = td.PointDipole( | ||
center=(-L / 2 + buffer, 0, 0), | ||
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0), | ||
polarization="Ez", | ||
) | ||
# Monitor | ||
mnt = td.FieldMonitor( | ||
center=( | ||
+L / 2 - buffer, | ||
0.5 * buffer, | ||
0.5 * buffer, | ||
), | ||
size=(0.0, 0.0, 0.0), | ||
freqs=[freq0], | ||
name="point", | ||
) | ||
# The box geometry | ||
base_box = td.Box(center=center, size=size) | ||
if angle is not None: | ||
base_box = base_box.rotated(angle, axis) | ||
|
||
scatterer = td.Structure( | ||
geometry=base_box, | ||
medium=td.Medium(permittivity=2.0), | ||
) | ||
|
||
sim = td.Simulation( | ||
size=(L, L, L), | ||
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50), | ||
structures=[scatterer], | ||
sources=[src], | ||
monitors=[mnt], | ||
run_time=120 / freq0, | ||
) | ||
return sim | ||
|
||
|
||
def objective_fn(center, size, angle, axis): | ||
sim = make_sim_rotation(center, size, angle, axis) | ||
sim_data = web.run(sim, task_name="emulated_rot_test", local_gradient=True, verbose=False) | ||
return anp.sum(sim_data.get_intensity("point").values) | ||
|
||
|
||
def get_grad(center, size, angle, axis): | ||
def wrapped(c, s): | ||
return objective_fn(c, s, angle, axis) | ||
|
||
val, (grad_c, grad_s) = ag.value_and_grad(wrapped, argnum=(0, 1))(center, size) | ||
return val, grad_c, grad_s | ||
|
||
|
||
@pytest.mark.numerical | ||
@pytest.mark.parametrize( | ||
"angle_deg, axis", | ||
[ | ||
(0.0, 1), | ||
(180.0, 1), | ||
(90.0, 1), | ||
(270.0, 1), | ||
], | ||
) | ||
def test_box_rotation_gradients(use_emulated_run, angle_deg, axis): | ||
center0 = (0.0, 0.0, 0.0) | ||
size0 = (2.0, 2.0, 2.0) | ||
|
||
angle_rad = np.deg2rad(angle_deg) | ||
val, grad_c, grad_s = get_grad(center0, size0, angle=None, axis=None) | ||
npx, npy, npz = grad_c | ||
sSx, sSy, sSz = grad_s | ||
|
||
assert not np.allclose(grad_c, 0.0), "center gradient is all zero." | ||
Comment on lines
+2286
to
+2290
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: use more descriptive variable names than npx,npy,npz for gradient components |
||
assert not np.allclose(grad_s, 0.0), "size gradient is all zero." | ||
|
||
if angle_deg == 180.0: | ||
# rotating 180° about y => (x,z) become negated, y stays same | ||
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis) | ||
rSx, rSy, rSz = grad_s_ref | ||
rx, ry, rz = grad_c_ref | ||
|
||
assert np.allclose(grad_c[0], -grad_c_ref[0], atol=1e-6), "center_x sign mismatch" | ||
assert np.allclose(grad_c[1], grad_c_ref[1], atol=1e-6), "center_y mismatch" | ||
assert np.allclose(grad_c[2], -grad_c_ref[2], atol=1e-6), "center_z sign mismatch" | ||
assert np.allclose(grad_s, grad_s_ref, atol=1e-6), "size grads changed unexpectedly" | ||
|
||
elif angle_deg == 90.0: | ||
# rotating 90° about y => new x= old z, new z=- old x, y stays same | ||
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis) | ||
rSx, rSy, rSz = grad_s_ref | ||
rx, ry, rz = grad_c_ref | ||
|
||
assert np.allclose(npx, rz, atol=1e-6), "center_x != old center_z" | ||
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly" | ||
assert np.allclose(npz, -rx, atol=1e-6), "center_z != - old center_x" | ||
|
||
assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z" | ||
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly" | ||
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x" | ||
|
||
elif angle_deg == 270.0: | ||
# rotating 270° about y => new x= - old z, new z= old x, y stays same | ||
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis) | ||
rSx, rSy, rSz = grad_s_ref | ||
rx, ry, rz = grad_c_ref | ||
|
||
assert np.allclose(npx, -rz, atol=1e-6), "center_x != - old center_z" | ||
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly" | ||
assert np.allclose(npz, rx, atol=1e-6), "center_z != old center_x" | ||
|
||
assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z" | ||
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly" | ||
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
""" | ||
FD vs AD checks for every affine parameter of a dielectric box: | ||
• centre c = (cx,cy,cz) | ||
• size a = (ax,ay,az) | ||
• rotation θ about each axis | ||
• scale s = (sx,sy,sz) | ||
• translation t = (tx,ty,tz) | ||
""" | ||
|
||
import atexit | ||
import os | ||
from collections import defaultdict | ||
|
||
import autograd | ||
import autograd.numpy as anp | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pytest | ||
import tidy3d as td | ||
import tidy3d.web as web | ||
from autograd import tuple | ||
|
||
# ───────── switches ─────────────────────────────────────────────────────── | ||
SAVE = False # save raw .npz | ||
PLOT = True # make png plots | ||
OUT = "./fd_ad_all_results" | ||
|
||
# ───────── physical / geometric constants ───────────────────────────────── | ||
λ = 1.5 | ||
f0 = td.C_0 / λ | ||
Lbox = 10 * λ | ||
buffer = 1.0 * λ | ||
Tsim = 120 / f0 | ||
|
||
# baseline shape parameters ( **plain Python lists / tuples** ) | ||
center0 = [0.0, 0.0, 0.0] | ||
size0 = [2.0, 2.0, 2.0] | ||
eps_box = 2.0 | ||
|
||
theta0 = np.pi / 4 # baseline rotation (x-axis) | ||
axis0 = 0 | ||
scale0 = [1.2, 1.3, 0.9] | ||
trans0 = [0.45, 0.20, 0.30] | ||
|
||
# finite-difference steps | ||
Δθ = 0.015 | ||
Δxyz = 0.03 | ||
|
||
AXES = (0, 1, 2) | ||
LAB = {0: "x", 1: "y", 2: "z"} | ||
|
||
plots = defaultdict(list) | ||
|
||
|
||
# ───────── simulation builder ───────────────────────────────────────────── | ||
def make_simulation(center, size, scale, trans, theta, axis): | ||
"""Return a Tidy3D Simulation with the requested affine parameters.""" | ||
src = td.PointDipole( | ||
center=(-Lbox / 2 + 0.5 * buffer, 0, 0), | ||
source_time=td.GaussianPulse(freq0=f0, fwidth=f0 / 10), | ||
polarization="Ez", | ||
) | ||
|
||
mon = td.FieldMonitor( | ||
center=(+Lbox / 2 - 0.5 * buffer, 2.0 * buffer, 2.0 * buffer), | ||
size=(0, 0, 0), | ||
freqs=[f0], | ||
name="m", | ||
) | ||
|
||
geom = ( | ||
td.Box(center=tuple(center), size=tuple(size)) | ||
.rotated(theta, axis=axis) | ||
.scaled(x=scale[0], y=scale[1], z=scale[2]) | ||
.translated(x=trans[0], y=trans[1], z=trans[2]) | ||
) | ||
struct = td.Structure(geometry=geom, medium=td.Medium(permittivity=eps_box)) | ||
|
||
return td.Simulation( | ||
size=(Lbox, Lbox, Lbox), | ||
run_time=Tsim, | ||
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50), | ||
sources=[src], | ||
monitors=[mon], | ||
structures=[struct], | ||
) | ||
|
||
|
||
def objective(c, a, s, t, θ, ax): | ||
sim = make_simulation(c, a, s, t, θ, ax) | ||
data = web.run(sim, task_name="fd_ad_all", verbose=False, local_gradient=True) | ||
return anp.sum(data.get_intensity("m").values) | ||
|
||
|
||
def finite_diff(fun, x0, δ): | ||
return (fun(x0 + δ) - fun(x0 - δ)) / (2 * δ) | ||
|
||
|
||
def _assert_close(fd_val, ad_val, tag, axis=None, extra="", tol=0.35): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: The tolerance value 0.35 (35%) seems quite high for numerical comparison. Consider tightening or documenting justification. |
||
""" | ||
Raise AssertionError if |FD-AD|/max(|FD|,1e-12) >= tol. | ||
""" | ||
rel = abs(fd_val - ad_val) / max(abs(fd_val), 1e-12) | ||
if rel >= tol: | ||
ax_lbl = {0: "x", 1: "y", 2: "z"}.get(axis, "") | ||
axis_str = f"_{ax_lbl}" if ax_lbl else "" | ||
raise AssertionError( | ||
f"{tag}{axis_str}{extra}: FD–AD mismatch " | ||
f"(rel diff {rel:.2%}) | FD={fd_val:.4e}, AD={ad_val:.4e}" | ||
) | ||
|
||
|
||
# 1. ROTATION θ_k (k = 0,1,2) | ||
θ_vals = np.array([0, np.pi / 4, np.pi / 2]) | ||
|
||
|
||
@pytest.mark.numerical | ||
@pytest.mark.parametrize("k", AXES, ids=[f"θ_{LAB[a]}" for a in AXES]) | ||
@pytest.mark.parametrize("θ", θ_vals) | ||
def test_fd_vs_ad_rotation(k, θ): | ||
f = lambda th: objective(center0, size0, scale0, trans0, th, k) | ||
g_ad = autograd.grad(f)(θ) | ||
g_fd = finite_diff(f, θ, Δθ) | ||
|
||
plots[("θ", k, "fd")].append((θ, g_fd)) | ||
plots[("θ", k, "ad")].append((θ, g_ad)) | ||
|
||
assert np.isfinite(g_fd) and np.isfinite(g_ad) | ||
_assert_close(g_fd, g_ad, tag="θ", axis=k) | ||
|
||
|
||
# 2. SCALE s_k | ||
@pytest.mark.numerical | ||
@pytest.mark.parametrize("k", AXES, ids=[f"s_{LAB[a]}" for a in AXES]) | ||
def test_fd_vs_ad_scale(k): | ||
def f(sk): | ||
s = list(scale0) | ||
s[k] = sk | ||
return objective(center0, size0, s, trans0, theta0, axis0) | ||
|
||
g_ad = autograd.grad(f)(scale0[k]) | ||
g_fd = finite_diff(f, scale0[k], Δxyz) | ||
|
||
plots[("s", k, "fd")].append(g_fd) | ||
plots[("s", k, "ad")].append(g_ad) | ||
|
||
assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result" | ||
_assert_close(g_fd, g_ad, tag="scale", axis=k) | ||
|
||
|
||
# 3. TRANSLATION t_k | ||
@pytest.mark.numerical | ||
@pytest.mark.parametrize("k", AXES, ids=[f"t_{LAB[a]}" for a in AXES]) | ||
def test_fd_vs_ad_translation(k): | ||
def f(tk): | ||
t = list(trans0) | ||
t[k] = tk | ||
return objective(center0, size0, scale0, t, theta0, axis0) | ||
|
||
g_ad = autograd.grad(f)(trans0[k]) | ||
g_fd = finite_diff(f, trans0[k], Δxyz) | ||
|
||
plots[("t", k, "fd")].append(g_fd) | ||
plots[("t", k, "ad")].append(g_ad) | ||
|
||
assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result" | ||
_assert_close(g_fd, g_ad, tag="trans", axis=k) | ||
|
||
|
||
# # 4. SIZE a_k | ||
@pytest.mark.numerical | ||
@pytest.mark.parametrize("k", AXES, ids=[f"a_{LAB[a]}" for a in AXES]) | ||
def test_fd_vs_ad_size(k): | ||
def f(ak): | ||
a = list(size0) | ||
a[k] = ak | ||
return objective(center0, a, scale0, trans0, theta0, axis0) | ||
|
||
g_ad = autograd.grad(f)(size0[k]) | ||
g_fd = finite_diff(f, size0[k], Δxyz) | ||
|
||
plots[("a", k, "fd")].append(g_fd) | ||
plots[("a", k, "ad")].append(g_ad) | ||
|
||
assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result" | ||
_assert_close(g_fd, g_ad, tag="size", axis=k) | ||
|
||
|
||
# # 5. CENTRE c_k | ||
@pytest.mark.numerical | ||
@pytest.mark.parametrize("k", AXES, ids=[f"c_{LAB[a]}" for a in AXES]) | ||
def test_fd_vs_ad_center(k): | ||
def f(ck): | ||
c = list(center0) | ||
c[k] = ck | ||
return objective(c, size0, scale0, trans0, theta0, axis0) | ||
|
||
g_ad = autograd.grad(f)(center0[k]) | ||
g_fd = finite_diff(f, center0[k], Δxyz) | ||
|
||
plots[("c", k, "fd")].append(g_fd) | ||
plots[("c", k, "ad")].append(g_ad) | ||
|
||
assert np.isfinite(g_fd) and np.isfinite(g_ad), "NaN/Inf in FD or AD result" | ||
_assert_close(g_fd, g_ad, tag="center", axis=k) | ||
|
||
|
||
# ───────── save / plot after test run ──────────────────────────── | ||
def _save_and_plot(): | ||
if not (SAVE or PLOT): | ||
return | ||
|
||
os.makedirs(OUT, exist_ok=True) | ||
|
||
if SAVE: | ||
np.savez_compressed(os.path.join(OUT, "gradients.npz"), **plots) | ||
|
||
labels, errs = [], [] | ||
|
||
for k in AXES: | ||
fd_pairs = plots.get(("θ", k, "fd"), []) | ||
ad_pairs = plots.get(("θ", k, "ad"), []) | ||
for (theta, fd), (_, ad) in zip(sorted(fd_pairs), sorted(ad_pairs)): | ||
lbl = f"θ_{LAB[k]} {theta:.2f}" | ||
rel = abs(fd - ad) / max(abs(fd), 1e-12) | ||
labels.append(lbl) | ||
errs.append(rel) | ||
|
||
for tag, pretty in zip(("s", "t", "a", "c"), ("scale", "trans", "size", "center")): | ||
for k in AXES: | ||
fd_vals = plots.get((tag, k, "fd"), []) | ||
ad_vals = plots.get((tag, k, "ad"), []) | ||
if fd_vals: | ||
rel = abs(fd_vals[0] - ad_vals[0]) / max(abs(fd_vals[0]), 1e-12) | ||
labels.append(f"{pretty}_{LAB[k]}") | ||
errs.append(rel) | ||
|
||
if not (PLOT and errs): | ||
return | ||
|
||
plt.figure(figsize=(12, 5)) | ||
plt.bar(range(len(errs)), errs) | ||
plt.xticks(range(len(errs)), labels, rotation=75, ha="right", fontsize=8) | ||
plt.ylabel("relative |FD – AD| / max(|FD|)") | ||
plt.title("FD vs AD relative errors for all parameters") | ||
plt.tight_layout() | ||
fname = os.path.join(OUT, "bar_rel_error_all_params.png") | ||
plt.savefig(fname, dpi=150) | ||
plt.close() | ||
print(f"[plot] saved {fname}") | ||
|
||
|
||
atexit.register(_save_and_plot) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: make_sim_rotation could be reused by storing sim as a fixture instead of recreating for each test