@@ -413,6 +413,35 @@ Number of devices to train on (``int``), which devices to train on (``list`` or
413
413
# Training with GPU Accelerator using total number of gpus available on the system
414
414
Trainer(accelerator = " gpu" )
415
415
416
+
417
+ enable_autolog_hparams
418
+ ^^^^^^^^^^^^^^^^^^^^^^
419
+
420
+ Whether to log hyperparameters at the start of a run. Defaults to True.
421
+
422
+ .. testcode ::
423
+
424
+ # default used by the Trainer
425
+ trainer = Trainer(enable_autolog_hparams=True)
426
+
427
+ # disable logging hyperparams
428
+ trainer = Trainer(enable_autolog_hparams=False)
429
+
430
+ With the parameter set to false, you can add custom code to log hyperparameters.
431
+
432
+ .. code-block :: python
433
+
434
+ model = LitModel()
435
+ trainer = Trainer(enable_autolog_hparams = False )
436
+ for logger in trainer.loggers:
437
+ if isinstance (logger, lightning.pytorch.loggers.CSVLogger):
438
+ logger.log_hyperparams(hparams_dict_1)
439
+ else :
440
+ logger.log_hyperparams(hparams_dict_2)
441
+
442
+ You can also use `self.logger.log_hyperparams(...) ` inside `LightningModule ` to log.
443
+
444
+
416
445
enable_checkpointing
417
446
^^^^^^^^^^^^^^^^^^^^
418
447
@@ -443,6 +472,40 @@ See :doc:`Saving and Loading Checkpoints <../common/checkpointing>` for how to c
443
472
# Add your callback to the callbacks list
444
473
trainer = Trainer(callbacks=[checkpoint_callback])
445
474
475
+
476
+ enable_model_summary
477
+ ^^^^^^^^^^^^^^^^^^^^
478
+
479
+ Whether to enable or disable the model summarization. Defaults to True.
480
+
481
+ .. testcode ::
482
+
483
+ # default used by the Trainer
484
+ trainer = Trainer(enable_model_summary=True)
485
+
486
+ # disable summarization
487
+ trainer = Trainer(enable_model_summary=False)
488
+
489
+ # enable custom summarization
490
+ from lightning.pytorch.callbacks import ModelSummary
491
+
492
+ trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
493
+
494
+
495
+ enable_progress_bar
496
+ ^^^^^^^^^^^^^^^^^^^
497
+
498
+ Whether to enable or disable the progress bar. Defaults to True.
499
+
500
+ .. testcode ::
501
+
502
+ # default used by the Trainer
503
+ trainer = Trainer(enable_progress_bar=True)
504
+
505
+ # disable progress bar
506
+ trainer = Trainer(enable_progress_bar=False)
507
+
508
+
446
509
fast_dev_run
447
510
^^^^^^^^^^^^
448
511
@@ -500,6 +563,39 @@ Gradient clipping value
500
563
# default used by the Trainer
501
564
trainer = Trainer(gradient_clip_val=None)
502
565
566
+
567
+ inference_mode
568
+ ^^^^^^^^^^^^^^
569
+
570
+ Whether to use :func: `torch.inference_mode ` or :func: `torch.no_grad ` mode during evaluation
571
+ (``validate ``/``test ``/``predict ``)
572
+
573
+ .. testcode ::
574
+
575
+ # default used by the Trainer
576
+ trainer = Trainer(inference_mode=True)
577
+
578
+ # Use `torch.no_grad ` instead
579
+ trainer = Trainer(inference_mode=False)
580
+
581
+
582
+ With :func: `torch.inference_mode ` disabled, you can enable the grad of your model layers if required.
583
+
584
+ .. code-block :: python
585
+
586
+ class LitModel (LightningModule ):
587
+ def validation_step (self , batch , batch_idx ):
588
+ preds = self .layer1(batch)
589
+ with torch.enable_grad():
590
+ grad_preds = preds.requires_grad_()
591
+ preds2 = self .layer2(grad_preds)
592
+
593
+
594
+ model = LitModel()
595
+ trainer = Trainer(inference_mode = False )
596
+ trainer.validate(model)
597
+
598
+
503
599
limit_train_batches
504
600
^^^^^^^^^^^^^^^^^^^
505
601
@@ -871,18 +967,6 @@ See the :doc:`profiler documentation <../tuning/profiler>` for more details.
871
967
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler() `
872
968
trainer = Trainer(profiler="advanced")
873
969
874
- enable_progress_bar
875
- ^^^^^^^^^^^^^^^^^^^
876
-
877
- Whether to enable or disable the progress bar. Defaults to True.
878
-
879
- .. testcode ::
880
-
881
- # default used by the Trainer
882
- trainer = Trainer(enable_progress_bar=True)
883
-
884
- # disable progress bar
885
- trainer = Trainer(enable_progress_bar=False)
886
970
887
971
reload_dataloaders_every_n_epochs
888
972
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -917,28 +1001,6 @@ The pseudocode applies also to the ``val_dataloader``.
917
1001
918
1002
.. _replace-sampler-ddp :
919
1003
920
- use_distributed_sampler
921
- ^^^^^^^^^^^^^^^^^^^^^^^
922
-
923
- See :paramref: `lightning.pytorch.trainer.Trainer.params.use_distributed_sampler `.
924
-
925
- .. testcode ::
926
-
927
- # default used by the Trainer
928
- trainer = Trainer(use_distributed_sampler=True)
929
-
930
- By setting to False, you have to add your own distributed sampler:
931
-
932
- .. code-block :: python
933
-
934
- # in your LightningModule or LightningDataModule
935
- def train_dataloader (self ):
936
- dataset = ...
937
- # default used by the Trainer
938
- sampler = torch.utils.data.DistributedSampler(dataset, shuffle = True )
939
- dataloader = DataLoader(dataset, batch_size = 32 , sampler = sampler)
940
- return dataloader
941
-
942
1004
943
1005
strategy
944
1006
^^^^^^^^
@@ -982,6 +1044,29 @@ Enable synchronization between batchnorm layers across all GPUs.
982
1044
trainer = Trainer(sync_batchnorm=True)
983
1045
984
1046
1047
+ use_distributed_sampler
1048
+ ^^^^^^^^^^^^^^^^^^^^^^^
1049
+
1050
+ See :paramref: `lightning.pytorch.trainer.Trainer.params.use_distributed_sampler `.
1051
+
1052
+ .. testcode ::
1053
+
1054
+ # default used by the Trainer
1055
+ trainer = Trainer(use_distributed_sampler=True)
1056
+
1057
+ By setting to False, you have to add your own distributed sampler:
1058
+
1059
+ .. code-block :: python
1060
+
1061
+ # in your LightningModule or LightningDataModule
1062
+ def train_dataloader (self ):
1063
+ dataset = ...
1064
+ # default used by the Trainer
1065
+ sampler = torch.utils.data.DistributedSampler(dataset, shuffle = True )
1066
+ dataloader = DataLoader(dataset, batch_size = 32 , sampler = sampler)
1067
+ return dataloader
1068
+
1069
+
985
1070
val_check_interval
986
1071
^^^^^^^^^^^^^^^^^^
987
1072
@@ -1058,84 +1143,6 @@ Can specify as float, int, or a time-based duration.
1058
1143
# Total number of batches run
1059
1144
total_fit_batches = total_train_batches + total_val_batches
1060
1145
1061
-
1062
- enable_model_summary
1063
- ^^^^^^^^^^^^^^^^^^^^
1064
-
1065
- Whether to enable or disable the model summarization. Defaults to True.
1066
-
1067
- .. testcode ::
1068
-
1069
- # default used by the Trainer
1070
- trainer = Trainer(enable_model_summary=True)
1071
-
1072
- # disable summarization
1073
- trainer = Trainer(enable_model_summary=False)
1074
-
1075
- # enable custom summarization
1076
- from lightning.pytorch.callbacks import ModelSummary
1077
-
1078
- trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
1079
-
1080
-
1081
- inference_mode
1082
- ^^^^^^^^^^^^^^
1083
-
1084
- Whether to use :func: `torch.inference_mode ` or :func: `torch.no_grad ` mode during evaluation
1085
- (``validate ``/``test ``/``predict ``)
1086
-
1087
- .. testcode ::
1088
-
1089
- # default used by the Trainer
1090
- trainer = Trainer(inference_mode=True)
1091
-
1092
- # Use `torch.no_grad ` instead
1093
- trainer = Trainer(inference_mode=False)
1094
-
1095
-
1096
- With :func: `torch.inference_mode ` disabled, you can enable the grad of your model layers if required.
1097
-
1098
- .. code-block :: python
1099
-
1100
- class LitModel (LightningModule ):
1101
- def validation_step (self , batch , batch_idx ):
1102
- preds = self .layer1(batch)
1103
- with torch.enable_grad():
1104
- grad_preds = preds.requires_grad_()
1105
- preds2 = self .layer2(grad_preds)
1106
-
1107
-
1108
- model = LitModel()
1109
- trainer = Trainer(inference_mode = False )
1110
- trainer.validate(model)
1111
-
1112
- enable_autolog_hparams
1113
- ^^^^^^^^^^^^^^^^^^^^^^
1114
-
1115
- Whether to log hyperparameters at the start of a run. Defaults to True.
1116
-
1117
- .. testcode ::
1118
-
1119
- # default used by the Trainer
1120
- trainer = Trainer(enable_autolog_hparams=True)
1121
-
1122
- # disable logging hyperparams
1123
- trainer = Trainer(enable_autolog_hparams=False)
1124
-
1125
- With the parameter set to false, you can add custom code to log hyperparameters.
1126
-
1127
- .. code-block :: python
1128
-
1129
- model = LitModel()
1130
- trainer = Trainer(enable_autolog_hparams = False )
1131
- for logger in trainer.loggers:
1132
- if isinstance (logger, lightning.pytorch.loggers.CSVLogger):
1133
- logger.log_hyperparams(hparams_dict_1)
1134
- else :
1135
- logger.log_hyperparams(hparams_dict_2)
1136
-
1137
- You can also use `self.logger.log_hyperparams(...) ` inside `LightningModule ` to log.
1138
-
1139
1146
-----
1140
1147
1141
1148
Trainer class API
0 commit comments