@@ -857,9 +857,6 @@ def forward(
857
857
"Cannot provide both embedding_selector and global_index."
858
858
)
859
859
860
- if x .dtype != self .pos_embd .dtype :
861
- self .pos_embd = self .pos_embd .to (x .dtype )
862
-
863
860
# Append positional embedding to input conditioning
864
861
if self .pos_embd is not None :
865
862
# Select positional embeddings with a selector function
@@ -947,15 +944,16 @@ def positional_embedding_indexing(
947
944
"""
948
945
# If no global indices are provided, select all embeddings and expand
949
946
# to match the batch size of the input
950
- if x .dtype != self .pos_embd .dtype :
951
- self .pos_embd = self .pos_embd .to (x .dtype )
947
+ pos_embd = self .pos_embd
948
+ if x .dtype != pos_embd .dtype :
949
+ pos_embd = pos_embd .to (x .dtype )
952
950
953
951
if global_index is None :
954
952
if self .lead_time_mode :
955
953
selected_pos_embd = []
956
- if self . pos_embd is not None :
954
+ if pos_embd is not None :
957
955
selected_pos_embd .append (
958
- self . pos_embd [None ].expand ((x .shape [0 ], - 1 , - 1 , - 1 ))
956
+ pos_embd [None ].expand ((x .shape [0 ], - 1 , - 1 , - 1 ))
959
957
)
960
958
if self .lt_embd is not None :
961
959
selected_pos_embd .append (
@@ -972,7 +970,7 @@ def positional_embedding_indexing(
972
970
if len (selected_pos_embd ) > 0 :
973
971
selected_pos_embd = torch .cat (selected_pos_embd , dim = 1 )
974
972
else :
975
- selected_pos_embd = self . pos_embd [None ].expand (
973
+ selected_pos_embd = pos_embd [None ].expand (
976
974
(x .shape [0 ], - 1 , - 1 , - 1 )
977
975
) # (B, C_{PE}, H, W)
978
976
@@ -985,11 +983,11 @@ def positional_embedding_indexing(
985
983
global_index = torch .reshape (
986
984
torch .permute (global_index , (1 , 0 , 2 , 3 )), (2 , - 1 )
987
985
) # (P, 2, X, Y) to (2, P*X*Y)
988
- selected_pos_embd = self . pos_embd [
986
+ selected_pos_embd = pos_embd [
989
987
:, global_index [0 ], global_index [1 ]
990
988
] # (C_pe, P*X*Y)
991
989
selected_pos_embd = torch .permute (
992
- torch .reshape (selected_pos_embd , (self . pos_embd .shape [0 ], P , H , W )),
990
+ torch .reshape (selected_pos_embd , (pos_embd .shape [0 ], P , H , W )),
993
991
(1 , 0 , 2 , 3 ),
994
992
) # (P, C_pe, X, Y)
995
993
@@ -1000,7 +998,7 @@ def positional_embedding_indexing(
1000
998
# Append positional and lead time embeddings to input conditioning
1001
999
if self .lead_time_mode :
1002
1000
embeds = []
1003
- if self . pos_embd is not None :
1001
+ if pos_embd is not None :
1004
1002
embeds .append (selected_pos_embd ) # reuse code below
1005
1003
if self .lt_embd is not None :
1006
1004
lt_embds = self .lt_embd [
@@ -1086,15 +1084,12 @@ def positional_embedding_selector(
1086
1084
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
1087
1085
>>>
1088
1086
"""
1089
- if x .dtype != self .pos_embd .dtype :
1090
- self .pos_embd = self .pos_embd .to (x .dtype )
1087
+ embeddings = self .pos_embd
1088
+ if x .dtype != embeddings .dtype :
1089
+ embeddings = embeddings .to (x .dtype )
1091
1090
if lead_time_label is not None :
1092
1091
# all patches share same lead_time_label
1093
- embeddings = torch .cat (
1094
- [self .pos_embd , self .lt_embd [lead_time_label [0 ].int ()]]
1095
- )
1096
- else :
1097
- embeddings = self .pos_embd
1092
+ embeddings = torch .cat ([embeddings , self .lt_embd [lead_time_label [0 ].int ()]])
1098
1093
return embedding_selector (embeddings ) # (B, N_pe, H, W)
1099
1094
1100
1095
def _get_positional_embedding (self ):
0 commit comments