Skip to content

autograd support for Job.run() and Batch.run() #2722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from

Conversation

tylerflex
Copy link
Collaborator

@tylerflex tylerflex commented Aug 6, 2025

Greptile Summary

This PR implements autograd (automatic differentiation) support for Job.run() and Batch.run() methods in the Tidy3D web API. The core change enables gradient computation through these high-level object-oriented interfaces, which is essential for inverse design and optimization workflows.

The implementation follows a comprehensive refactoring plan documented in AUTOGRAD_REFACTOR_PLAN.md. The main architectural changes include:

  1. Module Reorganization: BatchData class is extracted from container.py into its own batch_data.py module to break circular import dependencies that prevented container classes from importing autograd functions.

  2. Function Delegation: Job.run() and Batch.run() methods are refactored to delegate to autograd-compatible functions (run_autograd() and run_async_autograd()) that automatically detect if simulations contain traced parameters and route them appropriately.

  3. Async Function Extraction: Batch processing logic is extracted from the Batch class into standalone async functions (upload_async, start_async, monitor_async, etc.) to enable autograd compatibility through function-based primitives.

  4. Sequential Processing: A new run_async() function in webapi.py provides autograd-compatible batch processing by sequentially calling the proven single-simulation run() function for each simulation in a batch.

The changes maintain full backward compatibility - existing code using Job.run() and Batch.run() will continue to work exactly as before, but now these methods can be used within JAX-transformable functions for gradient-based optimization. The refactoring also reduces code duplication by consolidating async operations into reusable functions.

Important Files Changed

Files and Ratings
Filename Score Overview
tidy3d/web/api/container.py 4/5 Major refactor making Job.run() and Batch.run() delegate to autograd-compatible functions
tidy3d/web/api/autograd/autograd.py 3/5 Refactored to use direct webapi calls instead of Job/Batch objects for autograd compatibility
tidy3d/web/api/asynchronous.py 3/5 Extracted async batch operations into standalone functions with potential infinite loop issue
tidy3d/web/api/webapi.py 4/5 Added sequential run_async function for autograd-compatible batch processing
tidy3d/web/api/batch_data.py 4/5 New module containing extracted BatchData class with proper Mapping interface
tidy3d/web/__init__.py 5/5 Import reorganization to prevent circular dependencies while maintaining public API
tidy3d/plugins/adjoint/web.py 4/5 Updated imports for BatchData module reorganization
tests/test_components/test_autograd.py 4/5 Added comprehensive tests for Job/Batch autograd integration
AUTOGRAD_REFACTOR_PLAN.md 4/5 Comprehensive documentation of refactoring approach with clear implementation plan
tests/utils.py 5/5 Simple import path update for BatchData module reorganization
tidy3d/plugins/design/design.py 5/5 Import reorganization separating BatchData from container imports
tests/test_plugins/test_adjoint.py 5/5 Updated BatchData import path for new module structure
tidy3d/plugins/smatrix/component_modelers/modal.py 5/5 Updated BatchData import path for module reorganization
tidy3d/plugins/smatrix/component_modelers/base.py 5/5 Split imports for Batch and BatchData into separate modules
tidy3d/plugins/smatrix/component_modelers/terminal.py 5/5 Updated BatchData import path for new module structure
docs/notebooks 4/5 Updated submodule reference likely containing new autograd examples

Confidence score: 3/5

  • This PR contains complex architectural changes with potential issues in the monitoring loop and task ID handling that could cause runtime problems
  • Score reflects the complexity of autograd integration and potential edge cases in async processing and circular import handling
  • Pay close attention to asynchronous.py monitoring loop and autograd.py task ID generation for potential infinite loops or missing task references

Sequence Diagram

sequenceDiagram
    participant User
    participant Job
    participant Batch
    participant AutogradRun as "autograd.run"
    participant AutogradRunAsync as "autograd.run_async"
    participant WebAPI as "webapi functions"
    participant Server as "Tidy3D Server"

    Note over User, Server: Job.run() - Single Simulation Flow
    User->>+Job: job.run(path="data.hdf5")
    Note over Job: Job.run() now calls autograd.run() internally
    Job->>+AutogradRun: run(simulation, task_name, path, ...)
    Note over AutogradRun: Checks if simulation is autograd-compatible
    alt Simulation has autograd tracers
        AutogradRun->>AutogradRun: _run_primitive() - forward pass
        AutogradRun->>+WebAPI: _run_tidy3d() - actual simulation
        WebAPI->>+Server: upload/start/monitor/download
        Server-->>-WebAPI: simulation_data
        WebAPI-->>-AutogradRun: simulation_data
        AutogradRun->>AutogradRun: postprocess_run() - insert tracers
    else No autograd tracers
        AutogradRun->>+WebAPI: run_webapi() - regular run
        WebAPI->>+Server: upload/start/monitor/download
        Server-->>-WebAPI: simulation_data
        WebAPI-->>-AutogradRun: simulation_data
    end
    AutogradRun-->>-Job: simulation_data (autograd-compatible)
    Job-->>-User: simulation_data

    Note over User, Server: Batch.run() - Multiple Simulations Flow
    User->>+Batch: batch.run(path_dir="./data/")
    Note over Batch: Batch.run() now calls autograd.run_async() internally
    Batch->>+AutogradRunAsync: run_async(simulations, path_dir, ...)
    Note over AutogradRunAsync: Checks if all simulations are autograd-compatible
    alt All simulations have autograd tracers
        AutogradRunAsync->>AutogradRunAsync: _run_async_primitive() - forward pass
        AutogradRunAsync->>+WebAPI: _run_async_tidy3d() - actual batch
        WebAPI->>+Server: batch upload/start/monitor/download
        Server-->>-WebAPI: batch_data
        WebAPI-->>-AutogradRunAsync: batch_data
        AutogradRunAsync->>AutogradRunAsync: postprocess batch - insert tracers
    else No autograd tracers in simulations
        AutogradRunAsync->>+WebAPI: run_async_webapi() - regular batch
        WebAPI->>+Server: batch upload/start/monitor/download
        Server-->>-WebAPI: batch_data
        WebAPI-->>-AutogradRunAsync: batch_data
    end
    AutogradRunAsync-->>-Batch: batch_data (autograd-compatible)
    Batch-->>-User: batch_data

    Note over User, Server: Gradient Computation (when tracers present)
    User->>+AutogradRun: autograd.grad(objective)(params)
    Note over AutogradRun: VJP backward pass triggered
    AutogradRun->>AutogradRun: _run_bwd() - construct adjoint sims
    AutogradRun->>+WebAPI: run adjoint simulations
    WebAPI->>+Server: adjoint upload/start/monitor/download
    Server-->>-WebAPI: adjoint_data
    WebAPI-->>-AutogradRun: adjoint_data
    AutogradRun->>AutogradRun: postprocess_adj() - compute gradients
    AutogradRun-->>-User: gradients
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, 13 comments

