Skip to content

Commit 4454fab

Browse files
Remove code to support RMSNorm on old pytorch. (Comfy-Org#12499)
1 parent 1978f59 commit 4454fab

2 files changed

Lines changed: 6 additions & 55 deletions

File tree

comfy/ops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import comfy.model_management
2222
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
2323
import comfy.float
24-
import comfy.rmsnorm
2524
import json
2625
import comfy.memory_management
2726
import comfy.pinned_memory
@@ -463,7 +462,7 @@ def forward(self, *args, **kwargs):
463462
else:
464463
return super().forward(*args, **kwargs)
465464

466-
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
465+
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
467466
def reset_parameters(self):
468467
self.bias = None
469468
return None
@@ -475,8 +474,7 @@ def forward_comfy_cast_weights(self, input):
475474
weight = None
476475
bias = None
477476
offload_stream = None
478-
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
479-
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
477+
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
480478
uncast_bias_weight(self, weight, bias, offload_stream)
481479
return x
482480

comfy/rmsnorm.py

Lines changed: 4 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,10 @@
11
import torch
22
import comfy.model_management
3-
import numbers
4-
import logging
5-
6-
RMSNorm = None
7-
8-
try:
9-
rms_norm_torch = torch.nn.functional.rms_norm
10-
RMSNorm = torch.nn.RMSNorm
11-
except:
12-
rms_norm_torch = None
13-
logging.warning("Please update pytorch to use native RMSNorm")
143

4+
RMSNorm = torch.nn.RMSNorm
155

166
def rms_norm(x, weight=None, eps=1e-6):
17-
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
18-
if weight is None:
19-
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
20-
else:
21-
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
7+
if weight is None:
8+
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
229
else:
23-
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
24-
if weight is None:
25-
return r
26-
else:
27-
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
28-
29-
30-
if RMSNorm is None:
31-
class RMSNorm(torch.nn.Module):
32-
def __init__(
33-
self,
34-
normalized_shape,
35-
eps=1e-6,
36-
elementwise_affine=True,
37-
device=None,
38-
dtype=None,
39-
):
40-
factory_kwargs = {"device": device, "dtype": dtype}
41-
super().__init__()
42-
if isinstance(normalized_shape, numbers.Integral):
43-
# mypy error: incompatible types in assignment
44-
normalized_shape = (normalized_shape,) # type: ignore[assignment]
45-
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
46-
self.eps = eps
47-
self.elementwise_affine = elementwise_affine
48-
if self.elementwise_affine:
49-
self.weight = torch.nn.Parameter(
50-
torch.empty(self.normalized_shape, **factory_kwargs)
51-
)
52-
else:
53-
self.register_parameter("weight", None)
54-
self.bias = None
55-
56-
def forward(self, x):
57-
return rms_norm(x, self.weight, self.eps)
10+
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)

0 commit comments

Comments
 (0)