Draft
Conversation
14 tasks
This reverts commit 3cb301d.
Member
Author
|
We decided to split the items in this PR into separate PRs to keep better track of whether the improvements have any effect: |
jank324
added a commit
that referenced
this pull request
Sep 5, 2025
jank324
added a commit
that referenced
this pull request
Sep 5, 2025
jank324
added a commit
that referenced
this pull request
Sep 5, 2025
jank324
added a commit
that referenced
this pull request
Sep 8, 2025
14 tasks
14 tasks
jank324
added a commit
that referenced
this pull request
Sep 10, 2025
This was referenced Sep 10, 2025
jank324
added a commit
that referenced
this pull request
Sep 10, 2025
14 tasks
jank324
added a commit
that referenced
this pull request
Sep 12, 2025
This was referenced Sep 12, 2025
Member
Author
|
I just realised that there are places where we compute a reducing operation like |
Member
Author
Member
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.


Description
This PR does various things to improve Cheetah's speed:
verify_device_and_dtype, which comes with changes to how dtypes are handled by Cheetah in general. The changed variant is now more in line with PyTorch itself.torchoperations, because each of them comes with a kernel invocation that is much more costly than the computation itself for scalars. We look for operations that do as much as possible with as few kernel invocations as possible.Motivation and Context
From the last 0.5.x release to the first 0.6.x release, Cheetah slowed down by about 5-6x. It appears that the primary cause of this is the full use of tensors in all computations, even scalar ones. The latter is needed to maintain differentiability. An additional slowdown of about 30% was caused between 0.6.x releases and the most recent 0.7.5 release. The reasons for this are multifold, but include for example
verify_device_and_dtype.On a side note, operations like
torch.whereandtorch.clampcan break the compute graph, by selecting a subgraph that no longer contains results from parameters that one may be optimising over.Types of changes
Checklist
flake8(required).pytesttests pass (required).pyteston a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.