Skip to content

Add optional support of JAX to accelerate some partial derivatives #418

@kanekosh

Description

@kanekosh

Description of feature

When using a dense VLM mesh, compute_partials in some components (e.g., eval_mtx in aerodynamics) becomes a bottleneck for derivative computations. These partials can be accelerated by replacing the current analytical derivatives with AD.
Aditya Deshpande and Sriram Bommakanti tried it out for the AE588 project, and they showed that AD actually accelerated the partials. Their prototype implementation can be found in their fork. Note that they used AD for only some part of the compute_partials computations, but not to the entire partial computation.

AD support should be optional because we don't want to add JAX as a hard dependency (for now), and AD likely doesn't offer performance benefits for moderate mesh size.

Potential solution

  1. Run profiling and identify the components that can be accelerated by AD. eval_mtx is one, but there could be others.
  2. Replace (part of) the compute_partials method with AD. We'll need to try out multiple AD options as Aditya and Sriram did.
  3. Add an optional dependency on JAX in setup.py
  4. Add a documentation page on AD - ideally, suggest a mesh size threshold at which the AD becomes faster than the default analytical partials.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions