Skip to content

Commit 2a16af0

Browse files
committed
Fixing merge-conflicts
1 parent 8451906 commit 2a16af0

File tree

7 files changed

+57
-231
lines changed

7 files changed

+57
-231
lines changed

merlin/models/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from merlin.models.torch.batch import Batch, Sequence
1818
from merlin.models.torch.block import Block, ParallelBlock
19+
from merlin.models.torch.blocks.mlp import MLPBlock
1920
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
2021
from merlin.models.torch.inputs.tabular import TabularInputBlock
2122
from merlin.models.torch.outputs.base import ModelOutput
@@ -31,6 +32,7 @@
3132
"EmbeddingTable",
3233
"EmbeddingTables",
3334
"ParallelBlock",
35+
"MLPBlock",
3436
"ModelOutput",
3537
"Sequence",
3638
"RegressionOutput",

merlin/models/torch/router.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def select(self, selection: Selection) -> "RouterBlock":
182182

183183
selected_branches = {}
184184
for key, val in self.branches.items():
185-
selected_branches[key] = select_container(val, selection)
185+
selected = select_container(val, selection)
186+
if selected:
187+
selected_branches[key] = selected
186188

187189
selectable = self.__class__(self.selectable.select(selection))
188190
for key, val in selected_branches.items():
@@ -222,6 +224,7 @@ class SelectKeys(nn.Module, Selectable):
222224

223225
def __init__(self, schema: Optional[Schema] = None):
224226
super().__init__()
227+
self.col_names: List[str] = []
225228
if schema:
226229
self.setup_schema(schema)
227230

@@ -231,7 +234,7 @@ def setup_schema(self, schema: Schema):
231234

232235
super().setup_schema(schema)
233236

234-
self.col_names: List[str] = schema.column_names
237+
self.col_names = schema.column_names
235238

