Skip to content

Commit c72d6d0

Browse files
authored
Revert "Fix error when batch size is larger than the dataset" (#616)
This reverts commit 4b7c9a4.
1 parent 559bd70 commit c72d6d0

File tree

4 files changed

+6
-25
lines changed

4 files changed

+6
-25
lines changed

src/metatrain/experimental/nanopet/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def train(
149149
num_replicas=world_size,
150150
rank=rank,
151151
shuffle=True,
152-
drop_last=len(train_dataset) > self.hypers["batch_size"],
152+
drop_last=True,
153153
)
154154
for train_dataset in train_datasets
155155
]
@@ -181,9 +181,7 @@ def train(
181181
),
182182
drop_last=(
183183
# the sampler takes care of this (if present)
184-
# check if batch size > train_dataset
185-
len(train_dataset) > self.hypers["batch_size"]
186-
and train_sampler is None
184+
train_sampler is None
187185
),
188186
collate_fn=collate_fn,
189187
)

src/metatrain/pet/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def train(
167167
num_replicas=world_size,
168168
rank=rank,
169169
shuffle=True,
170-
drop_last=len(train_dataset) > self.hypers["batch_size"],
170+
drop_last=True,
171171
)
172172
for train_dataset in train_datasets
173173
]
@@ -199,9 +199,7 @@ def train(
199199
),
200200
drop_last=(
201201
# the sampler takes care of this (if present)
202-
# check if batch size > train_dataset
203-
len(train_dataset) > self.hypers["batch_size"]
204-
and train_sampler is None
202+
train_sampler is None
205203
),
206204
collate_fn=collate_fn,
207205
)

src/metatrain/soap_bpnn/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def train(
148148
num_replicas=world_size,
149149
rank=rank,
150150
shuffle=True,
151-
drop_last=len(train_dataset) > self.hypers["batch_size"],
151+
drop_last=True,
152152
)
153153
for train_dataset in train_datasets
154154
]
@@ -180,9 +180,7 @@ def train(
180180
),
181181
drop_last=(
182182
# the sampler takes care of this (if present)
183-
# check if batch size > train_dataset
184-
len(train_dataset) > self.hypers["batch_size"]
185-
and train_sampler is None
183+
train_sampler is None
186184
),
187185
collate_fn=collate_fn,
188186
)

tests/cli/test_train_model.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -302,19 +302,6 @@ def test_empty_training_set(monkeypatch, tmp_path, options):
302302
train_model(options)
303303

304304

305-
def test_batch_size_smaller_training_set(monkeypatch, tmp_path, options):
306-
"""Test that training still runs for batch size > train_size."""
307-
monkeypatch.chdir(tmp_path)
308-
309-
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
310-
311-
options["validation_set"] = 0.55
312-
options["test_set"] = 0.4
313-
options["architecture"]["training"]["batch_size"] = 1000
314-
315-
train_model(options)
316-
317-
318305
@pytest.mark.parametrize("split", [-0.1, 1.1])
319306
def test_wrong_test_split_size(split, monkeypatch, tmp_path, options):
320307
"""Test that an error is raised if the test split has the wrong size"""

0 commit comments

Comments
 (0)