Skip to content

Commit 2bc6cd3

Browse files
committed
Merge branch 'apex_ln_fix' into 'main'
Fix for newer apex version See merge request ADLR/megatron-lm!1021
2 parents f64f91e + 1524ddc commit 2bc6cd3

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

megatron/model/fused_layer_norm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
HAVE_PERSIST_LAYER_NORM = False
2020

2121
try:
22-
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
22+
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
2323
except:
24-
FusedLayerNormAffineFunction = None
24+
fused_layer_norm_affine = None
2525

2626
global fused_layer_norm_cuda
2727
fused_layer_norm_cuda = None
@@ -79,9 +79,9 @@ def forward(self, input):
7979
weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
8080

8181
if self.no_persist_layer_norm:
82-
assert FusedLayerNormAffineFunction is not None, \
83-
"FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex"
84-
return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
82+
assert fused_layer_norm_affine is not None, \
83+
"fused_layer_norm_affine is not available, please install apex from https://github.com/NVIDIA/apex"
84+
return fused_layer_norm_affine(input, weight, self.bias, self.normalized_shape, eps=self.eps)
8585
else:
8686
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
8787

0 commit comments

Comments
 (0)