236239
def select(self, selection: Selection) -> "SelectKeys":
237240
"""Select a subset of the schema based on the provided selection.
@@ -277,14 +280,28 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
277280

278281
return outputs
279282

283+
def extra_repr(self) -> str:
284+
return f"select: {', '.join(self.col_names)}"
285+
286+
def __bool__(self) -> bool:
287+
return bool(self.col_names)
288+
280289

281290
def select_container(container: BlockContainer, selection: Selection) -> BlockContainer:
282291
outputs = []
283292

284-
for module in container:
285-
if module.accepts_dict:
286-
outputs.append(module.select(selection))
293+
if hasattr(container.values[0], "select"):
294+
first = container.values[0].select(selection)
295+
if first:
296+
outputs.append(first)
287297
else:
288-
outputs.append(module)
298+
return BlockContainer()
299+
300+
if len(container.values) > 1:
301+
for module in container.values[1:]:
302+
if hasattr(module, "select"):
303+
outputs.append(module.select(selection))
304+
else:
305+
outputs.append(module)
289306

290307
return BlockContainer(*outputs, name=container._name)

merlin/models/torch/utils/selection_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,7 @@ def select_schema(schema: Schema, selection: Selection) -> Schema:
3333
if isinstance(selection, Schema):
3434
selected = selection
3535
elif isinstance(selection, ColumnSchema):
36-
<<<<<<< HEAD
37-
<<<<<<< HEAD
3836
selected = Schema([schema[selection.name]])
39-
=======
40-
selected = schema[selection.name]
41-
>>>>>>> a2644079 (Add selection_utils)
42-
=======
43-
selected = Schema([schema[selection.name]])
44-
>>>>>>> 89a6f043 (Increase test-coverage)
4537
elif callable(selection):
4638
selected = selection(schema)
4739
elif isinstance(selection, Tags):

tests/conftest.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@
2020
from pathlib import Path
2121

2222
import distributed
23-
import numpy as np
2423
import pytest
2524

2625
from merlin.core.utils import Distributed
2726
from merlin.datasets.synthetic import generate_data
2827
from merlin.io import Dataset
2928
from merlin.models.utils import ci_utils
30-
from merlin.schema import ColumnSchema, Tags
3129

3230
REPO_ROOT = Path(__file__).parent.parent
3331

@@ -86,26 +84,6 @@ def dask_client() -> distributed.Client:
8684
pass
8785

8886

89-
@pytest.fixture
90-
def item_id_col_schema() -> ColumnSchema:
91-
return ColumnSchema(
92-
"item_id",
93-
dtype=np.int32,
94-
properties={"domain": {"min": 0, "max": 10, "name": "item_id"}},
95-
tags=[Tags.CATEGORICAL, Tags.ITEM_ID],
96-
)
97-
98-
99-
@pytest.fixture
100-
def user_id_col_schema() -> ColumnSchema:
101-
return ColumnSchema(
102-
"user_id",
103-
dtype=np.int32,
104-
properties={"domain": {"min": 0, "max": 20, "name": "user_id"}},
105-
tags=[Tags.CATEGORICAL, Tags.USER_ID],
106-
)
107-
108-
10987
def pytest_collection_modifyitems(items):
11088
changed_backends = ci_utils.get_changed_backends()
11189
full_name_to_alias = {v: k for k, v in ci_utils.BACKEND_ALIASES.items()}

tests/unit/torch/_conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import pytest
3+
4+
from merlin.schema import ColumnSchema, Tags
5+
6+
7+
@pytest.fixture
8+
def item_id_col_schema() -> ColumnSchema:
9+
return ColumnSchema(
10+
"item_id",
11+
dtype=np.int32,
12+
properties={"domain": {"min": 0, "max": 10, "name": "item_id"}},
13+
tags=[Tags.CATEGORICAL, Tags.ITEM_ID],
14+
)
15+
16+
17+
@pytest.fixture
18+
def user_id_col_schema() -> ColumnSchema:
19+
return ColumnSchema(
20+
"user_id",
21+
dtype=np.int32,
22+
properties={"domain": {"min": 0, "max": 20, "name": "user_id"}},
23+
tags=[Tags.CATEGORICAL, Tags.USER_ID],
24+
)

tests/unit/torch/test_router.py

Lines changed: 8 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
<<<<<<< HEAD
2-
<<<<<<< HEAD
31
from typing import Dict
42

53
import pytest
@@ -99,142 +97,19 @@ def test_select(self):
9997
plus_one = PlusOneDict()
10098

10199
self.router.add_route(Tags.CONTINUOUS)
102-
self.router.add_route(Tags.CATEGORICAL)
100+
self.router.add_route(Tags.CATEGORICAL, mm.MLPBlock([10]))
101+
self.router.add_route(Tags.USER, mm.MLPBlock([10]).prepend(mm.SelectKeys(self.schema)))
102+
self.router.add_route(Tags.ITEM, mm.MLPBlock([10]))
103103
self.router.prepend(plus_one)
104104

105105
router = self.router.select(Tags.CATEGORICAL)
106106
assert router.selectable.schema == self.schema.select_by_tag(Tags.CATEGORICAL)
107107
assert router[0][0] == plus_one
108108

109-
def test_double_add(self):
110-
self.router.add_route(Tags.CONTINUOUS)
111-
with pytest.raises(ValueError):
112-
self.router.add_route(Tags.CONTINUOUS)
113-
114-
def test_nested(self):
115-
self.router.add_route(Tags.CONTINUOUS)
116-
117-
nested = self.router.nested_router()
118-
nested.add_route(Tags.USER)
119-
assert "user" in nested
120-
121-
outputs = module_utils.module_test(nested, self.batch.features)
122-
assert list(outputs.keys()) == ["user_age"]
123-
assert "user_age" in nested.output_schema().column_names
124-
<<<<<<< HEAD
125-
=======
126-
import pytest
127-
=======
128-
from typing import Dict
129-
>>>>>>> 89a6f043 (Increase test-coverage)
130-
131-
import pytest
132-
import torch
133-
from torch import nn
134-
135-
import merlin.models.torch as mm
136-
from merlin.models.torch.batch import Batch, sample_batch
137-
from merlin.models.torch.utils import module_utils
138-
from merlin.schema import ColumnSchema, Schema, Tags
139-
140-
141-
class ToFloat(nn.Module):
142-
def forward(self, x):
143-
return x.float()
144-
145-
146-
class PlusOneDict(nn.Module):
147-
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
148-
return {k: v + 1 for k, v in inputs.items()}
149-
150-
151-
class TestRouterBlock:
152-
<<<<<<< HEAD
153-
...
154-
>>>>>>> a2644079 (Add selection_utils)
155-
=======
156-
@pytest.fixture(autouse=True)
157-
def setup_method(self, music_streaming_data):
158-
self.schema = music_streaming_data.schema
159-
self.router: mm.RouterBlock = mm.RouterBlock(self.schema)
160-
self.batch: Batch = sample_batch(music_streaming_data, batch_size=10)
161-
162-
def test_add_route(self):
163-
self.router.add_route(Tags.CONTINUOUS)
164-
165-
outputs = module_utils.module_test(self.router, self.batch.features)
166-
assert set(outputs.keys()) == set(self.schema.select_by_tag(Tags.CONTINUOUS).column_names)
167-
assert "continuous" in self.router
168-
assert len(self.router["continuous"]) == 1
169-
assert isinstance(self.router["continuous"][0], mm.SelectKeys)
170-
171-
def test_add_route_module(self):
172-
class CustomSelect(mm.SelectKeys):
173-
...
174-
175-
self.router.add_route(Tags.CONTINUOUS, CustomSelect())
176-
177-
outputs = self.router(self.batch.features)
178-
assert set(outputs.keys()) == set(self.schema.select_by_tag(Tags.CONTINUOUS).column_names)
179-
assert len(self.router["continuous"]) == 2
180-
assert isinstance(self.router["continuous"][0], mm.SelectKeys)
181-
assert isinstance(self.router["continuous"][1], CustomSelect)
182-
183-
def test_module_with_setup(self):
184-
class Dummy(nn.Module):
185-
def setup_schema(self, schema: Schema):
186-
self.schema = schema
187-
188-
def forward(self, x):
189-
return x
190-
191-
dummy = Dummy()
192-
self.router.add_route(Tags.CONTINUOUS, dummy)
193-
assert dummy.schema == mm.select_schema(self.schema, Tags.CONTINUOUS)
194-
195-
dummy_2 = Dummy()
196-
self.router.add_route_for_each(ColumnSchema("user_id"), dummy_2, shared=True)
197-
assert dummy_2.schema == mm.select_schema(self.schema, ColumnSchema("user_id"))
198-
199-
def test_add_route_parallel_block(self):
200-
class FakeEmbeddings(mm.ParallelBlock):
201-
...
202-
203-
self.router.add_route(Tags.CATEGORICAL, FakeEmbeddings())
204-
assert isinstance(self.router["categorical"], FakeEmbeddings)
205-
206-
@pytest.mark.parametrize("shared", [True, False])
207-
def test_add_route_for_each(self, shared):
208-
block = mm.Block(mm.Concat(), ToFloat(), nn.LazyLinear(10)).to(self.batch.device())
209-
self.router.add_route_for_each(Tags.CONTINUOUS, block, shared=shared)
210-
211-
dense_pos = self.router.branches["position"][1][-1]
212-
dense_age = self.router.branches["user_age"][1][-1]
213-
if shared:
214-
assert dense_pos == dense_age
215-
else:
216-
assert dense_pos != dense_age
217-
218-
outputs = self.router(self.batch.features)
219-
assert set(outputs.keys()) == set(self.schema.select_by_tag(Tags.CONTINUOUS).column_names)
220-
221-
for value in outputs.values():
222-
assert value.shape[-1] == 10
223-
224-
def test_add_route_for_each_list(self):
225-
self.router.add_route_for_each([ColumnSchema("user_id")], ToFloat())
226-
assert isinstance(self.router.branches["user_id"][1], ToFloat)
227-
228-
def test_select(self):
229-
plus_one = PlusOneDict()
230-
231-
self.router.add_route(Tags.CONTINUOUS)
232-
self.router.add_route(Tags.CATEGORICAL)
233-
self.router.prepend(plus_one)
234-
235-
router = self.router.select(Tags.CATEGORICAL)
236-
assert router.selectable.schema == self.schema.select_by_tag(Tags.CATEGORICAL)
237-
assert router[0][0] == plus_one
109+
user = self.router.select(Tags.USER)
110+
assert "item_recency" not in user.branches["continuous"][0].col_names
111+
assert "item_id" not in user.branches["categorical"][0].col_names
112+
assert "item" not in user.branches
238113

239114
def test_double_add(self):
240115
self.router.add_route(Tags.CONTINUOUS)
@@ -250,52 +125,12 @@ def test_nested(self):
250125

251126
outputs = module_utils.module_test(nested, self.batch.features)
252127
assert list(outputs.keys()) == ["user_age"]
253-
>>>>>>> 89a6f043 (Increase test-coverage)
254-
=======
255-
>>>>>>> 78386932 (Fix failined nested router test)
128+
assert "user_age" in nested.output_schema().column_names
256129

257130

258131
class TestSelectKeys:
259132
@pytest.fixture(autouse=True)
260133
def setup_method(self, music_streaming_data):
261-
<<<<<<< HEAD
262-
<<<<<<< HEAD
263-
self.batch: Batch = sample_batch(music_streaming_data, batch_size=10)
264-
self.schema: Schema = music_streaming_data.schema
265-
self.user_schema: Schema = mm.select_schema(self.schema, Tags.USER)
266-
267-
def test_forward(self):
268-
select_user = mm.SelectKeys(self.user_schema)
269-
outputs = select_user(self.batch.features)
270-
271-
assert select_user.schema == self.user_schema
272-
273-
for col in {"user_id", "country", "user_age"}:
274-
assert col in outputs
275-
276-
assert "user_genres__values" in outputs
277-
assert "user_genres__offsets" in outputs
278-
279-
def test_select(self):
280-
select_user = mm.SelectKeys(self.user_schema)
281-
282-
user_id = Schema([self.user_schema["user_id"]])
283-
assert select_user.select(ColumnSchema("user_id")).schema == user_id
284-
assert select_user.select(Tags.USER).schema == self.user_schema
285-
286-
def test_setup_schema(self):
287-
select_user = mm.SelectKeys()
288-
select_user.setup_schema(self.user_schema["user_id"])
289-
assert select_user.schema == Schema([self.user_schema["user_id"]])
290-
=======
291-
self.data = music_streaming_data
292-
self.schema = music_streaming_data.schema
293-
self.select_keys = SelectKeys(music_streaming_data.schema)
294-
295-
def test_forward(self):
296-
...
297-
>>>>>>> a2644079 (Add selection_utils)
298-
=======
299134
self.batch: Batch = sample_batch(music_streaming_data, batch_size=10)
300135
self.schema: Schema = music_streaming_data.schema
301136
self.user_schema: Schema = mm.select_schema(self.schema, Tags.USER)
@@ -323,4 +158,3 @@ def test_setup_schema(self):
323158
select_user = mm.SelectKeys()
324159
select_user.setup_schema(self.user_schema["user_id"])
325160
assert select_user.schema == Schema([self.user_schema["user_id"]])
326-
>>>>>>> 89a6f043 (Increase test-coverage)

0 commit comments

Comments
 (0)