Skip to content

Remove verify_device_and_dtype#552

Merged
jank324 merged 12 commits intomasterfrom
remove-verify-device-and-dtype
Sep 10, 2025
Merged

Remove verify_device_and_dtype#552
jank324 merged 12 commits intomasterfrom
remove-verify-device-and-dtype

Conversation

@jank324
Copy link
Copy Markdown
Member

@jank324 jank324 commented Sep 5, 2025

Description

Removes the verify_device_and_dtype call in the initialisation of all Cheetah objects. This means that now the user is expected to ensure that thedtype and device of all tensors passed to an object match those passed to the object itself. This interface is very similar to that of modules defined in PyTorch itself.

Sub- / replacement-PR for #538.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

We found / believe that this insures a large speed penalty in tracking in Cheetah (see #538). At the same time, the users can be expected to match devices and dtypes themselves. This is also done in PyTorch-native Modules.

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 5, 2025

At the moment, the benchmarks for this PR on my laptop look like this. @Hespe ... this is not at all what we thought ...

benchmark_commits_plot

@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 5, 2025

Just putting this here as a point to remember: I also made some minor formatting changes in this PR that should be merged, even if we decide to discard its main change.

As a further point. Even if the main change does not bring a meaningful improvement in speed, I think it might be worth merging to align the interface with PyTorch itself.

@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 8, 2025

I reran the commit benchmark script on my MacBook and a cluster node again ... still no improvement, which is weird.

benchmark_commits_plot benchmark_commits_plot-4

@jank324 jank324 marked this pull request as ready for review September 8, 2025 11:46
@jank324 jank324 requested review from Hespe and Copilot September 8, 2025 11:46

This comment was marked as outdated.

@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 8, 2025

I know this doesn't really give any speed advantage as it seems, but I still think the interface change is a good one to align the Cheetah interface with the one of Modules included in PyTorch directly.

jank324 and others added 2 commits September 8, 2025 13:52
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Comment on lines +24 to +25
:param device: Device on which to create the element's tensors.
:param dtype: Data type of the element's tensors.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be more explicit on how these arguments are used? I.e., that they does not influence the device or dtype of tensors that are passed, only the defaults?

Comment thread cheetah/utils/argument_verification.py Outdated
@@ -29,8 +29,8 @@ def are_all_the_same_dtype(tensors: list[torch.Tensor]) -> torch.dtype:

def verify_device_and_dtype(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted entirely

@jank324 jank324 marked this pull request as draft September 8, 2025 12:25
@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 8, 2025

I also just realised that I didn't remove any of the as_tensor operations.

Comment thread tests/test_beam.py Outdated
@Hespe
Copy link
Copy Markdown
Member

Hespe commented Sep 8, 2025

We should check the speed again after removing the as_tensor calls.

@jank324
Copy link
Copy Markdown
Member Author

jank324 commented Sep 8, 2025

benchmark_commits_plot

@jank324 jank324 requested review from Hespe and Copilot September 10, 2025 08:38
@jank324 jank324 marked this pull request as ready for review September 10, 2025 08:38
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes the verify_device_and_dtype function call from all Cheetah object initializations, making the library more similar to PyTorch's native modules where users are responsible for ensuring device and dtype consistency between tensors and modules.

Key changes:

  • Removes device/dtype verification in element constructors
  • Simplifies tensor creation using factory_kwargs pattern
  • Updates tests to explicitly specify device/dtype where needed
  • Removes the verify_device_and_dtype utility function entirely

Reviewed Changes

Copilot reviewed 37 out of 37 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
cheetah/utils/argument_verification.py Completely removes the verify_device_and_dtype function and related utilities
cheetah/accelerator/*.py Updates all accelerator elements to remove verification calls and use direct tensor assignment
cheetah/particles/*.py Updates beam and species classes to remove verification and simplify tensor handling
cheetah/track_methods.py Simplifies tensor creation using factory_kwargs pattern
tests/*.py Updates tests to explicitly specify dtype for tensor parameters and use .to() method for module conversion
Comments suppressed due to low confidence (1)

cheetah/track_methods.py:1

  • Multiple duplicate return statements at the end of the misalignment_matrix function. Only one return statement should remain.
"""Utility functions for creating transfer maps for elements."""

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment thread tests/test_species.py Outdated
Comment thread cheetah/particles/species.py
Comment thread cheetah/accelerator/dipole.py
Comment thread cheetah/accelerator/dipole.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@jank324 jank324 merged commit 276cb84 into master Sep 10, 2025
10 checks passed
@jank324 jank324 deleted the remove-verify-device-and-dtype branch September 10, 2025 09:41
@jank324 jank324 restored the remove-verify-device-and-dtype branch September 10, 2025 09:41
@Hespe Hespe deleted the remove-verify-device-and-dtype branch September 17, 2025 09:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants