-
Notifications
You must be signed in to change notification settings - Fork 63
autograd support for Job.run()
and Batch.run()
#2722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, 13 comments
```python | ||
# tidy3d/web/api/batch_data.py | ||
from __future__ import annotations | ||
from typing import Mapping | ||
from tidy3d.components.base import Tidy3dBaseModel | ||
from tidy3d.web.core.constants import TaskName | ||
from .tidy3d_stub import SimulationDataType | ||
import tidy3d.web.api.webapi as web # Only import webapi, not container | ||
|
||
class BatchData(Tidy3dBaseModel, Mapping): | ||
# Move entire BatchData class here with minimal dependencies | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The import structure in the BatchData extraction example may create new dependencies. Consider whether importing webapi
directly might recreate circular dependencies.
# Test that autograd works through emulated Job.run() (demonstrates end goal) | ||
val, grad = ag.value_and_grad(objective)(params0) | ||
print(f"Job.run() with emulation - value: {val}, grad: {grad}") | ||
# After refactor, this should work without emulation by calling autograd.run() internally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Remove this comment about 'after refactor' since the refactor appears to be implemented based on the test assertions below
Context Used: Rule - Remove commented-out or obsolete code; rely on version control for history. (link)
2. **Simplify Batch class**: | ||
```python | ||
# container.py | ||
class Batch(WebContainer): | ||
def upload(self) -> None: | ||
"""Thin wrapper around asynchronous.upload_batch()""" | ||
from .asynchronous import upload_batch | ||
upload_batch(self.simulations, **self._get_batch_kwargs()) | ||
|
||
def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData: | ||
"""Thin wrapper that calls asynchronous functions""" | ||
from .asynchronous import run_async | ||
return run_async(self.simulations, path_dir=path_dir, **self._get_batch_kwargs()) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The Batch.run() implementation imports run_async
from asynchronous module, but Step 3 is about moving batch logic TO asynchronous.py. This creates a circular reference where Batch calls asynchronous.run_async which may need Batch internals.
# Test that autograd works through emulated Batch.run() (demonstrates end goal) | ||
val, grad = ag.value_and_grad(objective)(params0) | ||
print(f"Batch.run() with emulation - value: {val}, grad: {grad}") | ||
# After refactor, this should work without emulation by calling autograd.run_async() internally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Remove this comment about 'after refactor' since the refactor appears to be implemented based on the test assertions below
Context Used: Rule - Remove commented-out or obsolete code; rely on version control for history. (link)
1. **Lazy imports**: Use function-level imports in container.py to avoid circular dependencies | ||
2. **Backwards compatibility**: Keep existing methods working for users who don't need autograd | ||
3. **Clean separation**: BatchData becomes independent, asynchronous.py contains core logic | ||
4. **Minimal changes**: Autograd functions remain unchanged, containers adapt to them | ||
|
||
### Potential Issues and Solutions | ||
|
||
1. **Import timing**: If circular imports persist, use `TYPE_CHECKING` and runtime imports | ||
2. **Test mocking**: May need to update mock strategy to handle the new call paths | ||
3. **Plugin compatibility**: Verify that S-matrix and other plugins still work correctly | ||
4. **Performance**: The function-level imports might add small overhead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Function-level imports will indeed add runtime overhead on every method call. Consider caching the imported functions as class/module attributes after first import to minimize performance impact.
|
||
# We need to get the task_id from the run somehow | ||
# For now, let's extract it from run_kwargs if it's there, or generate a placeholder | ||
task_id = run_kwargs.get("task_id", f"autograd_{task_name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Using placeholder task_id generation instead of getting actual task_id from webapi.run() could cause issues with autograd_fwd simulation types that need real task IDs for file uploads
batch_data = run_async_webapi(simulations=sims, **web_kwargs) | ||
|
||
# Generate task_ids - we'll need to get these from somewhere or create placeholders | ||
task_ids = {task_name: f"autograd_batch_{task_name}" for task_name in simulations.keys()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Placeholder task_id generation may not work correctly for autograd_fwd cases that require real task IDs for sim_fields_keys uploads
tidy3d/web/api/webapi.py
Outdated
# For now, use a placeholder task_id since the individual run() calls handle everything | ||
task_ids[task_name] = f"run_async_{task_name}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Using placeholder task IDs like f"run_async_{task_name}"
may cause issues if other code expects real task IDs. Consider using the actual task IDs returned by the run()
function calls.
"aborted", | ||
) | ||
|
||
max_task_name = max(len(task_name) for task_name in task_ids.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Potential infinite loop if task_ids
is empty - max()
will raise ValueError
simulations=simulations, | ||
folder_name=folder_name, | ||
callback_url=callback_url, | ||
num_workers=num_workers or len(simulations), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Using num_workers or len(simulations)
can be problematic if num_workers=0
- should use num_workers if num_workers is not None else len(simulations)
num_workers=num_workers or len(simulations), | |
num_workers=num_workers if num_workers is not None else len(simulations), |
Greptile Summary
This PR implements autograd (automatic differentiation) support for
Job.run()
andBatch.run()
methods in the Tidy3D web API. The core change enables gradient computation through these high-level object-oriented interfaces, which is essential for inverse design and optimization workflows.The implementation follows a comprehensive refactoring plan documented in
AUTOGRAD_REFACTOR_PLAN.md
. The main architectural changes include:Module Reorganization:
BatchData
class is extracted fromcontainer.py
into its ownbatch_data.py
module to break circular import dependencies that prevented container classes from importing autograd functions.Function Delegation:
Job.run()
andBatch.run()
methods are refactored to delegate to autograd-compatible functions (run_autograd()
andrun_async_autograd()
) that automatically detect if simulations contain traced parameters and route them appropriately.Async Function Extraction: Batch processing logic is extracted from the
Batch
class into standalone async functions (upload_async
,start_async
,monitor_async
, etc.) to enable autograd compatibility through function-based primitives.Sequential Processing: A new
run_async()
function inwebapi.py
provides autograd-compatible batch processing by sequentially calling the proven single-simulationrun()
function for each simulation in a batch.The changes maintain full backward compatibility - existing code using
Job.run()
andBatch.run()
will continue to work exactly as before, but now these methods can be used within JAX-transformable functions for gradient-based optimization. The refactoring also reduces code duplication by consolidating async operations into reusable functions.Important Files Changed
Files and Ratings
tidy3d/web/api/container.py
tidy3d/web/api/autograd/autograd.py
tidy3d/web/api/asynchronous.py
tidy3d/web/api/webapi.py
tidy3d/web/api/batch_data.py
tidy3d/web/__init__.py
tidy3d/plugins/adjoint/web.py
tests/test_components/test_autograd.py
AUTOGRAD_REFACTOR_PLAN.md
tests/utils.py
tidy3d/plugins/design/design.py
tests/test_plugins/test_adjoint.py
tidy3d/plugins/smatrix/component_modelers/modal.py
tidy3d/plugins/smatrix/component_modelers/base.py
tidy3d/plugins/smatrix/component_modelers/terminal.py
docs/notebooks
Confidence score: 3/5
asynchronous.py
monitoring loop andautograd.py
task ID generation for potential infinite loops or missing task referencesSequence Diagram