diff --git a/RWKV-v5/src/dataset.py b/RWKV-v5/src/dataset.py index 8ce76a1b0..ed45c734e 100644 --- a/RWKV-v5/src/dataset.py +++ b/RWKV-v5/src/dataset.py @@ -111,6 +111,7 @@ def __getitem__(self, idx): dix = self.data[i] x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) + return x, y else: ctx_len = args.ctx_len req_len = ctx_len + 1