Skip to content

Commit b7ca4d3

Browse files
SkafteNickiBorda
andauthored
Document missing trainer args (#21205)
* add missing trainer args * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <[email protected]>
1 parent f1ed6a2 commit b7ca4d3

File tree

1 file changed

+107
-2
lines changed

1 file changed

+107
-2
lines changed

docs/source-pytorch/common/trainer.rst

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,27 @@ Example::
246246
See also: :ref:`gradient_accumulation` to enable more fine-grained accumulation schedules.
247247

248248

249+
barebones
250+
^^^^^^^^^
251+
252+
Whether to run in "barebones mode", where all features that may impact raw speed are disabled. This is meant for
253+
analyzing the Trainer overhead and is discouraged during regular training runs.
254+
255+
When enabled, the following features are automatically deactivated:
256+
- Checkpointing: ``enable_checkpointing=False``
257+
- Logging: ``logger=False``, ``log_every_n_steps=0``
258+
- Progress bar: ``enable_progress_bar=False``
259+
- Model summary: ``enable_model_summary=False``
260+
- Sanity checking: ``num_sanity_val_steps=0``
261+
262+
.. testcode::
263+
264+
# default used by the Trainer
265+
trainer = Trainer(barebones=False)
266+
267+
# enable barebones mode for speed analysis
268+
trainer = Trainer(barebones=True)
269+
249270
benchmark
250271
^^^^^^^^^
251272

@@ -364,6 +385,22 @@ will need to be set up to use remote filepaths.
364385
# default used by the Trainer
365386
trainer = Trainer(default_root_dir=os.getcwd())
366387

388+
389+
detect_anomaly
390+
^^^^^^^^^^^^^^
391+
392+
Enable anomaly detection for the autograd engine. This will significantly slow down compute speed and is recommended
393+
only for model debugging.
394+
395+
.. testcode::
396+
397+
# default used by the Trainer
398+
trainer = Trainer(detect_anomaly=False)
399+
400+
# enable anomaly detection for debugging
401+
trainer = Trainer(detect_anomaly=True)
402+
403+
367404
devices
368405
^^^^^^^
369406

@@ -548,6 +585,24 @@ impact to subsequent runs. These are the changes enabled:
548585
- If using the CLI, the configuration file is not saved.
549586

550587

588+
gradient_clip_algorithm
589+
^^^^^^^^^^^^^^^^^^^^^^^
590+
591+
The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and
592+
``gradient_clip_algorithm="norm"`` to clip by norm. By default it will be set to ``"norm"``.
593+
594+
.. testcode::
595+
596+
# default used by the Trainer (defaults to "norm" when gradient_clip_val is set)
597+
trainer = Trainer(gradient_clip_algorithm=None)
598+
599+
# clip by value
600+
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="value")
601+
602+
# clip by norm
603+
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="norm")
604+
605+
551606
gradient_clip_val
552607
^^^^^^^^^^^^^^^^^
553608

@@ -624,6 +679,26 @@ Example::
624679
# run through only 10 batches of the training set each epoch
625680
trainer = Trainer(limit_train_batches=10)
626681

682+
683+
limit_predict_batches
684+
^^^^^^^^^^^^^^^^^^^^^
685+
686+
How much of prediction dataset to check. Value is per device.
687+
688+
.. testcode::
689+
690+
# default used by the Trainer
691+
trainer = Trainer(limit_predict_batches=1.0)
692+
693+
# run through only 25% of the prediction set
694+
trainer = Trainer(limit_predict_batches=0.25)
695+
696+
# run for only 10 batches
697+
trainer = Trainer(limit_predict_batches=10)
698+
699+
In the case of multiple prediction dataloaders, the limit applies to each dataloader individually.
700+
701+
627702
limit_test_batches
628703
^^^^^^^^^^^^^^^^^^
629704

@@ -801,6 +876,23 @@ For customizable options use the :class:`~lightning.pytorch.callbacks.timer.Time
801876
In case ``max_time`` is used together with ``min_steps`` or ``min_epochs``, the ``min_*`` requirement
802877
always has precedence.
803878

879+
880+
model_registry
881+
^^^^^^^^^^^^^^
882+
883+
If specified will upload the model to lightning model registry under the provided name.
884+
885+
.. testcode::
886+
887+
# default used by the Trainer
888+
trainer = Trainer(model_registry=None)
889+
890+
# specify model name for model hub upload
891+
trainer = Trainer(model_registry="my-model-name")
892+
893+
See `Lightning model registry docs <https://lightning.ai/docs/overview/finetune-models/model-registry>`_ for more info.
894+
895+
804896
num_nodes
805897
^^^^^^^^^
806898

@@ -875,12 +967,25 @@ Useful for quickly debugging or trying to overfit on purpose.
875967

876968
# debug using a single consistent train batch and a single consistent val batch
877969

970+
plugins
971+
^^^^^^^
878972

879-
:ref:`Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example:
880-
973+
Plugins allow you to connect arbitrary backends, precision libraries, clusters etc. and modification of core lightning logic.
974+
Examples of plugin types:
881975
- :ref:`Checkpoint IO <checkpointing_expert>`
882976
- `TorchElastic <https://pytorch.org/elastic/0.2.2/index.html>`_
883977
- :ref:`Precision Plugins <precision_expert>`
978+
- :class:`~lightning.pytorch.plugins.environments.ClusterEnvironment`
979+
980+
.. testcode::
981+
982+
# default used by the Trainer
983+
trainer = Trainer(plugins=None)
984+
985+
# example using built in slurm plugin
986+
from lightning.fabric.plugins.environments import SLURMEnvironment
987+
trainer = Trainer(plugins=[SLURMEnvironment()])
988+
884989

885990
To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own
886991
:class:`~lightning.pytorch.plugins.environments.ClusterEnvironment`.

0 commit comments

Comments
 (0)