We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 52cbf39 commit 453a246Copy full SHA for 453a246
src/transformers/models/speecht5/modeling_speecht5.py
@@ -434,8 +434,8 @@ def forward(self, hidden_states):
434
pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
435
pos_seq = pos_seq[:, None] - pos_seq[None, :]
436
437
- pos_seq[pos_seq < -self.max_length] = -self.max_length
438
- pos_seq[pos_seq >= self.max_length] = self.max_length - 1
+ pos_seq = torch.where(pos_seq < -self.max_length, -self.max_length, pos_seq)
+ pos_seq = torch.where(pos_seq >= self.max_length, self.max_length - 1, pos_seq)
439
pos_seq = pos_seq + self.max_length
440
441
return self.pe_k(pos_seq)
0 commit comments