Skip to content

Commit 9eb3740

Browse files
littlebullGitSkafteNickiBordapre-commit-ci[bot]
authored
Prevent recursive symlink creation iwhen save_last='link' and save_top_k=-1 (#21186)
* Fix for issue #21110: Prevent recursive symlink creation in ModelCheckpoint - Added a check in _link_checkpoint to compare absolute paths of filepath and linkpath - Only create symlink if paths differ, avoiding self-linking when save_last='link' and save_top_k=-1 - Updated test to assert the fix prevents the recursive symlink bug * Address PR comments: refactor test to use boring classes and clarify assertion * Update test_model_checkpoint_additional_cases.py * chlog --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e4c4e9a commit 9eb3740

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
### Changed
2626

27-
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
27+
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))
28+
29+
30+
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))
2831

2932

3033
### Removed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
484484

485485
@staticmethod
486486
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
487-
if trainer.is_global_zero:
487+
if trainer.is_global_zero and os.path.abspath(filepath) != os.path.abspath(linkpath):
488488
if os.path.islink(linkpath) or os.path.isfile(linkpath):
489489
os.remove(linkpath)
490490
elif os.path.isdir(linkpath):

tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import os
23
from datetime import timedelta
34

45
import pytest
@@ -9,6 +10,7 @@
910

1011
from lightning.pytorch import LightningModule, Trainer, seed_everything
1112
from lightning.pytorch.callbacks import ModelCheckpoint
13+
from lightning.pytorch.demos.boring_classes import BoringModel
1214

1315

1416
class TinyDataset(Dataset):
@@ -206,3 +208,24 @@ def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tm
206208
expected = max(val_scores) # last/maximum value occurs at final validation epoch
207209
actual = float(ckpt.best_model_score)
208210
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
211+
212+
213+
def test_model_checkpoint_save_last_link_symlink_bug(tmp_path):
214+
"""Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink."""
215+
trainer = Trainer(
216+
default_root_dir=tmp_path,
217+
max_epochs=2,
218+
callbacks=[ModelCheckpoint(dirpath=tmp_path, every_n_epochs=10, save_last="link", save_top_k=-1)],
219+
enable_checkpointing=True,
220+
enable_model_summary=False,
221+
logger=False,
222+
)
223+
224+
model = BoringModel()
225+
trainer.fit(model)
226+
227+
last_ckpt = tmp_path / "last.ckpt"
228+
assert last_ckpt.exists()
229+
# With the fix, if a symlink exists, it should not point to itself (preventing recursion)
230+
if os.path.islink(str(last_ckpt)):
231+
assert os.readlink(str(last_ckpt)) != str(last_ckpt)

0 commit comments

Comments
 (0)