Conversation
|
At the moment, the benchmarks for this PR on my laptop look like this. @Hespe ... this is not at all what we thought ...
|
|
Just putting this here as a point to remember: I also made some minor formatting changes in this PR that should be merged, even if we decide to discard its main change. As a further point. Even if the main change does not bring a meaningful improvement in speed, I think it might be worth merging to align the interface with PyTorch itself. |
|
I know this doesn't really give any speed advantage as it seems, but I still think the interface change is a good one to align the Cheetah interface with the one of Modules included in PyTorch directly. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| :param device: Device on which to create the element's tensors. | ||
| :param dtype: Data type of the element's tensors. |
There was a problem hiding this comment.
Should we be more explicit on how these arguments are used? I.e., that they does not influence the device or dtype of tensors that are passed, only the defaults?
| @@ -29,8 +29,8 @@ def are_all_the_same_dtype(tensors: list[torch.Tensor]) -> torch.dtype: | |||
|
|
|||
| def verify_device_and_dtype( | |||
|
I also just realised that I didn't remove any of the |
|
We should check the speed again after removing the |
There was a problem hiding this comment.
Pull Request Overview
This PR removes the verify_device_and_dtype function call from all Cheetah object initializations, making the library more similar to PyTorch's native modules where users are responsible for ensuring device and dtype consistency between tensors and modules.
Key changes:
- Removes device/dtype verification in element constructors
- Simplifies tensor creation using factory_kwargs pattern
- Updates tests to explicitly specify device/dtype where needed
- Removes the
verify_device_and_dtypeutility function entirely
Reviewed Changes
Copilot reviewed 37 out of 37 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| cheetah/utils/argument_verification.py | Completely removes the verify_device_and_dtype function and related utilities |
| cheetah/accelerator/*.py | Updates all accelerator elements to remove verification calls and use direct tensor assignment |
| cheetah/particles/*.py | Updates beam and species classes to remove verification and simplify tensor handling |
| cheetah/track_methods.py | Simplifies tensor creation using factory_kwargs pattern |
| tests/*.py | Updates tests to explicitly specify dtype for tensor parameters and use .to() method for module conversion |
Comments suppressed due to low confidence (1)
cheetah/track_methods.py:1
- Multiple duplicate return statements at the end of the
misalignment_matrixfunction. Only one return statement should remain.
"""Utility functions for creating transfer maps for elements."""
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>




Description
Removes the
verify_device_and_dtypecall in the initialisation of all Cheetah objects. This means that now the user is expected to ensure that thedtypeanddeviceof all tensors passed to an object match those passed to the object itself. This interface is very similar to that of modules defined in PyTorch itself.Sub- / replacement-PR for #538.
Motivation and Context
We found / believe that this insures a large speed penalty in tracking in Cheetah (see #538). At the same time, the users can be expected to match devices and dtypes themselves. This is also done in PyTorch-native Modules.
Types of changes
Checklist
flake8(required).pytesttests pass (required).pyteston a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.