|
1 | 1 | import torch |
2 | 2 | import comfy.model_management |
3 | | -import numbers |
4 | | -import logging |
5 | | - |
6 | | -RMSNorm = None |
7 | | - |
8 | | -try: |
9 | | - rms_norm_torch = torch.nn.functional.rms_norm |
10 | | - RMSNorm = torch.nn.RMSNorm |
11 | | -except: |
12 | | - rms_norm_torch = None |
13 | | - logging.warning("Please update pytorch to use native RMSNorm") |
14 | 3 |
|
| 4 | +RMSNorm = torch.nn.RMSNorm |
15 | 5 |
|
16 | 6 | def rms_norm(x, weight=None, eps=1e-6): |
17 | | - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): |
18 | | - if weight is None: |
19 | | - return rms_norm_torch(x, (x.shape[-1],), eps=eps) |
20 | | - else: |
21 | | - return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) |
| 7 | + if weight is None: |
| 8 | + return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) |
22 | 9 | else: |
23 | | - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) |
24 | | - if weight is None: |
25 | | - return r |
26 | | - else: |
27 | | - return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) |
28 | | - |
29 | | - |
30 | | -if RMSNorm is None: |
31 | | - class RMSNorm(torch.nn.Module): |
32 | | - def __init__( |
33 | | - self, |
34 | | - normalized_shape, |
35 | | - eps=1e-6, |
36 | | - elementwise_affine=True, |
37 | | - device=None, |
38 | | - dtype=None, |
39 | | - ): |
40 | | - factory_kwargs = {"device": device, "dtype": dtype} |
41 | | - super().__init__() |
42 | | - if isinstance(normalized_shape, numbers.Integral): |
43 | | - # mypy error: incompatible types in assignment |
44 | | - normalized_shape = (normalized_shape,) # type: ignore[assignment] |
45 | | - self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] |
46 | | - self.eps = eps |
47 | | - self.elementwise_affine = elementwise_affine |
48 | | - if self.elementwise_affine: |
49 | | - self.weight = torch.nn.Parameter( |
50 | | - torch.empty(self.normalized_shape, **factory_kwargs) |
51 | | - ) |
52 | | - else: |
53 | | - self.register_parameter("weight", None) |
54 | | - self.bias = None |
55 | | - |
56 | | - def forward(self, x): |
57 | | - return rms_norm(x, self.weight, self.eps) |
| 10 | + return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) |
0 commit comments