Replace torch.op(x) with x.op() for speed#556
Merged
Conversation
….com/desy-ml/cheetah into call-ops-on-tensors-instead-of-torch
Member
Author
jank324
commented
Sep 26, 2025
jank324
commented
Sep 26, 2025
jank324
commented
Sep 26, 2025
Contributor
There was a problem hiding this comment.
Pull Request Overview
This PR replaces torch.op(x) function calls with x.op() method calls throughout the codebase for marginal but consistent performance improvements. The changes affect mathematical operations like sqrt(), sin(), cos(), abs(), sum(), any(), and others across the entire codebase.
Key changes include:
- Replacing
torch.sqrt(x)withx.sqrt(),torch.sin(x)withx.sin(), etc. - Replacing
torch.sum(x)withx.sum(),torch.any(x)withx.any() - Renaming function parameters from
inputtoinputsfor consistency in statistics utilities - Adding documentation to CONTRIBUTING.md with examples of the faster patterns
Reviewed Changes
Copilot reviewed 25 out of 25 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| tests/test_statistics.py | Updated test assertions to use .isnan() method |
| tests/test_space_charge_kick.py | Replaced torch math functions with tensor methods in test calculations |
| tests/test_cavity.py | Updated assertions to use tensor methods |
| cheetah/utils/statistics.py | Replaced torch functions with tensor methods; renamed parameters from input to inputs |
| cheetah/utils/physics.py | Replaced torch.sqrt() with .sqrt() |
| cheetah/utils/kde.py | Replaced torch functions with tensor methods |
| cheetah/utils/bmadx.py | Replaced torch math and utility functions with tensor methods |
| cheetah/track_methods.py | Replaced torch functions with tensor methods in transfer map calculations |
| cheetah/particles/particle_beam.py | Replaced torch functions with tensor methods throughout beam calculations |
| cheetah/particles/parameter_beam.py | Replaced torch functions with tensor methods in parameter beam |
| cheetah/particles/beam.py | Replaced torch functions with tensor methods in base beam class |
| cheetah/converters/bmad.py | Updated phase calculation to use .rad2deg() method |
| cheetah/accelerator/*.py | Replaced torch functions with tensor methods across all accelerator elements |
| CONTRIBUTING.md | Added documentation for the faster operation patterns |
| CHANGELOG.md | Updated changelog to reference this PR |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Hespe
approved these changes
Nov 12, 2025
Hespe
approved these changes
Nov 12, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Description
Sub- / replacement PR for #538.
Calling
x.op()seems to be marginally but consistently faster thantorch.op(x).Motivation and Context
Types of changes
Checklist
flake8(required).pytesttests pass (required).pyteston a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.