Skip to content

Commit a0399f1

Browse files
committed
Fix variable shapes in sea_raft's loss
1 parent 542899c commit a0399f1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ptlflow/models/sea_raft/sea_raft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ def forward(self, outputs, inputs):
3434

3535
flow_loss = 0.0
3636
# exlude invalid pixels and extremely large diplacements
37-
mag = torch.sum(flow_gt**2, dim=1).sqrt()
37+
mag = torch.sum(flow_gt**2, dim=1, keepdim=True).sqrt()
3838
valid = (valid >= 0.5) & (mag < self.max_flow)
3939
for i in range(n_predictions):
4040
i_weight = self.gamma ** (n_predictions - i - 1)
4141
loss_i = outputs["nf_preds"][i]
4242
final_mask = (
4343
(~torch.isnan(loss_i.detach()))
4444
& (~torch.isinf(loss_i.detach()))
45-
& valid[:, None]
45+
& valid
4646
)
4747
flow_loss += i_weight * ((final_mask * loss_i).sum() / final_mask.sum())
4848

0 commit comments

Comments
 (0)