-
-
Notifications
You must be signed in to change notification settings - Fork 653
Explore fixing test issues for older pytorch versions #3434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
6f5a787
to
6b337b8
Compare
9b1a70a
to
08cf6f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR addresses compatibility issues with older PyTorch versions by removing or modifying features that were introduced in newer versions. The changes primarily focus on replacing stable=True
parameter in torch.argsort()
(added in PyTorch 1.13) with version-specific handling and fixing device comparison syntax.
- Remove
.bool()
calls on tensor operations that already return boolean tensors for older PyTorch compatibility - Replace direct device object comparisons with device type string comparisons
- Add version checks to conditionally use the
stable
parameter intorch.argsort()
Reviewed Changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
tests/ignite/metrics/vision/test_object_detection_map.py | Replace device object comparisons with device type string comparisons |
tests/ignite/metrics/test_mean_average_precision.py | Remove redundant .bool() calls on torch.randint() operations |
tests/ignite/distributed/utils/test_native.py | Add version checks and test ordering for older PyTorch compatibility |
ignite/metrics/vision/object_detection_average_precision_recall.py | Add version-based handling for stable parameter in torch.argsort() and fix tensor operations |
ignite/metrics/mean_average_precision.py | Add version checks for stable parameter and fix device comparisons |
ignite/metrics/gan/fid.py | Add .numpy() calls for scipy operations that require numpy arrays |
.github/workflows/pytorch-version-tests.yml | Fix Python version string formatting and installation commands |
@@ -37,7 +37,7 @@ def fid_score( | |||
diff = mu1 - mu2 | |||
|
|||
# Product might be almost singular | |||
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False) | |||
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2).numpy(), disp=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converting the tensor to numpy adds an unnecessary GPU-to-CPU transfer. Consider checking if the input tensors are on GPU and handle the conversion more efficiently, or use PyTorch native operations if available.
Copilot uses AI. Check for mistakes.
tr_covmean = np.sum(np.sqrt(((np.diag(sigma1.numpy()) * eps) * (np.diag(sigma2.numpy()) * eps)) / (eps * eps))) | ||
|
||
return float(diff.dot(diff).item() + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiple .numpy()
calls on the same tensors (sigma1 and sigma2) create redundant GPU-to-CPU transfers. Consider converting these tensors to numpy once and reusing the numpy arrays.
tr_covmean = np.sum(np.sqrt(((np.diag(sigma1.numpy()) * eps) * (np.diag(sigma2.numpy()) * eps)) / (eps * eps))) | |
return float(diff.dot(diff).item() + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean) | |
tr_covmean = np.sum(np.sqrt(((np.diag(sigma1) * eps) * (np.diag(sigma2) * eps)) / (eps * eps))) | |
return float(diff.dot(diff).item() + torch.trace(torch.tensor(sigma1)) + torch.trace(torch.tensor(sigma2)) - 2 * tr_covmean) |
Copilot uses AI. Check for mistakes.
08cf6f6
to
8cd193c
Compare
@vfdev-5 I think these fixes all the issues on pytorch versions. |
Ok, it seems now there are som typing issues
|
I dont understand where those issues are coming from now. I run mypy locally and no errors are reported. :\ |
Maybe, let's just ignore them with appropriate comment ? |
Supersedes #3356
Description:
Check list: