Reimplement gradient L2 norm computation with correct math and unit test#990
Reimplement gradient L2 norm computation with correct math and unit test#990mandira15 wants to merge 3 commits intofossasia:masterfrom
Conversation
Reviewer's GuideReimplements gradient L2 norm computation as a standalone helper using correct global L2 math, avoids deprecated/unsafe tensor APIs, and adds a focused unit test to validate return type and positivity after backprop. File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 3 issues, and left some high level feedback:
- The new
compute_gradient_l2_normis defined at the same indentation level as other instance methods but takes onlymodeland noself, so it should either be moved to module-level or annotated as a@staticmethodto avoid confusion and potential misuse. - The docstring for
compute_gradient_l2_normis currently commented out with#; consider converting it into a proper triple-quoted docstring so that tooling and users can see the function description.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The new `compute_gradient_l2_norm` is defined at the same indentation level as other instance methods but takes only `model` and no `self`, so it should either be moved to module-level or annotated as a `@staticmethod` to avoid confusion and potential misuse.
- The docstring for `compute_gradient_l2_norm` is currently commented out with `#`; consider converting it into a proper triple-quoted docstring so that tooling and users can see the function description.
## Individual Comments
### Comment 1
<location path="py/visdom/__init__.py" line_range="1748-1762" />
<code_context>
+ if len(parameters) == 0:
+ return 0.0
+
+ device = parameters[0].device
+ total_norm = torch.zeros(1, device=device)
+
+ for param in parameters:
+ if param.grad is not None:
+ param_norm = param.grad.detach().norm(2)
</code_context>
<issue_to_address>
**suggestion:** Assuming all parameters and their gradients are on the same device may break in mixed-device setups.
`device` is taken from `parameters[0]` and used for `total_norm`, but you then iterate over all params. In mixed-device setups (model-parallel, sharded, etc.), this can cause device mismatch when accumulating norms. Consider either asserting all params/grads are on the same device, or computing norms per device and combining them safely.
```suggestion
import torch
parameters = list(model.parameters())
if len(parameters) == 0:
return 0.0
# Accumulate on CPU to safely handle mixed-device (model-parallel/sharded) setups.
total_norm_sq = 0.0
for param in parameters:
if param.grad is not None:
# Compute norm on the parameter's device, then move the scalar to CPU.
param_norm = param.grad.detach().norm(2).item()
total_norm_sq += param_norm ** 2
total_norm = total_norm_sq ** 0.5
return float(total_norm)
```
</issue_to_address>
### Comment 2
<location path="tests/test_gradient_norm.py" line_range="5" />
<code_context>
+import torch.nn as nn
+from visdom import compute_gradient_l2_norm
+
+def test_compute_gradient_l2_norm_returns_float():
+ model = nn.Linear(5, 1)
+ x = torch.randn(3, 5)
</code_context>
<issue_to_address>
**suggestion (testing):** Add a test for the case where no `.backward()` has been called so all gradients are `None` and the function should return `0.0`.
Please also add a test that calls `compute_gradient_l2_norm(model)` before any backward pass and asserts it returns `0.0`, confirming that parameters with `grad is None` are handled correctly.
</issue_to_address>
### Comment 3
<location path="tests/test_gradient_norm.py" line_range="15-18" />
<code_context>
+ loss = criterion(output, y)
+ loss.backward()
+
+ norm = compute_gradient_l2_norm(model)
+
+ assert isinstance(norm, float)
+ assert norm > 0
\ No newline at end of file
</code_context>
<issue_to_address>
**suggestion (testing):** Add a deterministic test that checks the actual numeric value of the L2 norm for a small model.
This only checks type and positivity and doesn’t validate that the implementation actually computes `sqrt(sum(||g_i||^2))`. Please add a deterministic test (with a fixed seed) using a tiny model (e.g., one Linear layer) where you:
- run a forward/backward pass,
- manually compute the norm from `param.grad` tensors,
- and assert `compute_gradient_l2_norm(model)` matches that value (e.g., with `pytest.approx`).
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| import torch.nn as nn | ||
| from visdom import compute_gradient_l2_norm | ||
|
|
||
| def test_compute_gradient_l2_norm_returns_float(): |
There was a problem hiding this comment.
suggestion (testing): Add a test for the case where no .backward() has been called so all gradients are None and the function should return 0.0.
Please also add a test that calls compute_gradient_l2_norm(model) before any backward pass and asserts it returns 0.0, confirming that parameters with grad is None are handled correctly.
| norm = compute_gradient_l2_norm(model) | ||
|
|
||
| assert isinstance(norm, float) | ||
| assert norm > 0 No newline at end of file |
There was a problem hiding this comment.
suggestion (testing): Add a deterministic test that checks the actual numeric value of the L2 norm for a small model.
This only checks type and positivity and doesn’t validate that the implementation actually computes sqrt(sum(||g_i||^2)). Please add a deterministic test (with a fixed seed) using a tiny model (e.g., one Linear layer) where you:
- run a forward/backward pass,
- manually compute the norm from
param.gradtensors, - and assert
compute_gradient_l2_norm(model)matches that value (e.g., withpytest.approx).
|
Hi @norbusan , I have addressed all the requested changes and pushed the updates. Thank you for your guidance. |
Description
This PR reworks the gradient L2 norm computation to ensure correct mathematical implementation, clean API behavior, and proper test coverage.
Motivation and Context
The previous implementation had issues related to incorrect L2 norm calculation (double square root), duplicate initialization, and unclear return type behavior. This update fixes the mathematical computation and ensures the function returns a consistent float value.
Changes
sqrt(sum(||g_i||^2))
.dataHow Has This Been Tested?
Types of changes
Summary by Sourcery
Implement a correct global L2 gradient norm helper and add coverage for its behavior.
New Features:
Bug Fixes:
Tests: