Skip to content

Commit fd14b2a

Browse files
committed
add test checking order of sampled data points
1 parent fe1ffc4 commit fd14b2a

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

tests/trainer/test_trainer.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@
104104
slow,
105105
torch_device,
106106
)
107-
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists
107+
from transformers.trainer_utils import (
108+
PREFIX_CHECKPOINT_DIR,
109+
HPSearchBackend,
110+
check_target_module_exists,
111+
get_last_checkpoint,
112+
)
108113
from transformers.training_args import OptimizerNames
109114
from transformers.utils import (
110115
SAFE_WEIGHTS_INDEX_NAME,
@@ -5104,6 +5109,142 @@ def create_dummy_dataset():
51045109
final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME)
51055110
self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!")
51065111

5112+
def test_resume_batch_order(self):
5113+
"""
5114+
Test that verifies dataloader order is reproducible when resuming from partial checkpoints.
5115+
Tests resuming from checkpoint 7 (within epoch 1).
5116+
"""
5117+
5118+
# --- Helper classes and functions defined locally for this test ---
5119+
class DummyDataset(torch.utils.data.Dataset):
5120+
def __init__(self, size: int = 32):
5121+
self.size = size
5122+
self.data = torch.randn((size, 10))
5123+
self.data[:, 0] = torch.arange(0, size) # Encode the data order
5124+
self.labels = torch.randint(0, 10, (size,))
5125+
5126+
def __len__(self) -> int:
5127+
return self.size
5128+
5129+
def __getitem__(self, idx: int):
5130+
return {"input_ids": self.data[idx], "labels": self.labels[idx]}
5131+
5132+
class DummyModel(nn.Module):
5133+
def __init__(self, size: int):
5134+
super().__init__()
5135+
self.fc = nn.Linear(10, 10, bias=False)
5136+
# data_order logs the order of data points seen by the model
5137+
self.register_buffer("data_order", torch.empty(0, dtype=torch.long))
5138+
5139+
def load_state_dict(self, state_dict, strict=True):
5140+
# Handle data_order buffer size mismatch during checkpoint loading
5141+
if "data_order" in state_dict:
5142+
saved_data_order = state_dict["data_order"]
5143+
if hasattr(self, "data_order") and self.data_order.shape != saved_data_order.shape:
5144+
# Resize the buffer to match the saved state
5145+
self.data_order = saved_data_order.clone()
5146+
5147+
return super().load_state_dict(state_dict, strict=strict)
5148+
5149+
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None):
5150+
logits = self.fc(input_ids)
5151+
loss = None
5152+
if labels is not None:
5153+
loss_fn = nn.CrossEntropyLoss()
5154+
loss = loss_fn(logits, labels)
5155+
5156+
# Log the data order for verification
5157+
data_indices = input_ids[:, 0].int()
5158+
self.data_order = torch.cat([self.data_order, data_indices.detach().clone()])
5159+
5160+
return {"loss": loss, "logits": logits}
5161+
5162+
# Scenario 1: Run baseline training to completion
5163+
# 1.1 Run training to completion
5164+
set_seed(42)
5165+
train_dataset = DummyDataset(size=10)
5166+
model_baseline = DummyModel(size=10)
5167+
5168+
exp_dir_baseline = self.get_auto_remove_tmp_dir()
5169+
args_baseline = TrainingArguments(
5170+
output_dir=str(exp_dir_baseline),
5171+
seed=42,
5172+
learning_rate=0.1,
5173+
per_device_train_batch_size=2,
5174+
gradient_accumulation_steps=1,
5175+
save_strategy="steps",
5176+
save_steps=1,
5177+
num_train_epochs=3,
5178+
optim="sgd",
5179+
disable_tqdm=True,
5180+
dataloader_num_workers=0, # Ensures that main process loads the data
5181+
report_to=[], # Disable wandb/tensorboard and other loggers
5182+
)
5183+
5184+
trainer_baseline = Trainer(
5185+
model=model_baseline,
5186+
args=args_baseline,
5187+
train_dataset=train_dataset,
5188+
)
5189+
5190+
trainer_baseline.train()
5191+
5192+
# 1.2 Get the data order from the last saved checkpoint for the full run
5193+
last_checkpoint_path = get_last_checkpoint(exp_dir_baseline)
5194+
last_ckpt_num = int(os.path.basename(last_checkpoint_path).split("-")[1]) # Must be 15
5195+
5196+
baseline_state_dict = safetensors.torch.load_file(
5197+
os.path.join(exp_dir_baseline, f"checkpoint-{last_ckpt_num}", "model.safetensors")
5198+
)
5199+
baseline_data_order = baseline_state_dict["data_order"]
5200+
5201+
# Scenario 2: Resume training from checkpoint in the middle of the second epoch
5202+
# 2.1 Resume training from the second batch of epoch 1 (target_ckpt_num = 7)
5203+
# 1 epoch consists of 10 points, so 5 steps with batch size 2
5204+
target_ckpt_num = 7
5205+
checkpoint_path = os.path.join(exp_dir_baseline, f"checkpoint-{target_ckpt_num - 1}")
5206+
5207+
set_seed(42)
5208+
model_resume = DummyModel(size=10)
5209+
5210+
exp_dir_resume = self.get_auto_remove_tmp_dir()
5211+
args_resume = TrainingArguments(
5212+
output_dir=str(exp_dir_resume),
5213+
seed=42,
5214+
learning_rate=0.1,
5215+
per_device_train_batch_size=2,
5216+
gradient_accumulation_steps=1,
5217+
save_strategy="steps",
5218+
save_steps=1,
5219+
num_train_epochs=3,
5220+
optim="sgd",
5221+
disable_tqdm=True,
5222+
dataloader_num_workers=0, # Ensures that main process loads the data
5223+
report_to=[], # Disable wandb/tensorboard and other loggers
5224+
)
5225+
5226+
trainer_resume = Trainer(
5227+
model=model_resume,
5228+
args=args_resume,
5229+
train_dataset=train_dataset,
5230+
)
5231+
5232+
trainer_resume.train(resume_from_checkpoint=checkpoint_path)
5233+
5234+
# 2.2 Get the data order from the last saved checkpoint for the resumed run
5235+
resumed_state_dict = safetensors.torch.load_file(
5236+
os.path.join(exp_dir_resume, f"checkpoint-{last_ckpt_num}", "model.safetensors")
5237+
)
5238+
resumed_data_order = resumed_state_dict["data_order"]
5239+
5240+
# 3. Compare results: the data order should be identical
5241+
self.assertTrue(
5242+
torch.equal(baseline_data_order, resumed_data_order),
5243+
f"Data order mismatch after checkpoint deletion and resume.\n"
5244+
f"Baseline: {baseline_data_order}\n"
5245+
f"Resumed: {resumed_data_order}",
5246+
)
5247+
51075248

51085249
@require_torch
51095250
@is_staging_test

0 commit comments

Comments
 (0)