Skip to content

Commit 765f12e

Browse files
committed
fix paddle slicing
1 parent d446566 commit 765f12e

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

pygmtools/paddle_backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,10 +1284,9 @@ def permutation_loss(pred_dsmat, gt_perm, n1, n2):
12841284
loss = paddle.to_tensor(0., place=pred_dsmat.place)
12851285
n_sum = paddle.zeros_like(loss)
12861286
for b in range(batch_num):
1287-
batch_slice = [b, slice(n1[b]), slice(n2[b])]
12881287
loss += paddle.nn.functional.binary_cross_entropy(
1289-
pred_dsmat[batch_slice],
1290-
gt_perm[batch_slice],
1288+
pred_dsmat[b, :n1[b], :n2[b]],
1289+
gt_perm[b, :n1[b], :n2[b]],
12911290
reduction='sum')
12921291
n_sum += paddle.to_tensor(n1[b].astype(n_sum.dtype), place=pred_dsmat.place)
12931292

0 commit comments

Comments
 (0)