Skip to content

Speed improvements (from profiling)#538

Draft
jank324 wants to merge 98 commits intomasterfrom
speed-improvements
Draft

Speed improvements (from profiling)#538
jank324 wants to merge 98 commits intomasterfrom
speed-improvements

Conversation

@jank324
Copy link
Copy Markdown
Member

@jank324 jank324 commented Aug 8, 2025

Description

This PR does various things to improve Cheetah's speed:

  • Remove verify_device_and_dtype, which comes with changes to how dtypes are handled by Cheetah in general. The changed variant is now more in line with PyTorch itself.
  • Reduce the number of torch operations, because each of them comes with a kernel invocation that is much more costly than the computation itself for scalars. We look for operations that do as much as possible with as few kernel invocations as possible.
  • Choose faster PyTorch operations.
  • Reduce branching (this also restores some breaks in the compute graphs). To do so, we just ignore some things like divisions by zero, when they would be unphysical, or we try to rearrange equations, such that branching edge cases no longer appear.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

From the last 0.5.x release to the first 0.6.x release, Cheetah slowed down by about 5-6x. It appears that the primary cause of this is the full use of tensors in all computations, even scalar ones. The latter is needed to maintain differentiability. An additional slowdown of about 30% was caused between 0.6.x releases and the most recent 0.7.5 release. The reasons for this are multifold, but include for example verify_device_and_dtype.

On a side note, operations like torch.where and torch.clamp can break the compute graph, by selecting a subgraph that no longer contains results from parameters that one may be optimising over.

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@jank324 jank324 added the enhancement New feature or request label Aug 8, 2025
jank324 and others added 24 commits August 8, 2025 12:16
…ole only to benchmark speed"

This reverts commit b15eadc.
This reverts commit 3cb301d.
@jank324 jank324 mentioned this pull request Sep 4, 2025
14 tasks
@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 5, 2025

We decided to split the items in this PR into separate PRs to keep better track of whether the improvements have any effect:

@jank324 jank324 mentioned this pull request Sep 5, 2025
14 tasks
jank324 added a commit that referenced this pull request Sep 5, 2025
jank324 added a commit that referenced this pull request Sep 5, 2025
jank324 added a commit that referenced this pull request Sep 5, 2025
@jank324 jank324 mentioned this pull request Sep 5, 2025
14 tasks
jank324 added a commit that referenced this pull request Sep 8, 2025
@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 19, 2025

I just realised that there are places where we compute a reducing operation like sum or mean, and then unsqueeze again. It would be much more efficient to call the operation with keep_dims=True.

@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 19, 2025

benchmark_commits_plot

@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 19, 2025

benchmark_commits_plot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants