Skip to content

Commit a5b3972

Browse files
authored
fix(linear): remove linear op assert (#438)
1 parent c89c9db commit a5b3972

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

internlm/model/modules/linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,13 +912,13 @@ def __init__( # pylint: disable=W0231, W0233
912912
torch.empty(num_groups, in_features, local_multiple * multiple_of, device=device, dtype=dtype)
913913
)
914914
self.tp_dim = 2
915-
assert self.weight.shape[self.tp_dim] != out_features
915+
# assert self.weight.shape[self.tp_dim] != out_features
916916
elif split_mode == "row":
917917
self.weight = nn.Parameter(
918918
torch.empty(num_groups, local_multiple * multiple_of, out_features, device=device, dtype=dtype)
919919
)
920920
self.tp_dim = 1
921-
assert self.weight.shape[self.tp_dim] != in_features
921+
# assert self.weight.shape[self.tp_dim] != in_features
922922
elif split_mode == "weight":
923923
self.weight = nn.Parameter(
924924
torch.empty(local_multiple * multiple_of, out_features, device=device, dtype=dtype)

0 commit comments

Comments
 (0)