@@ -15,6 +15,7 @@ def reconstruct_from_samples(
1515 dtype = torch .float32 ,
1616 loss_fn = "ClampedL1" ,
1717 batch_size = 512 ,
18+ drop_last = True ,
1819):
1920
2021 optimizer = torch .optim .Adam (sdf .parametrization .parameters (), lr = lr )
@@ -57,8 +58,16 @@ def reconstruct_from_samples(
5758 raise NotImplementedError (f"Loss function { loss_fn } not available." )
5859
5960 dataset = TensorDataset (queries_ps_torch , gt_dist )
61+ if drop_last and (batch_size > len (dataset )):
62+ print (
63+ "Warning: drop_last was set to true, "
64+ f"but batch size ({ batch_size } ) is larger "
65+ f"than the size of the dataset ({ len (dataset )} )."
66+ " setting drop_last=False"
67+ )
68+ drop_last = False
6069 dataloader = DataLoader (
61- dataset , batch_size = batch_size , shuffle = True , drop_last = True
70+ dataset , batch_size = batch_size , shuffle = True , drop_last = drop_last
6271 )
6372
6473 for e in pbar :
@@ -86,6 +95,7 @@ def reconstruct_deepLS_from_samples(
8695 dtype = torch .float32 ,
8796 loss_fn = "ClampedL1" ,
8897 batch_size = 512 ,
98+ drop_last = True ,
8999):
90100
91101 optimizer = torch .optim .Adam (sdf .parametrization .parameters (), lr = lr )
@@ -104,8 +114,16 @@ def reconstruct_deepLS_from_samples(
104114 raise NotImplementedError (f"Loss function { loss_fn } not available." )
105115
106116 dataset = TensorDataset (sdfSample .samples , gt_dist )
117+ if drop_last and (batch_size > len (dataset )):
118+ print (
119+ "Warning: drop_last was set to true, "
120+ f"but batch size ({ batch_size } ) is larger "
121+ f"than the size of the dataset ({ len (dataset )} )."
122+ " setting drop_last=False"
123+ )
124+ drop_last = False
107125 dataloader = DataLoader (
108- dataset , batch_size = batch_size , shuffle = True , drop_last = True
126+ dataset , batch_size = batch_size , shuffle = True , drop_last = drop_last
109127 )
110128
111129 for e in pbar :
0 commit comments