Skip to content

Commit 6989e15

Browse files
SkafteNickiBorda
andauthored
Expose stopping reasons in EarlyStopping callback (#21188)
* add public reason api * add testing * add to documentation * changelog * fix doctests * fix unittests * fix documentation * fix typing --------- Co-authored-by: Jirka Borovec <[email protected]>
1 parent 9eb3740 commit 6989e15

File tree

4 files changed

+257
-1
lines changed

4 files changed

+257
-1
lines changed

docs/source-pytorch/common/early_stopping.rst

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.. testsetup:: *
22

3-
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
3+
from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason
4+
from lightning.pytorch import Trainer, LightningModule
45

56
.. _early_stopping:
67

@@ -71,6 +72,37 @@ Additional parameters that stop training at extreme points:
7172
- ``check_on_train_epoch_end``: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within
7273
training-specific hooks on epoch-level.
7374

75+
After training completes, you can programmatically check why early stopping occurred using the ``stopping_reason``
76+
attribute, which returns an ``EarlyStoppingReason`` enum value.
77+
78+
.. code-block:: python
79+
80+
from lightning.pytorch.callbacks import EarlyStopping
81+
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
82+
83+
early_stopping = EarlyStopping(monitor="val_loss", patience=3)
84+
trainer = Trainer(callbacks=[early_stopping])
85+
trainer.fit(model)
86+
87+
# Check why training stopped
88+
if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
89+
print("Training stopped due to patience exhaustion")
90+
elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD:
91+
print("Training stopped due to reaching stopping threshold")
92+
elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED:
93+
print("Training completed normally without early stopping")
94+
95+
# Access human-readable message
96+
if early_stopping.stopping_reason_message:
97+
print(f"Details: {early_stopping.stopping_reason_message}")
98+
99+
The available stopping reasons are:
100+
101+
- ``NOT_STOPPED``: Training completed normally without early stopping
102+
- ``STOPPING_THRESHOLD``: Training stopped because the monitored metric reached the stopping threshold
103+
- ``DIVERGENCE_THRESHOLD``: Training stopped because the monitored metric exceeded the divergence threshold
104+
- ``PATIENCE_EXHAUSTED``: Training stopped because the metric didn't improve for the specified patience
105+
- ``NON_FINITE_METRIC``: Training stopped because the monitored metric became NaN or infinite
74106

75107
In case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`
76108
and change where it is called:

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))
2020

2121

22+
- Added attributes to access stopping reason in `EarlyStopping` callback ([#21188](https://github.com/Lightning-AI/pytorch-lightning/pull/21188))
23+
24+
2225
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
2326

2427

src/lightning/pytorch/callbacks/early_stopping.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import logging
23+
from enum import Enum
2324
from typing import Any, Callable, Optional
2425

2526
import torch
@@ -34,6 +35,16 @@
3435
log = logging.getLogger(__name__)
3536

3637

38+
class EarlyStoppingReason(Enum):
39+
"""Enum for early stopping reasons."""
40+
41+
NOT_STOPPED = 0
42+
STOPPING_THRESHOLD = 1
43+
DIVERGENCE_THRESHOLD = 2
44+
PATIENCE_EXHAUSTED = 3
45+
NON_FINITE_METRIC = 4
46+
47+
3748
class EarlyStopping(Callback):
3849
r"""Monitor a metric and stop training when it stops improving.
3950
@@ -65,6 +76,11 @@ class EarlyStopping(Callback):
6576
If this is ``False``, then the check runs at the end of the validation.
6677
log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process.
6778
79+
Attributes:
80+
stopped_epoch: The epoch at which training was stopped. 0 if training was not stopped.
81+
stopping_reason: An ``EarlyStoppingReason`` enum indicating why training was stopped.
82+
stopping_reason_message: A human-readable message explaining why training was stopped.
83+
6884
Raises:
6985
MisconfigurationException:
7086
If ``mode`` is none of ``"min"`` or ``"max"``.
@@ -75,8 +91,12 @@ class EarlyStopping(Callback):
7591
7692
>>> from lightning.pytorch import Trainer
7793
>>> from lightning.pytorch.callbacks import EarlyStopping
94+
>>> from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
7895
>>> early_stopping = EarlyStopping('val_loss')
7996
>>> trainer = Trainer(callbacks=[early_stopping])
97+
>>> # After training...
98+
>>> if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
99+
... print("Training stopped due to patience exhaustion")
80100
81101
.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
82102
following arguments:
@@ -117,6 +137,8 @@ def __init__(
117137
self.divergence_threshold = divergence_threshold
118138
self.wait_count = 0
119139
self.stopped_epoch = 0
140+
self.stopping_reason = EarlyStoppingReason.NOT_STOPPED
141+
self.stopping_reason_message: Optional[str] = None
120142
self._check_on_train_epoch_end = check_on_train_epoch_end
121143
self.log_rank_zero_only = log_rank_zero_only
122144

@@ -169,6 +191,8 @@ def state_dict(self) -> dict[str, Any]:
169191
"stopped_epoch": self.stopped_epoch,
170192
"best_score": self.best_score,
171193
"patience": self.patience,
194+
"stopping_reason": self.stopping_reason.value,
195+
"stopping_reason_message": self.stopping_reason_message,
172196
}
173197

174198
@override
@@ -177,6 +201,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
177201
self.stopped_epoch = state_dict["stopped_epoch"]
178202
self.best_score = state_dict["best_score"]
179203
self.patience = state_dict["patience"]
204+
stopping_reason_value = state_dict.get("stopping_reason", EarlyStoppingReason.NOT_STOPPED.value)
205+
self.stopping_reason = EarlyStoppingReason(stopping_reason_value)
206+
self.stopping_reason_message = state_dict.get("stopping_reason_message")
180207

181208
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
182209
from lightning.pytorch.trainer.states import TrainerFn
@@ -212,6 +239,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
212239
trainer.should_stop = trainer.should_stop or should_stop
213240
if should_stop:
214241
self.stopped_epoch = trainer.current_epoch
242+
self.stopping_reason_message = reason
215243
if reason and self.verbose:
216244
self._log_info(trainer, reason, self.log_rank_zero_only)
217245

@@ -220,19 +248,22 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
220248
reason = None
221249
if self.check_finite and not torch.isfinite(current):
222250
should_stop = True
251+
self.stopping_reason = EarlyStoppingReason.NON_FINITE_METRIC
223252
reason = (
224253
f"Monitored metric {self.monitor} = {current} is not finite."
225254
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
226255
)
227256
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
228257
should_stop = True
258+
self.stopping_reason = EarlyStoppingReason.STOPPING_THRESHOLD
229259
reason = (
230260
"Stopping threshold reached:"
231261
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
232262
" Signaling Trainer to stop."
233263
)
234264
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
235265
should_stop = True
266+
self.stopping_reason = EarlyStoppingReason.DIVERGENCE_THRESHOLD
236267
reason = (
237268
"Divergence threshold reached:"
238269
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
@@ -247,6 +278,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
247278
self.wait_count += 1
248279
if self.wait_count >= self.patience:
249280
should_stop = True
281+
self.stopping_reason = EarlyStoppingReason.PATIENCE_EXHAUSTED
250282
reason = (
251283
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
252284
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import logging
1516
import math
1617
import os
@@ -25,6 +26,7 @@
2526

2627
from lightning.pytorch import Trainer, seed_everything
2728
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
29+
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
2830
from lightning.pytorch.demos.boring_classes import BoringModel
2931
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3032
from tests_pytorch.helpers.datamodules import ClassifDataModule
@@ -505,3 +507,190 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex
505507
log_mock.assert_called_once_with(expected_log)
506508
else:
507509
log_mock.assert_not_called()
510+
511+
512+
class ModelWithHighLoss(BoringModel):
513+
def on_validation_epoch_end(self):
514+
self.log("val_loss", 10.0)
515+
516+
517+
class ModelWithDecreasingLoss(BoringModel):
518+
def __init__(self):
519+
super().__init__()
520+
self.epoch_losses = [5.0, 3.0, 1.0, 0.5]
521+
522+
def on_validation_epoch_end(self):
523+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1
524+
self.log("val_loss", loss)
525+
526+
527+
class ModelWithIncreasingLoss(BoringModel):
528+
def __init__(self):
529+
super().__init__()
530+
self.epoch_losses = [1.0, 2.0, 5.0, 10.0]
531+
532+
def on_validation_epoch_end(self):
533+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 15.0
534+
self.log("val_loss", loss)
535+
536+
537+
class ModelWithNaNLoss(BoringModel):
538+
def __init__(self):
539+
super().__init__()
540+
self.epoch_losses = [1.0, 0.5, float("nan")]
541+
542+
def on_validation_epoch_end(self):
543+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else float("nan")
544+
self.log("val_loss", loss)
545+
546+
547+
class ModelWithImprovingLoss(BoringModel):
548+
def __init__(self):
549+
super().__init__()
550+
self.epoch_losses = [5.0, 4.0, 3.0, 2.0, 1.0]
551+
552+
def on_validation_epoch_end(self):
553+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1
554+
self.log("val_loss", loss)
555+
556+
557+
@pytest.mark.parametrize(
558+
(
559+
"model_cls",
560+
"early_stopping_kwargs",
561+
"trainer_kwargs",
562+
"expected_reason",
563+
"reason_message_substr",
564+
"should_stop",
565+
"state_dict_override",
566+
),
567+
[
568+
# Patience exhausted
569+
(
570+
ModelWithHighLoss,
571+
{"monitor": "val_loss", "patience": 2, "verbose": True},
572+
{"max_epochs": 10, "enable_progress_bar": False},
573+
EarlyStoppingReason.PATIENCE_EXHAUSTED,
574+
"did not improve",
575+
True,
576+
None,
577+
),
578+
# Stopping threshold
579+
(
580+
ModelWithDecreasingLoss,
581+
{"monitor": "val_loss", "stopping_threshold": 0.6, "mode": "min", "verbose": True},
582+
{"max_epochs": 10, "enable_progress_bar": False},
583+
EarlyStoppingReason.STOPPING_THRESHOLD,
584+
"Stopping threshold reached",
585+
True,
586+
None,
587+
),
588+
# Divergence threshold
589+
(
590+
ModelWithIncreasingLoss,
591+
{"monitor": "val_loss", "divergence_threshold": 8.0, "mode": "min", "verbose": True},
592+
{"max_epochs": 10, "enable_progress_bar": False},
593+
EarlyStoppingReason.DIVERGENCE_THRESHOLD,
594+
"Divergence threshold reached",
595+
True,
596+
None,
597+
),
598+
# Non-finite metric
599+
(
600+
ModelWithNaNLoss,
601+
{"monitor": "val_loss", "check_finite": True, "verbose": True},
602+
{"max_epochs": 10, "enable_progress_bar": False},
603+
EarlyStoppingReason.NON_FINITE_METRIC,
604+
"is not finite",
605+
True,
606+
None,
607+
),
608+
# Not stopped (normal completion)
609+
(
610+
ModelWithImprovingLoss,
611+
{"monitor": "val_loss", "patience": 3, "verbose": True},
612+
{"max_epochs": 3, "enable_progress_bar": False},
613+
EarlyStoppingReason.NOT_STOPPED,
614+
None,
615+
False,
616+
None,
617+
),
618+
# State persistence
619+
(
620+
None,
621+
{"monitor": "val_loss", "patience": 3},
622+
{},
623+
EarlyStoppingReason.PATIENCE_EXHAUSTED,
624+
"Test message",
625+
None,
626+
{"stopping_reason": EarlyStoppingReason.PATIENCE_EXHAUSTED, "stopping_reason_message": "Test message"},
627+
),
628+
# Backward compatibility (old state dict)
629+
(
630+
None,
631+
{"monitor": "val_loss", "patience": 3},
632+
{},
633+
EarlyStoppingReason.NOT_STOPPED,
634+
None,
635+
None,
636+
{
637+
"wait_count": 2,
638+
"stopped_epoch": 5,
639+
"best_score": torch.tensor(0.5),
640+
"patience": 3,
641+
},
642+
),
643+
],
644+
)
645+
def test_early_stopping_reasons(
646+
tmp_path,
647+
model_cls,
648+
early_stopping_kwargs,
649+
trainer_kwargs,
650+
expected_reason,
651+
reason_message_substr,
652+
should_stop,
653+
state_dict_override,
654+
):
655+
"""Test all early stopping reasons in a single parametrized test."""
656+
if state_dict_override is not None:
657+
early_stopping = EarlyStopping(**early_stopping_kwargs)
658+
if "stopping_reason" in state_dict_override:
659+
# State persistence test
660+
early_stopping.stopping_reason = state_dict_override["stopping_reason"]
661+
early_stopping.stopping_reason_message = state_dict_override["stopping_reason_message"]
662+
state_dict = early_stopping.state_dict()
663+
new_early_stopping = EarlyStopping(**early_stopping_kwargs)
664+
new_early_stopping.load_state_dict(state_dict)
665+
assert new_early_stopping.stopping_reason == expected_reason
666+
assert new_early_stopping.stopping_reason_message == reason_message_substr
667+
else:
668+
# Backward compatibility test
669+
early_stopping.load_state_dict(copy.deepcopy(state_dict_override))
670+
assert early_stopping.stopping_reason == expected_reason
671+
assert early_stopping.stopping_reason_message is None
672+
assert early_stopping.wait_count == state_dict_override["wait_count"]
673+
assert early_stopping.stopped_epoch == state_dict_override["stopped_epoch"]
674+
return
675+
676+
# All other tests
677+
model = model_cls()
678+
early_stopping = EarlyStopping(**early_stopping_kwargs)
679+
trainer = Trainer(
680+
default_root_dir=tmp_path,
681+
callbacks=[early_stopping],
682+
**trainer_kwargs,
683+
)
684+
trainer.fit(model)
685+
686+
assert early_stopping.stopping_reason == expected_reason
687+
if reason_message_substr is not None:
688+
assert early_stopping.stopping_reason_message is not None
689+
assert reason_message_substr in early_stopping.stopping_reason_message
690+
else:
691+
assert early_stopping.stopping_reason_message is None
692+
if should_stop is not None:
693+
if should_stop:
694+
assert early_stopping.stopped_epoch > 0
695+
else:
696+
assert early_stopping.stopped_epoch == 0

0 commit comments

Comments
 (0)