|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import copy |
14 | 15 | import logging
|
15 | 16 | import math
|
16 | 17 | import os
|
|
25 | 26 |
|
26 | 27 | from lightning.pytorch import Trainer, seed_everything
|
27 | 28 | from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
| 29 | +from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason |
28 | 30 | from lightning.pytorch.demos.boring_classes import BoringModel
|
29 | 31 | from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
30 | 32 | from tests_pytorch.helpers.datamodules import ClassifDataModule
|
@@ -505,3 +507,190 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex
|
505 | 507 | log_mock.assert_called_once_with(expected_log)
|
506 | 508 | else:
|
507 | 509 | log_mock.assert_not_called()
|
| 510 | + |
| 511 | + |
| 512 | +class ModelWithHighLoss(BoringModel): |
| 513 | + def on_validation_epoch_end(self): |
| 514 | + self.log("val_loss", 10.0) |
| 515 | + |
| 516 | + |
| 517 | +class ModelWithDecreasingLoss(BoringModel): |
| 518 | + def __init__(self): |
| 519 | + super().__init__() |
| 520 | + self.epoch_losses = [5.0, 3.0, 1.0, 0.5] |
| 521 | + |
| 522 | + def on_validation_epoch_end(self): |
| 523 | + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1 |
| 524 | + self.log("val_loss", loss) |
| 525 | + |
| 526 | + |
| 527 | +class ModelWithIncreasingLoss(BoringModel): |
| 528 | + def __init__(self): |
| 529 | + super().__init__() |
| 530 | + self.epoch_losses = [1.0, 2.0, 5.0, 10.0] |
| 531 | + |
| 532 | + def on_validation_epoch_end(self): |
| 533 | + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 15.0 |
| 534 | + self.log("val_loss", loss) |
| 535 | + |
| 536 | + |
| 537 | +class ModelWithNaNLoss(BoringModel): |
| 538 | + def __init__(self): |
| 539 | + super().__init__() |
| 540 | + self.epoch_losses = [1.0, 0.5, float("nan")] |
| 541 | + |
| 542 | + def on_validation_epoch_end(self): |
| 543 | + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else float("nan") |
| 544 | + self.log("val_loss", loss) |
| 545 | + |
| 546 | + |
| 547 | +class ModelWithImprovingLoss(BoringModel): |
| 548 | + def __init__(self): |
| 549 | + super().__init__() |
| 550 | + self.epoch_losses = [5.0, 4.0, 3.0, 2.0, 1.0] |
| 551 | + |
| 552 | + def on_validation_epoch_end(self): |
| 553 | + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1 |
| 554 | + self.log("val_loss", loss) |
| 555 | + |
| 556 | + |
| 557 | +@pytest.mark.parametrize( |
| 558 | + ( |
| 559 | + "model_cls", |
| 560 | + "early_stopping_kwargs", |
| 561 | + "trainer_kwargs", |
| 562 | + "expected_reason", |
| 563 | + "reason_message_substr", |
| 564 | + "should_stop", |
| 565 | + "state_dict_override", |
| 566 | + ), |
| 567 | + [ |
| 568 | + # Patience exhausted |
| 569 | + ( |
| 570 | + ModelWithHighLoss, |
| 571 | + {"monitor": "val_loss", "patience": 2, "verbose": True}, |
| 572 | + {"max_epochs": 10, "enable_progress_bar": False}, |
| 573 | + EarlyStoppingReason.PATIENCE_EXHAUSTED, |
| 574 | + "did not improve", |
| 575 | + True, |
| 576 | + None, |
| 577 | + ), |
| 578 | + # Stopping threshold |
| 579 | + ( |
| 580 | + ModelWithDecreasingLoss, |
| 581 | + {"monitor": "val_loss", "stopping_threshold": 0.6, "mode": "min", "verbose": True}, |
| 582 | + {"max_epochs": 10, "enable_progress_bar": False}, |
| 583 | + EarlyStoppingReason.STOPPING_THRESHOLD, |
| 584 | + "Stopping threshold reached", |
| 585 | + True, |
| 586 | + None, |
| 587 | + ), |
| 588 | + # Divergence threshold |
| 589 | + ( |
| 590 | + ModelWithIncreasingLoss, |
| 591 | + {"monitor": "val_loss", "divergence_threshold": 8.0, "mode": "min", "verbose": True}, |
| 592 | + {"max_epochs": 10, "enable_progress_bar": False}, |
| 593 | + EarlyStoppingReason.DIVERGENCE_THRESHOLD, |
| 594 | + "Divergence threshold reached", |
| 595 | + True, |
| 596 | + None, |
| 597 | + ), |
| 598 | + # Non-finite metric |
| 599 | + ( |
| 600 | + ModelWithNaNLoss, |
| 601 | + {"monitor": "val_loss", "check_finite": True, "verbose": True}, |
| 602 | + {"max_epochs": 10, "enable_progress_bar": False}, |
| 603 | + EarlyStoppingReason.NON_FINITE_METRIC, |
| 604 | + "is not finite", |
| 605 | + True, |
| 606 | + None, |
| 607 | + ), |
| 608 | + # Not stopped (normal completion) |
| 609 | + ( |
| 610 | + ModelWithImprovingLoss, |
| 611 | + {"monitor": "val_loss", "patience": 3, "verbose": True}, |
| 612 | + {"max_epochs": 3, "enable_progress_bar": False}, |
| 613 | + EarlyStoppingReason.NOT_STOPPED, |
| 614 | + None, |
| 615 | + False, |
| 616 | + None, |
| 617 | + ), |
| 618 | + # State persistence |
| 619 | + ( |
| 620 | + None, |
| 621 | + {"monitor": "val_loss", "patience": 3}, |
| 622 | + {}, |
| 623 | + EarlyStoppingReason.PATIENCE_EXHAUSTED, |
| 624 | + "Test message", |
| 625 | + None, |
| 626 | + {"stopping_reason": EarlyStoppingReason.PATIENCE_EXHAUSTED, "stopping_reason_message": "Test message"}, |
| 627 | + ), |
| 628 | + # Backward compatibility (old state dict) |
| 629 | + ( |
| 630 | + None, |
| 631 | + {"monitor": "val_loss", "patience": 3}, |
| 632 | + {}, |
| 633 | + EarlyStoppingReason.NOT_STOPPED, |
| 634 | + None, |
| 635 | + None, |
| 636 | + { |
| 637 | + "wait_count": 2, |
| 638 | + "stopped_epoch": 5, |
| 639 | + "best_score": torch.tensor(0.5), |
| 640 | + "patience": 3, |
| 641 | + }, |
| 642 | + ), |
| 643 | + ], |
| 644 | +) |
| 645 | +def test_early_stopping_reasons( |
| 646 | + tmp_path, |
| 647 | + model_cls, |
| 648 | + early_stopping_kwargs, |
| 649 | + trainer_kwargs, |
| 650 | + expected_reason, |
| 651 | + reason_message_substr, |
| 652 | + should_stop, |
| 653 | + state_dict_override, |
| 654 | +): |
| 655 | + """Test all early stopping reasons in a single parametrized test.""" |
| 656 | + if state_dict_override is not None: |
| 657 | + early_stopping = EarlyStopping(**early_stopping_kwargs) |
| 658 | + if "stopping_reason" in state_dict_override: |
| 659 | + # State persistence test |
| 660 | + early_stopping.stopping_reason = state_dict_override["stopping_reason"] |
| 661 | + early_stopping.stopping_reason_message = state_dict_override["stopping_reason_message"] |
| 662 | + state_dict = early_stopping.state_dict() |
| 663 | + new_early_stopping = EarlyStopping(**early_stopping_kwargs) |
| 664 | + new_early_stopping.load_state_dict(state_dict) |
| 665 | + assert new_early_stopping.stopping_reason == expected_reason |
| 666 | + assert new_early_stopping.stopping_reason_message == reason_message_substr |
| 667 | + else: |
| 668 | + # Backward compatibility test |
| 669 | + early_stopping.load_state_dict(copy.deepcopy(state_dict_override)) |
| 670 | + assert early_stopping.stopping_reason == expected_reason |
| 671 | + assert early_stopping.stopping_reason_message is None |
| 672 | + assert early_stopping.wait_count == state_dict_override["wait_count"] |
| 673 | + assert early_stopping.stopped_epoch == state_dict_override["stopped_epoch"] |
| 674 | + return |
| 675 | + |
| 676 | + # All other tests |
| 677 | + model = model_cls() |
| 678 | + early_stopping = EarlyStopping(**early_stopping_kwargs) |
| 679 | + trainer = Trainer( |
| 680 | + default_root_dir=tmp_path, |
| 681 | + callbacks=[early_stopping], |
| 682 | + **trainer_kwargs, |
| 683 | + ) |
| 684 | + trainer.fit(model) |
| 685 | + |
| 686 | + assert early_stopping.stopping_reason == expected_reason |
| 687 | + if reason_message_substr is not None: |
| 688 | + assert early_stopping.stopping_reason_message is not None |
| 689 | + assert reason_message_substr in early_stopping.stopping_reason_message |
| 690 | + else: |
| 691 | + assert early_stopping.stopping_reason_message is None |
| 692 | + if should_stop is not None: |
| 693 | + if should_stop: |
| 694 | + assert early_stopping.stopped_epoch > 0 |
| 695 | + else: |
| 696 | + assert early_stopping.stopped_epoch == 0 |
0 commit comments