diff --git a/CHANGELOG.md b/CHANGELOG.md index 459b1840c..6f6d0b30c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Significantly improved performance of the `tidy3d.plugins.autograd.grey_dilation` morphological operation and its gradient calculation. The new implementation is orders of magnitude faster, especially for large arrays and kernel sizes. +- Support for gradients with respect to the `conductivity` of a `CustomMedium`. ### Fixed - Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window. diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index 92f482d6a..244a94da8 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -315,17 +315,34 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: eps_arr = 1.01 + 0.5 * (anp.tanh(matrix @ params).reshape(DA_SHAPE) + 1) nx, ny, nz = eps_arr.shape + da_coords = { + "x": np.linspace(-0.5, 0.5, nx), + "y": np.linspace(-0.5, 0.5, ny), + "z": np.linspace(-0.5, 0.5, nz), + } custom_med = td.Structure( geometry=box, medium=td.CustomMedium( permittivity=td.SpatialDataArray( eps_arr, - coords={ - "x": np.linspace(-0.5, 0.5, nx), - "y": np.linspace(-0.5, 0.5, ny), - "z": np.linspace(-0.5, 0.5, nz), - }, + coords=da_coords, + ), + ), + ) + + # custom medium with variable permittivity and conductivity data + conductivity_arr = 0.01 * (anp.tanh(matrix @ params).reshape(DA_SHAPE) + 1) + custom_med_with_conductivity = td.Structure( + geometry=box, + medium=td.CustomMedium( + permittivity=td.SpatialDataArray( + eps_arr, + coords=da_coords, + ), + conductivity=td.SpatialDataArray( + conductivity_arr, + coords=da_coords, ), ), ) @@ -333,12 +350,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: # custom medium with vector valued permittivity data eps_ii = td.ScalarFieldDataArray( eps_arr.reshape(nx, ny, nz, 1), - coords={ - "x": np.linspace(-0.5, 0.5, nx), - "y": np.linspace(-0.5, 0.5, ny), - "z": np.linspace(-0.5, 0.5, nz), - "f": [td.C_0], - }, + coords=da_coords | {"f": [td.C_0]}, ) custom_med_vec = td.Structure( @@ -480,6 +492,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: "center_list": center_list, "size_element": size_element, "custom_med": custom_med, + "custom_med_with_conductivity": custom_med_with_conductivity, "custom_med_vec": custom_med_vec, "polyslab": polyslab, "polyslab_dispersive": polyslab_dispersive, @@ -577,6 +590,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None: "center_list", "size_element", "custom_med", + "custom_med_with_conductivity", "custom_med_vec", "polyslab", "complex_polyslab", @@ -1661,7 +1675,7 @@ def J(eps): monkeypatch.setattr( td.CustomPoleResidue, "_derivative_field_cmp", - lambda self, E_der_map, eps_data, dim, freqs: dJ_deps / 3.0, + lambda self, E_der_map, eps_data, dim, freqs, component="real": dJ_deps / 3.0, ) import importlib @@ -2352,3 +2366,52 @@ def objective(x): with pytest.raises(ValueError): g = ag.grad(objective)(1.0) + + +def test_custom_medium_conductivity_only_gradient(rng, use_emulated_run, tmp_path): + """Test conductivity gradients for CustomMedium with constant permittivity.""" + + monitor, postprocess = make_monitors()["field_point"] + + def objective(params): + """Objective function testing only conductivity gradient (constant permittivity).""" + len_arr = np.prod(DA_SHAPE) + matrix = rng.random((len_arr, N_PARAMS)) + + # constant permittivity + eps_arr = np.ones(DA_SHAPE) * 2.0 + + # variable conductivity + conductivity_arr = 0.05 * (anp.tanh(3 * matrix @ params).reshape(DA_SHAPE) + 1) + + nx, ny, nz = DA_SHAPE + coords = { + "x": np.linspace(-0.5, 0.5, nx), + "y": np.linspace(-0.5, 0.5, ny), + "z": np.linspace(-0.5, 0.5, nz), + } + + custom_med_struct = td.Structure( + geometry=td.Box(center=(0, 0, 0), size=(1, 1, 1)), + medium=td.CustomMedium( + permittivity=td.SpatialDataArray(eps_arr, coords=coords), + conductivity=td.SpatialDataArray(conductivity_arr, coords=coords), + ), + ) + + sim = SIM_BASE.updated_copy( + structures=[custom_med_struct], + monitors=[monitor], + ) + + data = run( + sim, + path=str(tmp_path / "sim_test.hdf5"), + task_name="conductivity_only_grad_test", + verbose=False, + ) + return postprocess(data, data[monitor.name]) + + val, grad = ag.value_and_grad(objective)(params0) + + assert anp.all(grad != 0.0), "some gradients are 0 for conductivity-only test" diff --git a/tests/test_components/test_autograd_conductivity_numerical.py b/tests/test_components/test_autograd_conductivity_numerical.py new file mode 100644 index 000000000..54fdf5330 --- /dev/null +++ b/tests/test_components/test_autograd_conductivity_numerical.py @@ -0,0 +1,519 @@ +"""Test autograd conductivity gradients by comparing to numerically computed finite difference gradients. + +This test validates the implementation of conductivity gradients for CustomMedium in the autograd +system. It creates simulations with CustomMedium objects that have constant permittivity and +variable conductivity, then compares the gradients computed by autograd against finite difference +approximations. + +The test covers: +- Multiple wavelengths (optical at 1.55μm and microwave at 15.5μm) +- Different conductivity scales (0.1x, 1x, 10x base conductivity) +- Various monitor configurations (point, 2D, and 3D monitors) +- Different background indices + +Similar to test_autograd_numerical.py but specifically tests conductivity gradients instead of +permittivity gradients. This addresses the feature request by Greg Roberts to add numerical +validation for conductivity gradients in CustomMedium. +""" + +from __future__ import annotations + +import operator +import sys + +import autograd as ag +import matplotlib.pylab as plt +import numpy as np +import pytest +from scipy.ndimage import gaussian_filter + +import tidy3d as td +import tidy3d.web as web + +PLOT_FD_ADJ_COMPARISON = True +NUM_FINITE_DIFFERENCE = 10 +SAVE_FD_ADJ_DATA = True +SAVE_FD_LOC = 0 +SAVE_ADJ_LOC = 1 +LOCAL_GRADIENT = True +VERBOSE = False +NUMERICAL_RESULTS_DATA_DIR = "./numerical_conductivity_test/" +SHOW_PRINT_STATEMENTS = False + +RMS_THRESHOLD = 0.25 + +if PLOT_FD_ADJ_COMPARISON: + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") +else: + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") + +if SHOW_PRINT_STATEMENTS: + sys.stdout = sys.stderr + + +# Constants for conductivity testing +CONDUCTIVITY_SEED = 0.01 +MESH_FACTOR_DESIGN = 30.0 + + +def get_sim_geometry(mesh_wvl_um): + """Returns the simulation domain geometry.""" + return td.Box(size=(5 * mesh_wvl_um, 5 * mesh_wvl_um, 7 * mesh_wvl_um), center=(0, 0, 0)) + + +def make_base_sim( + mesh_wvl_um, + adj_wvl_um, + monitor_size_wvl, + box_for_override, + monitor_bg_index=1.0, + run_time=1e-11, +): + """Creates a base simulation for conductivity gradient testing. + + Parameters + ---------- + mesh_wvl_um : float + Mesh wavelength in micrometers + adj_wvl_um : float + Adjoint wavelength in micrometers + monitor_size_wvl : tuple + Monitor size in wavelengths (x, y, z) + box_for_override : td.Box + Box geometry for mesh override + monitor_bg_index : float = 1.0 + Background refractive index for monitor region + run_time : float = 1e-11 + Simulation run time in seconds + + Returns + ------- + td.Simulation + Base simulation without the conductivity structure + """ + sim_geometry = get_sim_geometry(mesh_wvl_um) + sim_size_um = sim_geometry.size + sim_center_um = sim_geometry.center + + boundary_spec = td.BoundarySpec( + x=td.Boundary.pml(), + y=td.Boundary.pml(), + z=td.Boundary.pml(), + ) + + dl_design = mesh_wvl_um / MESH_FACTOR_DESIGN + + mesh_overrides = [] + mesh_overrides.extend( + [ + td.MeshOverrideStructure( + geometry=box_for_override, + dl=[dl_design, dl_design, dl_design], + ), + ] + ) + + src_size = sim_size_um[0:2] + (0,) + + wl_min_src_um = 0.9 * adj_wvl_um + wl_max_src_um = 1.1 * adj_wvl_um + + fwidth_src = td.C_0 * ((1.0 / wl_min_src_um) - (1.0 / wl_max_src_um)) + freq0 = td.C_0 / adj_wvl_um + + pulse = td.GaussianPulse(freq0=freq0, fwidth=fwidth_src) + src = td.PlaneWave( + center=(0, 0, -2 * mesh_wvl_um), + size=src_size, + source_time=pulse, + direction="+", + ) + + field_monitor = td.FieldMonitor( + center=(0, 0, 0.25 * sim_size_um[2]), + size=tuple(dim * mesh_wvl_um for dim in monitor_size_wvl), + name="monitor_fields", + freqs=[freq0], + ) + + monitor_index_block = td.Box( + center=(0, 0, 0.25 * sim_size_um[2] + mesh_wvl_um), + size=(*tuple(2 * size for size in sim_size_um[0:2]), mesh_wvl_um + 0.5 * sim_size_um[2]), + ) + monitor_index_block_structure = td.Structure( + geometry=monitor_index_block, medium=td.Medium(permittivity=monitor_bg_index**2) + ) + + sim_base = td.Simulation( + center=sim_center_um, + size=sim_size_um, + grid_spec=td.GridSpec.auto( + min_steps_per_wvl=30, + wavelength=mesh_wvl_um, + override_structures=mesh_overrides, + ), + structures=[monitor_index_block_structure], + sources=[src], + monitors=[field_monitor], + run_time=run_time, + boundary_spec=boundary_spec, + subpixel=True, + ) + + return sim_base + + +def create_objective_function(geometry, create_sim_base, eval_fn, sim_path_dir, dims): + """Creates an objective function for conductivity gradient testing. + + This function returns an objective that takes conductivity arrays and returns + the evaluation function value. It's designed to work with autograd for computing + gradients with respect to the conductivity arrays. + + Parameters + ---------- + geometry : td.Box + Geometry for the conductivity structure + create_sim_base : callable + Function that creates the base simulation + eval_fn : callable + Evaluation function to compute objective from simulation data + sim_path_dir : str + Directory path for simulation files + dims : tuple + Dimensions (nx, ny, nz) for the conductivity array + + Returns + ------- + callable + Objective function that takes conductivity arrays and returns scalar value + """ + + def objective(conductivity_arrays): + sim_base = create_sim_base() + + simulation_dict = {} + for idx in range(len(conductivity_arrays)): + # Get bounds and create coordinates + bounds = geometry.bounds + nx, ny, nz = dims + coords = { + "x": np.linspace(bounds[0][0], bounds[1][0], nx), + "y": np.linspace(bounds[0][1], bounds[1][1], ny), + "z": np.linspace(bounds[0][2], bounds[1][2], nz), + } + + # Create CustomMedium with constant permittivity and variable conductivity + custom_medium = td.CustomMedium( + permittivity=td.SpatialDataArray( + np.ones_like(conductivity_arrays[idx]), + coords=coords, + ), + conductivity=td.SpatialDataArray( + conductivity_arrays[idx], + coords=coords, + ), + ) + + block_structure = td.Structure( + geometry=geometry, + medium=custom_medium, + ) + + sim_with_block = sim_base.updated_copy( + structures=(*sim_base.structures, block_structure) + ) + + simulation_dict[f"numerical_conductivity_testing_{idx}"] = sim_with_block.copy() + + sim_data = web.run_async( + simulation_dict, path_dir=sim_path_dir, local_gradient=LOCAL_GRADIENT, verbose=VERBOSE + ) + + objective_vals = [] + for idx in range(len(conductivity_arrays)): + objective_vals.append(eval_fn(sim_data[f"numerical_conductivity_testing_{idx}"])) + + if len(conductivity_arrays) == 1: + return objective_vals[0] + + return objective_vals + + return objective + + +def make_eval_fns(monitor_size_wvl): + """Creates evaluation functions for different monitor configurations. + + Parameters + ---------- + monitor_size_wvl : tuple + Monitor size in wavelengths (x, y, z) + + Returns + ------- + tuple + (list of evaluation functions, list of function names) + """ + num_nonzero_spatial_dims = 3 - np.sum(np.isclose(monitor_size_wvl, 0)) + + def intensity(sim_data): + """Computes intensity at the center of the monitor.""" + field_data = sim_data["monitor_fields"] + shape_x, shape_y, shape_z, *_ = field_data.Ex.values.shape + + total = 0.0 + return np.sum( + np.abs(field_data.Ex.values[shape_x // 2, shape_y // 2, shape_z // 2]) ** 2 + + np.abs(field_data.Ey.values[shape_x // 2, shape_y // 2, shape_z // 2]) ** 2 + + np.abs(field_data.Ez.values[shape_x // 2, shape_y // 2, shape_z // 2]) ** 2 + ) + + eval_fns = [intensity] + eval_fn_names = ["intensity"] + + if num_nonzero_spatial_dims == 2: + + def flux(sim_data): + """Computes flux through the monitor.""" + field_data = sim_data["monitor_fields"] + return np.sum(field_data.flux.values) + + eval_fns.append(flux) + eval_fn_names.append("flux") + + return eval_fns, eval_fn_names + + +# Test parameters for conductivity gradient testing +background_indices = [1.0, 1.5] +mesh_wvls_um = [1.55, 1.55, 10 * 1.55, 10 * 1.55] +adj_wvls_um = [1.55, 2.2, 10 * 1.55, 10 * 2.2] +monitor_sizes_3d_wvl = [(0.5, 0.5, 0), (0.5, 0.5, 0.5), (0.5, 0, 0), (0, 0.5, 0), (0, 0, 0)] + +# Different conductivity ranges to test +conductivity_scales = [1.0, 0.1, 10.0] + +conductivity_data_test_parameters = [] + +test_number = 0 +for idx in range(len(mesh_wvls_um)): + mesh_wvl_um = mesh_wvls_um[idx] + adj_wvl_um = adj_wvls_um[idx] + + for monitor_size_wvl in monitor_sizes_3d_wvl: + eval_fns, eval_fn_names = make_eval_fns(monitor_size_wvl) + + for monitor_bg_index in background_indices: + for conductivity_scale in conductivity_scales: + for eval_fn_idx, eval_fn in enumerate(eval_fns): + conductivity_data_test_parameters.append( + { + "mesh_wvl_um": mesh_wvl_um, + "adj_wvl_um": adj_wvl_um, + "monitor_size_wvl": monitor_size_wvl, + "monitor_bg_index": monitor_bg_index, + "conductivity_scale": conductivity_scale, + "eval_fn": eval_fn, + "eval_fn_name": eval_fn_names[eval_fn_idx], + "test_number": test_number, + } + ) + + test_number += 1 + + +@pytest.mark.numerical +@pytest.mark.parametrize( + "conductivity_data_test_parameters, dir_name", + zip( + conductivity_data_test_parameters, + ([NUMERICAL_RESULTS_DATA_DIR] if SAVE_FD_ADJ_DATA else [None]) + * len(conductivity_data_test_parameters), + ), + indirect=["dir_name"], +) +def test_finite_difference_conductivity_data( + conductivity_data_test_parameters, rng, tmp_path, create_directory +): + """Test autograd conductivity gradients by comparing to numerical finite difference. + + This test validates that the autograd implementation correctly computes gradients + with respect to conductivity arrays in CustomMedium. It uses finite difference + approximation as the ground truth and ensures the autograd gradients match within + a specified tolerance. + + The test procedure: + 1. Create a CustomMedium with constant permittivity and variable conductivity + 2. Compute objective function and its gradient using autograd + 3. Compute finite difference gradients by perturbing conductivity + 4. Compare the two gradients using RMS error + + Parameters + ---------- + conductivity_data_test_parameters : dict + Test parameters including wavelengths, monitor configuration, etc. + rng : numpy.random.Generator + Random number generator for creating perturbation patterns + tmp_path : pathlib.Path + Temporary directory for simulation files + create_directory : fixture + Pytest fixture for creating directories + """ + + # Create directory for plots if plotting is enabled + if PLOT_FD_ADJ_COMPARISON or SAVE_FD_ADJ_DATA: + import os + + os.makedirs(NUMERICAL_RESULTS_DATA_DIR, exist_ok=True) + + num_tests = 0 + for monitor_size_wvl in monitor_sizes_3d_wvl: + eval_fns, _ = make_eval_fns(monitor_size_wvl) + num_tests += ( + len(eval_fns) * len(background_indices) * len(mesh_wvls_um) * len(conductivity_scales) + ) + + test_results = np.zeros((2, NUM_FINITE_DIFFERENCE)) + + test_number = conductivity_data_test_parameters["test_number"] + + ( + mesh_wvl_um, + adj_wvl_um, + monitor_size_wvl, + monitor_bg_index, + conductivity_scale, + eval_fn, + eval_fn_name, + test_number, + ) = operator.itemgetter( + "mesh_wvl_um", + "adj_wvl_um", + "monitor_size_wvl", + "monitor_bg_index", + "conductivity_scale", + "eval_fn", + "eval_fn_name", + "test_number", + )(conductivity_data_test_parameters) + + dim_um = mesh_wvl_um + thickness_um = 0.5 * mesh_wvl_um + block = td.Box(center=(0, 0, 0), size=(dim_um, dim_um, thickness_um)) + + dim = 1 + int(dim_um / (mesh_wvl_um / MESH_FACTOR_DESIGN)) + Nz = 1 + int(thickness_um / (mesh_wvl_um / MESH_FACTOR_DESIGN)) + + sim_geometry = get_sim_geometry(mesh_wvl_um) + + box_for_override = td.Box( + center=(0, 0, 0), size=sim_geometry.size[0:2] + (thickness_um + mesh_wvl_um,) + ) + + eval_fns, eval_fn_names = make_eval_fns(monitor_size_wvl) + + sim_path_dir = tmp_path / f"test{test_number}" + sim_path_dir.mkdir() + + objective = create_objective_function( + block, + lambda mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + monitor_size_wvl=monitor_size_wvl, + box_for_override=box_for_override, + monitor_bg_index=monitor_bg_index: make_base_sim( + mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + monitor_size_wvl=monitor_size_wvl, + box_for_override=box_for_override, + monitor_bg_index=monitor_bg_index, + ), + eval_fn, + sim_path_dir=str(sim_path_dir), + dims=(dim, dim, Nz), + ) + + obj_val_and_grad = ag.value_and_grad(objective) + + # Initial conductivity array + conductivity_init = CONDUCTIVITY_SEED * conductivity_scale * np.ones((dim, dim, Nz)) + + obj, adj_grad = obj_val_and_grad([conductivity_init]) + + # empirical step size for finite difference + fd_step = 0.1 + + all_conductivity = [] + pattern_dot_adj_gradient = np.zeros(NUM_FINITE_DIFFERENCE) + + for fd_idx in range(NUM_FINITE_DIFFERENCE): + random_pattern = rng.random((dim, dim, Nz)) - 0.5 + random_pattern = gaussian_filter(random_pattern, sigma=3) + random_pattern /= np.linalg.norm(random_pattern) + + pattern_dot_adj_gradient[fd_idx] = np.sum(random_pattern * adj_grad) + + conductivity_up = conductivity_init.copy() + fd_step * random_pattern + conductivity_down = conductivity_init.copy() - fd_step * random_pattern + + all_conductivity.append(conductivity_up) + all_conductivity.append(conductivity_down) + + all_obj = objective(all_conductivity) + + fd_grad = np.zeros(NUM_FINITE_DIFFERENCE) + for fd_idx in range(NUM_FINITE_DIFFERENCE): + obj_up_location = 2 * fd_idx + obj_down_location = 2 * fd_idx + 1 + + fd_grad[fd_idx] = (all_obj[obj_up_location] - all_obj[obj_down_location]) / (2 * fd_step) + + rms_error = np.linalg.norm(fd_grad - pattern_dot_adj_gradient) + fd_mag = np.linalg.norm(fd_grad) + adj_mag = np.linalg.norm(pattern_dot_adj_gradient) + percentage_error = 100.0 * np.mean( + (fd_grad - pattern_dot_adj_gradient) / (fd_grad + np.finfo(np.float64).eps) + ) + + print("\n" * 3) + print("-" * 20) + print(f"Numerical test #{test_number}") + print(f"Mesh and adjoint wavelengths: {mesh_wvl_um}, {adj_wvl_um}") + print(f"Monitor size: {monitor_size_wvl}") + print(f"Background index for monitor: {monitor_bg_index}") + print(f"Conductivity scale: {conductivity_scale}") + print(f"Eval function: {eval_fn_name}") + print(f"RMS Error: {rms_error}") + print(f"FD, Adj magnitudes: {fd_mag}, {adj_mag}") + print(f"Percentage Error: {percentage_error}") + print("-" * 20) + print("\n" * 3) + + assert rms_error < RMS_THRESHOLD * fd_mag, "RMS error magnitude too large" + + test_results[SAVE_FD_LOC, :] = fd_grad + test_results[SAVE_ADJ_LOC, :] = pattern_dot_adj_gradient + + test_number += 1 + + if PLOT_FD_ADJ_COMPARISON: + plt.figure(figsize=(10, 6)) + plt.plot(pattern_dot_adj_gradient, color="g", linewidth=2.0, label="Adjoint") + plt.plot(fd_grad, color="b", linewidth=1.5, linestyle="--", label="Finite difference") + plt.title(f"Gradient comparison for {eval_fn_name} (Test #{test_number})") + plt.xlabel("Sample number") + plt.ylabel("Gradient value") + plt.legend() + plt.grid(True, alpha=0.3) + + # Save the plot + plot_filename = f"{NUMERICAL_RESULTS_DATA_DIR}/gradient_comparison_test_{test_number}_{eval_fn_name}.png" + plt.savefig(plot_filename, dpi=150, bbox_inches="tight") + print(f"Plot saved to: {plot_filename}") + + plt.show() + plt.close() + + if SAVE_FD_ADJ_DATA: + np.save(f"{NUMERICAL_RESULTS_DATA_DIR}/results_{test_number}.npy", test_results) diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index ee371e916..0da0f50d4 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -1661,14 +1661,16 @@ def _not_loaded(field): def _derivative_field_cmp( self, E_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, + spatial_data: PermittivityDataset, dim: str, freqs: NDArray, ) -> np.ndarray: - coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1} - dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp} + coords_interp = {key: val for key, val in spatial_data.coords.items() if len(val) > 1} + dims_sum = {dim for dim in spatial_data.coords.keys() if dim not in coords_interp} - eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"] + eps_coordinate_shape = [ + len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz" + ] # compute sizes along each of the interpolation dimensions sizes_list = [] @@ -2833,14 +2835,27 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField vjps = {} for field_path in derivative_info.paths: - if field_path == ("permittivity",): + if field_path[0] == "permittivity": vjp_array = 0.0 for dim in "xyz": vjp_array += self._derivative_field_cmp( E_der_map=derivative_info.E_der_map, - eps_data=self.permittivity, + spatial_data=self.permittivity, dim=dim, freqs=np.atleast_1d(derivative_info.frequency), + component="real", + ) + vjps[field_path] = vjp_array + + elif field_path[0] == "conductivity": + vjp_array = 0.0 + for dim in "xyz": + vjp_array += self._derivative_field_cmp( + E_der_map=derivative_info.E_der_map, + spatial_data=self.conductivity, + dim=dim, + freqs=np.atleast_1d(derivative_info.frequency), + component="imag", ) vjps[field_path] = vjp_array @@ -2849,11 +2864,11 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField dim = key[-1] vjps[field_path] = self._derivative_field_cmp( E_der_map=derivative_info.E_der_map, - eps_data=self.eps_dataset.field_components[key], + spatial_data=self.eps_dataset.field_components[key], dim=dim, freqs=np.atleast_1d(derivative_info.frequency), + component="complex", ) - else: raise NotImplementedError( f"No derivative defined for 'CustomMedium' field: {field_path}." @@ -2864,15 +2879,18 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField def _derivative_field_cmp( self, E_der_map: ElectromagneticFieldDataset, - eps_data: PermittivityDataset, + spatial_data: CustomSpatialDataTypeAnnotated, dim: str, freqs: NDArray, + component: str = "real", ) -> np.ndarray: - """Compute derivative with respect to the ``dim`` components within the custom medium.""" - coords_interp = {key: eps_data.coords[key] for key in "xyz"} + """Compute the derivative with respect to a material property component.""" + coords_interp = {key: spatial_data.coords[key] for key in "xyz"} coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1} - eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"] + eps_coordinate_shape = [ + len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz" + ] E_der_dim_interp = E_der_map[f"E{dim}"].sel(f=freqs) @@ -2910,10 +2928,26 @@ def _derivative_field_cmp( # if sizes_list is empty, then reduce() fails d_vol = np.array(1.0) - # TODO: probably this could be more robust. eg if the DataArray has weird edge cases - E_der_dim_interp = ( - E_der_dim_interp.interp(**coords_interp, assume_sorted=True).fillna(0.0).real.sum("f") - ) + E_der_dim_interp_complex = E_der_dim_interp.interp( + **coords_interp, assume_sorted=True + ).fillna(0.0) + + if component == "imag": + # convert from derivative w.r.t. complex permittivity to derivative w.r.t. conductivity + E_der_dim_interp = E_der_dim_interp_complex.imag + # frequency-dependent scaling must be applied before summing over frequencies + for freq in freqs: + vjp_imag = Medium.eps_sigma_to_eps_complex( + eps_real=0, sigma=E_der_dim_interp.sel(f=freq), freq=freq + ).imag + E_der_dim_interp.loc[{"f": freq}] = vjp_imag + elif component == "complex": + # for complex permittivity in eps_dataset, return the full complex derivative + E_der_dim_interp = E_der_dim_interp_complex + else: + E_der_dim_interp = E_der_dim_interp_complex.real + + E_der_dim_interp = E_der_dim_interp.sum("f") try: E_der_dim_interp = E_der_dim_interp * d_vol.reshape(E_der_dim_interp.shape) @@ -3909,7 +3943,7 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField for dim in "xyz": dJ_deps_complex += self._derivative_field_cmp( E_der_map=derivative_info.E_der_map, - eps_data=self.eps_inf, + spatial_data=self.eps_inf, dim=dim, freqs=np.atleast_1d(derivative_info.frequency), )