Skip to content

Commit 9434c6b

Browse files
committed
Increasing test-coverage
1 parent 62f20db commit 9434c6b

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

merlin/models/torch/batch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def target(self, name: str = "default") -> torch.Tensor:
281281
def __bool__(self) -> bool:
282282
return bool(self.features)
283283

284-
<<<<<<< HEAD
285284
def device(self) -> torch.device:
286285
"""Retrieves the device of the tensors in the Batch object.
287286
@@ -303,14 +302,6 @@ def sample_batch(
303302
batch_size: Optional[int] = None,
304303
shuffle: Optional[bool] = False,
305304
) -> Batch:
306-
=======
307-
308-
def sample_batch(
309-
dataset_or_loader: Union[Dataset, Loader],
310-
batch_size: Optional[int] = None,
311-
shuffle: Optional[bool] = False,
312-
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
313-
>>>>>>> d00e67a4 (Adding sample_batch & sample_features)
314305
"""Util function to generate a batch of input tensors from a merlin.io.Dataset instance
315306
316307
Parameters
@@ -330,6 +321,7 @@ def sample_batch(
330321
dictionary of target tensors.
331322
"""
332323

324+
<<<<<<< HEAD
333325
<<<<<<< HEAD
334326
if isinstance(data, Dataset):
335327
if not batch_size:
@@ -340,6 +332,11 @@ def sample_batch(
340332
else:
341333
raise ValueError(f"Expected Dataset or Loader instance, got: {data}")
342334
=======
335+
=======
336+
if not isinstance(dataset_or_loader, (Dataset, Loader)):
337+
raise ValueError(f"Expected Dataset or Loader instance, got {dataset_or_loader}")
338+
339+
>>>>>>> 519159b5 (Increasing test-coverage)
343340
if isinstance(dataset_or_loader, Dataset):
344341
if not batch_size:
345342
raise ValueError("Either use 'Loader' or specify 'batch_size'")

merlin/models/torch/router.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,13 @@ def add_route(
7575

7676
routing_module = self.selectable.select(selection)
7777
if module is not None:
78-
branch = Block(routing_module)
79-
if isinstance(module, ParallelBlock):
80-
branch = module.prepend(routing_module)
81-
8278
if hasattr(module, "setup_schema"):
8379
module.setup_schema(routing_module.schema)
8480

85-
branch.append(module)
81+
if isinstance(module, ParallelBlock):
82+
branch = module.prepend(routing_module)
83+
else:
84+
branch = Block(routing_module, module)
8685
else:
8786
branch = routing_module
8887

tests/unit/torch/utils/test_selection_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from merlin.models.torch.utils.selection_utils import select_schema, selection_name
3+
from merlin.models.torch.utils.selection_utils import Selectable, select_schema, selection_name
44
from merlin.schema import ColumnSchema, Schema, Tags
55

66

@@ -70,3 +70,16 @@ def test_select_column(self):
7070

7171
assert selection_name(column) == column.name
7272
assert selection_name(ColumnSchema("user_id")) == column.name
73+
74+
def test_exception(self):
75+
with pytest.raises(ValueError, match="is not valid"):
76+
selection_name(1)
77+
78+
79+
class TestSelectable:
80+
def test_exception(self):
81+
selectable = Selectable()
82+
83+
assert hasattr(selectable, "setup_schema")
84+
with pytest.raises(NotImplementedError):
85+
selectable.select(1)

0 commit comments

Comments
 (0)