Edit Code Review Bot Settings | Greptile

Comment on lines +107 to +118
```python
# tidy3d/web/api/batch_data.py
from __future__ import annotations
from typing import Mapping
from tidy3d.components.base import Tidy3dBaseModel
from tidy3d.web.core.constants import TaskName
from .tidy3d_stub import SimulationDataType
import tidy3d.web.api.webapi as web # Only import webapi, not container

class BatchData(Tidy3dBaseModel, Mapping):
# Move entire BatchData class here with minimal dependencies
```
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The import structure in the BatchData extraction example may create new dependencies. Consider whether importing webapi directly might recreate circular dependencies.

# Test that autograd works through emulated Job.run() (demonstrates end goal)
val, grad = ag.value_and_grad(objective)(params0)
print(f"Job.run() with emulation - value: {val}, grad: {grad}")
# After refactor, this should work without emulation by calling autograd.run() internally
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Remove this comment about 'after refactor' since the refactor appears to be implemented based on the test assertions below

Context Used: Rule - Remove commented-out or obsolete code; rely on version control for history. (link)

Comment on lines +160 to +173
2. **Simplify Batch class**:
```python
# container.py
class Batch(WebContainer):
def upload(self) -> None:
"""Thin wrapper around asynchronous.upload_batch()"""
from .asynchronous import upload_batch
upload_batch(self.simulations, **self._get_batch_kwargs())

def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData:
"""Thin wrapper that calls asynchronous functions"""
from .asynchronous import run_async
return run_async(self.simulations, path_dir=path_dir, **self._get_batch_kwargs())
```
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The Batch.run() implementation imports run_async from asynchronous module, but Step 3 is about moving batch logic TO asynchronous.py. This creates a circular reference where Batch calls asynchronous.run_async which may need Batch internals.

# Test that autograd works through emulated Batch.run() (demonstrates end goal)
val, grad = ag.value_and_grad(objective)(params0)
print(f"Batch.run() with emulation - value: {val}, grad: {grad}")
# After refactor, this should work without emulation by calling autograd.run_async() internally
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Remove this comment about 'after refactor' since the refactor appears to be implemented based on the test assertions below

Context Used: Rule - Remove commented-out or obsolete code; rely on version control for history. (link)

Comment on lines +347 to +357
1. **Lazy imports**: Use function-level imports in container.py to avoid circular dependencies
2. **Backwards compatibility**: Keep existing methods working for users who don't need autograd
3. **Clean separation**: BatchData becomes independent, asynchronous.py contains core logic
4. **Minimal changes**: Autograd functions remain unchanged, containers adapt to them

### Potential Issues and Solutions

1. **Import timing**: If circular imports persist, use `TYPE_CHECKING` and runtime imports
2. **Test mocking**: May need to update mock strategy to handle the new call paths
3. **Plugin compatibility**: Verify that S-matrix and other plugins still work correctly
4. **Performance**: The function-level imports might add small overhead
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Function-level imports will indeed add runtime overhead on every method call. Consider caching the imported functions as class/module attributes after first import to minimize performance impact.


# We need to get the task_id from the run somehow
# For now, let's extract it from run_kwargs if it's there, or generate a placeholder
task_id = run_kwargs.get("task_id", f"autograd_{task_name}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Using placeholder task_id generation instead of getting actual task_id from webapi.run() could cause issues with autograd_fwd simulation types that need real task IDs for file uploads

batch_data = run_async_webapi(simulations=sims, **web_kwargs)

# Generate task_ids - we'll need to get these from somewhere or create placeholders
task_ids = {task_name: f"autograd_batch_{task_name}" for task_name in simulations.keys()}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Placeholder task_id generation may not work correctly for autograd_fwd cases that require real task IDs for sim_fields_keys uploads

Comment on lines 1285 to 1286
# For now, use a placeholder task_id since the individual run() calls handle everything
task_ids[task_name] = f"run_async_{task_name}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Using placeholder task IDs like f"run_async_{task_name}" may cause issues if other code expects real task IDs. Consider using the actual task IDs returned by the run() function calls.

"aborted",
)

max_task_name = max(len(task_name) for task_name in task_ids.keys())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Potential infinite loop if task_ids is empty - max() will raise ValueError

simulations=simulations,
folder_name=folder_name,
callback_url=callback_url,
num_workers=num_workers or len(simulations),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Using num_workers or len(simulations) can be problematic if num_workers=0 - should use num_workers if num_workers is not None else len(simulations)

Suggested change
num_workers=num_workers or len(simulations),
num_workers=num_workers if num_workers is not None else len(simulations),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant