@@ -491,6 +491,7 @@ class BoringCkptPathModel(BoringModel):
491
491
def __init__ (self , out_dim : int = 2 , hidden_dim : int = 2 ) -> None :
492
492
super ().__init__ ()
493
493
self .save_hyperparameters ()
494
+ self .hidden_dim = hidden_dim
494
495
self .layer = torch .nn .Linear (32 , out_dim )
495
496
496
497
@@ -526,6 +527,41 @@ def add_arguments_to_parser(self, parser):
526
527
assert "Parsing of ckpt_path hyperparameters failed" in err .getvalue ()
527
528
528
529
530
+ class BoringCkptPathSubclass (BoringCkptPathModel ):
531
+ def __init__ (self , extra : bool = True , ** kwargs ) -> None :
532
+ super ().__init__ (** kwargs )
533
+ self .extra = extra
534
+
535
+
536
+ def test_lightning_cli_ckpt_path_argument_hparams_subclass_mode (cleandir ):
537
+ class CkptPathCLI (LightningCLI ):
538
+ def add_arguments_to_parser (self , parser ):
539
+ parser .link_arguments ("model.init_args.out_dim" , "model.init_args.hidden_dim" , compute_fn = lambda x : x * 2 )
540
+
541
+ cli_args = ["fit" , "--model=BoringCkptPathSubclass" , "--model.out_dim=4" , "--trainer.max_epochs=1" ]
542
+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
543
+ cli = CkptPathCLI (BoringCkptPathModel , subclass_mode_model = True )
544
+
545
+ assert cli .config .fit .model .class_path .endswith (".BoringCkptPathSubclass" )
546
+ assert cli .config .fit .model .init_args == Namespace (out_dim = 4 , hidden_dim = 8 , extra = True )
547
+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
548
+ assert hparams_path .is_file ()
549
+ hparams = yaml .safe_load (hparams_path .read_text ())
550
+ assert hparams ["out_dim" ] == 4
551
+ assert hparams ["hidden_dim" ] == 8
552
+ assert hparams ["extra" ] is True
553
+
554
+ checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ))
555
+ cli_args = ["predict" , "--model=BoringCkptPathModel" , f"--ckpt_path={ checkpoint_path } " ]
556
+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
557
+ cli = CkptPathCLI (BoringCkptPathModel , subclass_mode_model = True )
558
+
559
+ assert isinstance (cli .model , BoringCkptPathSubclass )
560
+ assert cli .model .hidden_dim == 8
561
+ assert cli .model .extra is True
562
+ assert cli .model .layer .out_features == 4
563
+
564
+
529
565
def test_lightning_cli_submodules (cleandir ):
530
566
class MainModule (BoringModel ):
531
567
def __init__ (self , submodule1 : LightningModule , submodule2 : LightningModule , main_param : int = 1 ):
0 commit comments