Skip to content

Commit b7e0382

Browse files
author
Julius Berner
committed
Avoid dtype change of buffer/param and fix softmax dtype
1 parent bab3815 commit b7e0382

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

physicsnemo/models/diffusion/song_unet.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -857,9 +857,6 @@ def forward(
857857
"Cannot provide both embedding_selector and global_index."
858858
)
859859

860-
if x.dtype != self.pos_embd.dtype:
861-
self.pos_embd = self.pos_embd.to(x.dtype)
862-
863860
# Append positional embedding to input conditioning
864861
if self.pos_embd is not None:
865862
# Select positional embeddings with a selector function
@@ -947,15 +944,16 @@ def positional_embedding_indexing(
947944
"""
948945
# If no global indices are provided, select all embeddings and expand
949946
# to match the batch size of the input
950-
if x.dtype != self.pos_embd.dtype:
951-
self.pos_embd = self.pos_embd.to(x.dtype)
947+
pos_embd = self.pos_embd
948+
if x.dtype != pos_embd.dtype:
949+
pos_embd = pos_embd.to(x.dtype)
952950

953951
if global_index is None:
954952
if self.lead_time_mode:
955953
selected_pos_embd = []
956-
if self.pos_embd is not None:
954+
if pos_embd is not None:
957955
selected_pos_embd.append(
958-
self.pos_embd[None].expand((x.shape[0], -1, -1, -1))
956+
pos_embd[None].expand((x.shape[0], -1, -1, -1))
959957
)
960958
if self.lt_embd is not None:
961959
selected_pos_embd.append(
@@ -972,7 +970,7 @@ def positional_embedding_indexing(
972970
if len(selected_pos_embd) > 0:
973971
selected_pos_embd = torch.cat(selected_pos_embd, dim=1)
974972
else:
975-
selected_pos_embd = self.pos_embd[None].expand(
973+
selected_pos_embd = pos_embd[None].expand(
976974
(x.shape[0], -1, -1, -1)
977975
) # (B, C_{PE}, H, W)
978976

@@ -985,11 +983,11 @@ def positional_embedding_indexing(
985983
global_index = torch.reshape(
986984
torch.permute(global_index, (1, 0, 2, 3)), (2, -1)
987985
) # (P, 2, X, Y) to (2, P*X*Y)
988-
selected_pos_embd = self.pos_embd[
986+
selected_pos_embd = pos_embd[
989987
:, global_index[0], global_index[1]
990988
] # (C_pe, P*X*Y)
991989
selected_pos_embd = torch.permute(
992-
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
990+
torch.reshape(selected_pos_embd, (pos_embd.shape[0], P, H, W)),
993991
(1, 0, 2, 3),
994992
) # (P, C_pe, X, Y)
995993

@@ -1000,7 +998,7 @@ def positional_embedding_indexing(
1000998
# Append positional and lead time embeddings to input conditioning
1001999
if self.lead_time_mode:
10021000
embeds = []
1003-
if self.pos_embd is not None:
1001+
if pos_embd is not None:
10041002
embeds.append(selected_pos_embd) # reuse code below
10051003
if self.lt_embd is not None:
10061004
lt_embds = self.lt_embd[
@@ -1086,15 +1084,12 @@ def positional_embedding_selector(
10861084
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
10871085
>>>
10881086
"""
1089-
if x.dtype != self.pos_embd.dtype:
1090-
self.pos_embd = self.pos_embd.to(x.dtype)
1087+
embeddings = self.pos_embd
1088+
if x.dtype != embeddings.dtype:
1089+
embeddings = embeddings.to(x.dtype)
10911090
if lead_time_label is not None:
10921091
# all patches share same lead_time_label
1093-
embeddings = torch.cat(
1094-
[self.pos_embd, self.lt_embd[lead_time_label[0].int()]]
1095-
)
1096-
else:
1097-
embeddings = self.pos_embd
1092+
embeddings = torch.cat([embeddings, self.lt_embd[lead_time_label[0].int()]])
10981093
return embedding_selector(embeddings) # (B, N_pe, H, W)
10991094

11001095
def _get_positional_embedding(self):

0 commit comments

Comments
 (0)