@@ -188,18 +188,14 @@ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
188188 self .weight = DefaultParameter ((num_channels ,), tp .float32 )
189189 self .bias = DefaultParameter ((num_channels ,), tp .float32 )
190190 self .eps = eps
191+ self .num_channels = num_channels
191192
192193 def forward (self , x : tp .Tensor ) -> tp .Tensor :
193- original_dtype = x .dtype
194- x = tp .cast (x , tp .float32 )
195- u = tp .mean (x , dim = 1 , keepdim = True )
196- s = tp .mean ((x - u ) ** 2 , dim = 1 , keepdim = True )
197- x = (x - u ) / tp .sqrt (s + self .eps )
198- w = tp .unsqueeze (tp .unsqueeze (self .weight , 1 ), 2 )
199- b = tp .unsqueeze (tp .unsqueeze (self .bias , 1 ), 2 )
200- x = w * x + b
201- x = tp .cast (x , original_dtype )
202- return x
194+ x_perm = tp .permute (x , (0 , 2 , 3 , 1 ))
195+ ln = tp .LayerNorm (self .num_channels , dtype = x .dtype , eps = self .eps )
196+ ln .weight = self .weight
197+ ln .bias = self .bias
198+ return tp .permute (ln (x_perm ), (0 , 3 , 1 , 2 ))
203199
204200
205201def get_activation_fn (activation ):
0 commit comments