Skip to content

Commit 453a246

Browse files
authored
avoiding conditional indexing in positionalencoding to avoid possibil… (#42090)
avoiding conditional indexing in positionalencoding to avoid possibility of empty tensors
1 parent 52cbf39 commit 453a246

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/models/speecht5/modeling_speecht5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ def forward(self, hidden_states):
434434
pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
435435
pos_seq = pos_seq[:, None] - pos_seq[None, :]
436436

437-
pos_seq[pos_seq < -self.max_length] = -self.max_length
438-
pos_seq[pos_seq >= self.max_length] = self.max_length - 1
437+
pos_seq = torch.where(pos_seq < -self.max_length, -self.max_length, pos_seq)
438+
pos_seq = torch.where(pos_seq >= self.max_length, self.max_length - 1, pos_seq)
439439
pos_seq = pos_seq + self.max_length
440440

441441
return self.pe_k(pos_seq)

0 commit comments

Comments
 (0)