Skip to content

Commit 226ad69

Browse files
Restore support for passing device in CPU-only environment (#107)
Add test for passing device to check this works on CPU
1 parent 3c11b59 commit 226ad69

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

merlin/dataloader/loader_base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,8 @@ def __init__(
7474
self.global_rank = global_rank or 0
7575
self.drop_last = drop_last
7676

77-
if device:
78-
self.device = device
79-
else:
80-
self.device = "cpu" if not HAS_GPU or dataset.cpu else 0
77+
device = device or 0
78+
self.device = "cpu" if not HAS_GPU or dataset.cpu else device
8179

8280
if self.device == "cpu":
8381
self._array_lib = np

tests/unit/dataloader/test_tf_dataloader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def test_simple_model():
9898
_ = model.evaluate(loader)
9999

100100

101+
def test_with_device():
102+
dataset = Dataset(make_df({"a": [1]}))
103+
tf_dataloader.Loader(dataset, batch_size=1, device=1).peek()
104+
105+
101106
def test_nested_list():
102107
num_rows = 100
103108
batch_size = 12

0 commit comments

Comments
 (0)