Skip to content

add source differentiation #2710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft

add source differentiation #2710

wants to merge 7 commits into from

Conversation

tylerflex
Copy link
Collaborator

@tylerflex tylerflex commented Aug 1, 2025

adds ability to diff w.r.t. sources. but doesn't actually implement the VJP yet!

Greptile Summary

This PR adds foundational infrastructure for source differentiation in Tidy3D's autograd system, enabling automatic differentiation with respect to source parameters like CustomCurrentSource. The changes establish the framework to treat sources as differentiable components alongside existing structure differentiation capabilities.

The implementation follows the established autograd pattern in the codebase where differentiable components implement a _compute_derivatives method and are tracked through path-based systems. The main architectural changes include:

  • Path system extension: Updated from single starting_path to multiple starting_paths throughout the autograd pipeline to handle both structures and sources simultaneously
  • Component type distinction: Added logic to differentiate between 'structures' and 'sources' when processing gradients, with separate handling paths for each
  • Source integration: Added CustomCurrentSource._compute_derivatives() method and updated the web API autograd system to recognize sources as traceable components
  • Testing infrastructure: Added test coverage for the new source differentiation pathway

The changes integrate cleanly with the existing autograd architecture by extending the path-based field tracking system used for structures. The _strip_traced_fields method signature was updated across the codebase to support multiple starting paths, enabling efficient batch processing of both structure and source parameters in a single operation.

PR Description Notes:

  • The PR body contains a minor grammatical issue: 'diff w.r.t.' could be more clearly written as 'differentiate with respect to'

Important Files Changed

Changed Files
Filename Score Overview
tidy3d/components/source/current.py 4/5 Added placeholder _compute_derivatives method to CustomCurrentSource with proper logging and zero gradient returns
tidy3d/web/api/autograd/autograd.py 4/5 Extended autograd pipeline to handle sources alongside structures, added source gradient processing logic with placeholder implementation
tidy3d/components/base.py 4/5 Modified _strip_traced_fields to accept multiple starting paths instead of single path for batch processing
tidy3d/components/simulation.py 4/5 Updated adjoint monitor creation to handle both structures and sources with placeholder logic for source processing
tests/test_components/test_autograd.py 4/5 Added comprehensive test for source autograd functionality and updated existing tests for new method signatures
tidy3d/components/geometry/base.py 4/5 Updated ClipOperation validator to use new starting_paths parameter in _strip_traced_fields call
tidy3d/components/grid/grid_spec.py 4/5 Modified grid generation logic to use updated _strip_traced_fields method signature
tidy3d/plugins/autograd/README.md 4/5 Added documentation for source differentiation capabilities with clear placeholder status indication
CHANGELOG.md 5/5 Added clear changelog entry documenting the infrastructure addition with appropriate caveats about placeholder implementation
plan.md 4/5 Added comprehensive implementation plan documenting design decisions and future roadmap for source differentiation
docs/faq 2/5 Submodule update resulted in empty file, potentially breaking FAQ documentation functionality

Confidence score: 4/5

  • This PR is relatively safe to merge as a foundational infrastructure change with clear placeholder implementations
  • Score reflects well-structured incremental development approach with proper documentation and testing, though actual functionality is not yet implemented
  • Pay close attention to docs/faq submodule which appears to have issues that could break documentation

@tylerflex tylerflex marked this pull request as draft August 1, 2025 18:48
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 7 comments

Edit Code Review Bot Settings | Greptile


def objective(source_amplitude):
# Create traced field data for the source
field_data = source_amplitude * np.ones((10, 10, 1, 1))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Missing import for np - the example uses np.ones() but only imports anp

Suggested change
field_data = source_amplitude * np.ones((10, 10, 1, 1))
field_data = source_amplitude * anp.ones((10, 10, 1, 1))

def objective(source_amplitude):
# Create traced field data for the source
field_data = source_amplitude * np.ones((10, 10, 1, 1))
scalar_field = td.ScalarFieldDataArray(field_data, coords=coords)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Variable coords is undefined in the example code

Comment on lines 1005 to 1006
if isinstance(starting_paths[0], str):
starting_paths = (starting_paths,)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The type check isinstance(starting_paths[0], str) could fail with IndexError if starting_paths is an empty tuple. Should check starting_paths and isinstance(starting_paths[0], str) instead.

Suggested change
if isinstance(starting_paths[0], str):
starting_paths = (starting_paths,)
if starting_paths and isinstance(starting_paths[0], str):
starting_paths = (starting_paths,)

sim = td.Simulation(
size=(2.0, 2.0, 2.0),
sources=[custom_source],
monitors=[td.FieldMonitor(size=(1.0, 1.0, 0.0), center=(0, 0, 0), freqs=[2e14])]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Monitor name should match the name used in line 266 (load_field_monitor expects a monitor name)

Suggested change
monitors=[td.FieldMonitor(size=(1.0, 1.0, 0.0), center=(0, 0, 0), freqs=[2e14])]
monitors=[td.FieldMonitor(size=(1.0, 1.0, 0.0), center=(0, 0, 0), freqs=[2e14], name="field_monitor")]

Comment on lines 1230 to 1247
derivative_info = DerivativeInfo(
paths=source_paths,
E_der_map={}, # Placeholder - source-specific field maps needed
D_der_map={}, # Placeholder - source-specific field maps needed
E_fwd=None, # Placeholder - source-specific forward fields needed
E_adj=None, # Placeholder - source-specific adjoint fields needed
D_fwd=None, # Placeholder - source-specific forward fields needed
D_adj=None, # Placeholder - source-specific adjoint fields needed
eps_data=None, # Not applicable for sources
eps_in=None, # Not applicable for sources
eps_out=None, # Not applicable for sources
eps_background=None, # Not applicable for sources
frequencies=np.array([]), # Placeholder
eps_no_structure=None, # Not applicable for sources
eps_inf_structure=None, # Not applicable for sources
bounds=((0, 0, 0), (0, 0, 0)), # Placeholder
bounds_intersect=((0, 0, 0), (0, 0, 0)), # Placeholder
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: DerivativeInfo is created with placeholder/empty values for all source-specific fields. This will likely cause issues when _compute_derivatives is actually called on sources.

return sim_fields_vjp
else:
# Fallback for sources without _compute_derivatives method
td.log.warning(f"Source {source_index} does not have _compute_derivatives method")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Warning message uses f-string with Source {source_index} but doesn't identify which source type or provide enough context for debugging.

Context Used: Rule - Make log messages and warnings informative by including relevant context and enclosing variable names in single quotes. (link)


# For now, return placeholder gradients with expected structure
# This is a placeholder for future implementation
import tidy3d as td
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider using td.log.debug() instead of td.log.info() since this is a placeholder message that will appear frequently during development

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant