diff --git a/ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb b/ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb index 78453ff..ea1b2d4 100644 --- a/ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb +++ b/ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb @@ -111,7 +111,25 @@ " return len(self.input_ids)\n", "\n", " def __getitem__(self, idx):\n", - " return self.input_ids[idx], self.target_ids[idx]" + " return self.input_ids[idx], self.target_ids[idx]\n", + "\n", + "def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True, num_workers=0):\n", + " # 初始化tokenizer(在这个例子中不需要真正的tokenizer,设为None)\n", + " tokenizer = None\n", + "\n", + " # 创建dataset\n", + " dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n", + "\n", + " # 创建dataloader\n", + " dataloader = DataLoader(\n", + " dataset,\n", + " batch_size=batch_size,\n", + " shuffle=shuffle,\n", + " drop_last=drop_last,\n", + " num_workers=num_workers\n", + " )\n", + "\n", + " return dataloader" ] }, {