Skip to content

Commit abb4b5f

Browse files
committed
added warning and changed drop_last flag when batch size is too large
1 parent 2eeb900 commit abb4b5f

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

DeepSDFStruct/deep_sdf/reconstruction.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def reconstruct_from_samples(
1515
dtype=torch.float32,
1616
loss_fn="ClampedL1",
1717
batch_size=512,
18+
drop_last=True,
1819
):
1920

2021
optimizer = torch.optim.Adam(sdf.parametrization.parameters(), lr=lr)
@@ -57,8 +58,16 @@ def reconstruct_from_samples(
5758
raise NotImplementedError(f"Loss function {loss_fn} not available.")
5859

5960
dataset = TensorDataset(queries_ps_torch, gt_dist)
61+
if drop_last and (batch_size > len(dataset)):
62+
print(
63+
"Warning: drop_last was set to true, "
64+
f"but batch size ({batch_size}) is larger "
65+
f"than the size of the dataset ({len(dataset)})."
66+
" setting drop_last=False"
67+
)
68+
drop_last = False
6069
dataloader = DataLoader(
61-
dataset, batch_size=batch_size, shuffle=True, drop_last=True
70+
dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last
6271
)
6372

6473
for e in pbar:
@@ -86,6 +95,7 @@ def reconstruct_deepLS_from_samples(
8695
dtype=torch.float32,
8796
loss_fn="ClampedL1",
8897
batch_size=512,
98+
drop_last=True,
8999
):
90100

91101
optimizer = torch.optim.Adam(sdf.parametrization.parameters(), lr=lr)
@@ -104,8 +114,16 @@ def reconstruct_deepLS_from_samples(
104114
raise NotImplementedError(f"Loss function {loss_fn} not available.")
105115

106116
dataset = TensorDataset(sdfSample.samples, gt_dist)
117+
if drop_last and (batch_size > len(dataset)):
118+
print(
119+
"Warning: drop_last was set to true, "
120+
f"but batch size ({batch_size}) is larger "
121+
f"than the size of the dataset ({len(dataset)})."
122+
" setting drop_last=False"
123+
)
124+
drop_last = False
107125
dataloader = DataLoader(
108-
dataset, batch_size=batch_size, shuffle=True, drop_last=True
126+
dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last
109127
)
110128

111129
for e in pbar:

0 commit comments

Comments
 (0)