|
104 | 104 | slow, |
105 | 105 | torch_device, |
106 | 106 | ) |
107 | | -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists |
| 107 | +from transformers.trainer_utils import ( |
| 108 | + PREFIX_CHECKPOINT_DIR, |
| 109 | + HPSearchBackend, |
| 110 | + check_target_module_exists, |
| 111 | + get_last_checkpoint, |
| 112 | +) |
108 | 113 | from transformers.training_args import OptimizerNames |
109 | 114 | from transformers.utils import ( |
110 | 115 | SAFE_WEIGHTS_INDEX_NAME, |
@@ -5104,6 +5109,142 @@ def create_dummy_dataset(): |
5104 | 5109 | final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME) |
5105 | 5110 | self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!") |
5106 | 5111 |
|
| 5112 | + def test_resume_batch_order(self): |
| 5113 | + """ |
| 5114 | + Test that verifies dataloader order is reproducible when resuming from partial checkpoints. |
| 5115 | + Tests resuming from checkpoint 7 (within epoch 1). |
| 5116 | + """ |
| 5117 | + |
| 5118 | + # --- Helper classes and functions defined locally for this test --- |
| 5119 | + class DummyDataset(torch.utils.data.Dataset): |
| 5120 | + def __init__(self, size: int = 32): |
| 5121 | + self.size = size |
| 5122 | + self.data = torch.randn((size, 10)) |
| 5123 | + self.data[:, 0] = torch.arange(0, size) # Encode the data order |
| 5124 | + self.labels = torch.randint(0, 10, (size,)) |
| 5125 | + |
| 5126 | + def __len__(self) -> int: |
| 5127 | + return self.size |
| 5128 | + |
| 5129 | + def __getitem__(self, idx: int): |
| 5130 | + return {"input_ids": self.data[idx], "labels": self.labels[idx]} |
| 5131 | + |
| 5132 | + class DummyModel(nn.Module): |
| 5133 | + def __init__(self, size: int): |
| 5134 | + super().__init__() |
| 5135 | + self.fc = nn.Linear(10, 10, bias=False) |
| 5136 | + # data_order logs the order of data points seen by the model |
| 5137 | + self.register_buffer("data_order", torch.empty(0, dtype=torch.long)) |
| 5138 | + |
| 5139 | + def load_state_dict(self, state_dict, strict=True): |
| 5140 | + # Handle data_order buffer size mismatch during checkpoint loading |
| 5141 | + if "data_order" in state_dict: |
| 5142 | + saved_data_order = state_dict["data_order"] |
| 5143 | + if hasattr(self, "data_order") and self.data_order.shape != saved_data_order.shape: |
| 5144 | + # Resize the buffer to match the saved state |
| 5145 | + self.data_order = saved_data_order.clone() |
| 5146 | + |
| 5147 | + return super().load_state_dict(state_dict, strict=strict) |
| 5148 | + |
| 5149 | + def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None): |
| 5150 | + logits = self.fc(input_ids) |
| 5151 | + loss = None |
| 5152 | + if labels is not None: |
| 5153 | + loss_fn = nn.CrossEntropyLoss() |
| 5154 | + loss = loss_fn(logits, labels) |
| 5155 | + |
| 5156 | + # Log the data order for verification |
| 5157 | + data_indices = input_ids[:, 0].int() |
| 5158 | + self.data_order = torch.cat([self.data_order, data_indices.detach().clone()]) |
| 5159 | + |
| 5160 | + return {"loss": loss, "logits": logits} |
| 5161 | + |
| 5162 | + # Scenario 1: Run baseline training to completion |
| 5163 | + # 1.1 Run training to completion |
| 5164 | + set_seed(42) |
| 5165 | + train_dataset = DummyDataset(size=10) |
| 5166 | + model_baseline = DummyModel(size=10) |
| 5167 | + |
| 5168 | + exp_dir_baseline = self.get_auto_remove_tmp_dir() |
| 5169 | + args_baseline = TrainingArguments( |
| 5170 | + output_dir=str(exp_dir_baseline), |
| 5171 | + seed=42, |
| 5172 | + learning_rate=0.1, |
| 5173 | + per_device_train_batch_size=2, |
| 5174 | + gradient_accumulation_steps=1, |
| 5175 | + save_strategy="steps", |
| 5176 | + save_steps=1, |
| 5177 | + num_train_epochs=3, |
| 5178 | + optim="sgd", |
| 5179 | + disable_tqdm=True, |
| 5180 | + dataloader_num_workers=0, # Ensures that main process loads the data |
| 5181 | + report_to=[], # Disable wandb/tensorboard and other loggers |
| 5182 | + ) |
| 5183 | + |
| 5184 | + trainer_baseline = Trainer( |
| 5185 | + model=model_baseline, |
| 5186 | + args=args_baseline, |
| 5187 | + train_dataset=train_dataset, |
| 5188 | + ) |
| 5189 | + |
| 5190 | + trainer_baseline.train() |
| 5191 | + |
| 5192 | + # 1.2 Get the data order from the last saved checkpoint for the full run |
| 5193 | + last_checkpoint_path = get_last_checkpoint(exp_dir_baseline) |
| 5194 | + last_ckpt_num = int(os.path.basename(last_checkpoint_path).split("-")[1]) # Must be 15 |
| 5195 | + |
| 5196 | + baseline_state_dict = safetensors.torch.load_file( |
| 5197 | + os.path.join(exp_dir_baseline, f"checkpoint-{last_ckpt_num}", "model.safetensors") |
| 5198 | + ) |
| 5199 | + baseline_data_order = baseline_state_dict["data_order"] |
| 5200 | + |
| 5201 | + # Scenario 2: Resume training from checkpoint in the middle of the second epoch |
| 5202 | + # 2.1 Resume training from the second batch of epoch 1 (target_ckpt_num = 7) |
| 5203 | + # 1 epoch consists of 10 points, so 5 steps with batch size 2 |
| 5204 | + target_ckpt_num = 7 |
| 5205 | + checkpoint_path = os.path.join(exp_dir_baseline, f"checkpoint-{target_ckpt_num - 1}") |
| 5206 | + |
| 5207 | + set_seed(42) |
| 5208 | + model_resume = DummyModel(size=10) |
| 5209 | + |
| 5210 | + exp_dir_resume = self.get_auto_remove_tmp_dir() |
| 5211 | + args_resume = TrainingArguments( |
| 5212 | + output_dir=str(exp_dir_resume), |
| 5213 | + seed=42, |
| 5214 | + learning_rate=0.1, |
| 5215 | + per_device_train_batch_size=2, |
| 5216 | + gradient_accumulation_steps=1, |
| 5217 | + save_strategy="steps", |
| 5218 | + save_steps=1, |
| 5219 | + num_train_epochs=3, |
| 5220 | + optim="sgd", |
| 5221 | + disable_tqdm=True, |
| 5222 | + dataloader_num_workers=0, # Ensures that main process loads the data |
| 5223 | + report_to=[], # Disable wandb/tensorboard and other loggers |
| 5224 | + ) |
| 5225 | + |
| 5226 | + trainer_resume = Trainer( |
| 5227 | + model=model_resume, |
| 5228 | + args=args_resume, |
| 5229 | + train_dataset=train_dataset, |
| 5230 | + ) |
| 5231 | + |
| 5232 | + trainer_resume.train(resume_from_checkpoint=checkpoint_path) |
| 5233 | + |
| 5234 | + # 2.2 Get the data order from the last saved checkpoint for the resumed run |
| 5235 | + resumed_state_dict = safetensors.torch.load_file( |
| 5236 | + os.path.join(exp_dir_resume, f"checkpoint-{last_ckpt_num}", "model.safetensors") |
| 5237 | + ) |
| 5238 | + resumed_data_order = resumed_state_dict["data_order"] |
| 5239 | + |
| 5240 | + # 3. Compare results: the data order should be identical |
| 5241 | + self.assertTrue( |
| 5242 | + torch.equal(baseline_data_order, resumed_data_order), |
| 5243 | + f"Data order mismatch after checkpoint deletion and resume.\n" |
| 5244 | + f"Baseline: {baseline_data_order}\n" |
| 5245 | + f"Resumed: {resumed_data_order}", |
| 5246 | + ) |
| 5247 | + |
5107 | 5248 |
|
5108 | 5249 | @require_torch |
5109 | 5250 | @is_staging_test |
|
0 commit comments