Skip to content

Commit 8a97cb8

Browse files
authored
Adding RouterBlock (#1096)
* Adding improved doc-strings * Adding torch github-action + add copyright * Increase test-coverage * Expose Sequence in __init__ * Give n in repeat a default value * Adding BlockContainerDict * Adding ParallelBlock * Adding sample_batch & sample_features * Output Batch instead * Add selection_utils * Increase test-coverage * Fixing doc-strings * Some tiny doc-string updates * Increasing test-coverage * Fix failined nested router test * Making test-classes for functions camel-case
1 parent 06f877a commit 8a97cb8

File tree

7 files changed

+656
-6
lines changed

7 files changed

+656
-6
lines changed

merlin/models/torch/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@
1919
from merlin.models.torch.outputs.base import ModelOutput
2020
from merlin.models.torch.outputs.classification import BinaryOutput
2121
from merlin.models.torch.outputs.regression import RegressionOutput
22+
from merlin.models.torch.router import RouterBlock, SelectKeys, select_schema
2223
from merlin.models.torch.transforms.agg import Concat, Stack
2324

2425
__all__ = [
2526
"Batch",
26-
"Concat",
2727
"BinaryOutput",
2828
"Block",
29-
"ModelOutput",
3029
"ParallelBlock",
31-
"RegressionOutput",
30+
"ModelOutput",
3231
"Sequence",
32+
"RegressionOutput",
33+
"RouterBlock",
34+
"SelectKeys",
35+
"select_schema",
36+
"Concat",
3337
"Stack",
3438
]

merlin/models/torch/block.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,12 @@ def forward(
190190
outputs = {}
191191
for name, branch_container in self.branches.items():
192192
branch = inputs
193-
for module in branch_container.values:
194-
branch = module(branch, batch=batch)
193+
194+
if hasattr(branch_container, "branches"):
195+
branch = branch_container(branch, batch=batch)
196+
else:
197+
for module in branch_container.values:
198+
branch = module(branch, batch=batch)
195199

196200
if isinstance(branch, torch.Tensor):
197201
branch_dict = {name: branch}

merlin/models/torch/router.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
from copy import deepcopy
2+
from typing import Dict, List, Optional
3+
4+
import torch
5+
from torch import nn
6+
7+
from merlin.models.torch.block import Block, ParallelBlock
8+
from merlin.models.torch.utils.selection_utils import (
9+
Selectable,
10+
Selection,
11+
select_schema,
12+
selection_name,
13+
)
14+
from merlin.schema import ColumnSchema, Schema
15+
16+
17+
class RouterBlock(ParallelBlock, Selectable):
18+
"""A block that routes features by selecting them from a selectable object.
19+
20+
Example usage::
21+
22+
router = RouterBlock(schema)
23+
router.add_route(Tags.CONTINUOUS)
24+
router.add_route(Tags.CATEGORICAL, mm.Embeddings(dim=64))
25+
router.add_route(Tags.EMBEDDING, mm.MLPBlock([64, 32]))
26+
27+
Parameters
28+
----------
29+
selectable : Selectable
30+
The selectable object from which to select features.
31+
32+
Attributes
33+
----------
34+
selectable : Selectable
35+
The selectable object from which to select features.
36+
"""
37+
38+
def __init__(self, selectable: Selectable):
39+
super().__init__()
40+
if isinstance(selectable, Schema):
41+
selectable = SelectKeys(selectable)
42+
43+
self.selectable: Selectable = selectable
44+
45+
def add_route(
46+
self,
47+
selection: Selection,
48+
module: Optional[nn.Module] = None,
49+
name: Optional[str] = None,
50+
) -> "RouterBlock":
51+
"""Add a new routing path for a given selection.
52+
53+
Example usage::
54+
55+
router.add_route(Tags.CONTINUOUS)
56+
57+
Example usage with module::
58+
59+
router.add_route(Tags.CONTINUOUS, MLPBlock([64, 32]]))
60+
61+
Parameters
62+
----------
63+
selection : Selection
64+
The selection to apply to the selectable.
65+
module : nn.Module, optional
66+
The module to append to the branch after selection.
67+
name : str, optional
68+
The name of the branch. Default is the name of the selection.
69+
70+
Returns
71+
-------
72+
RouterBlock
73+
The router block with the new route added.
74+
"""
75+
76+
routing_module = self.selectable.select(selection)
77+
if module is not None:
78+
if hasattr(module, "setup_schema"):
79+
module.setup_schema(routing_module.schema)
80+
81+
if isinstance(module, ParallelBlock):
82+
branch = module.prepend(routing_module)
83+
else:
84+
branch = Block(routing_module, module)
85+
else:
86+
branch = routing_module
87+
88+
_name: str = name or selection_name(selection)
89+
if _name in self.branches:
90+
raise ValueError(f"Branch with name {_name} already exists")
91+
self.branches[_name] = branch
92+
93+
return self
94+
95+
def add_route_for_each(
96+
self, selection: Selection, module: nn.Module, shared=False
97+
) -> "RouterBlock":
98+
"""Add a new route for each column in a selection.
99+
100+
Example usage::
101+
102+
router.add_route_for_each(Tags.EMBEDDING, mm.MLPBlock([64, 32]]))
103+
104+
Parameters
105+
----------
106+
selection : Selection
107+
The selections to apply to the selectable.
108+
module : nn.Module
109+
The module to append to each branch after selection.
110+
shared : bool, optional
111+
Whether to use the same module instance for each selection.
112+
113+
Returns
114+
-------
115+
RouterBlock
116+
The router block with the new routes added.
117+
"""
118+
119+
if isinstance(selection, (list, tuple)):
120+
for sel in selection:
121+
self.add_route_for_each(sel, module, shared=shared)
122+
123+
return self
124+
125+
selected = select_schema(self.selectable.schema, selection)
126+
127+
for col in selected:
128+
col_module = module if shared else deepcopy(module)
129+
self.add_route(col, col_module, name=col.name)
130+
131+
return self
132+
133+
def nested_router(self) -> "RouterBlock":
134+
"""Create a new nested router block.
135+
136+
This method is useful for creating hierarchical routing structures.
137+
For example, you might want to route continuous and categorical features differently,
138+
and then within each of these categories, route user- and item-features differently.
139+
This can be achieved by calling `nested_router` to create a second level of routing.
140+
141+
This approach allows for constructing networks with shared computation,
142+
such as shared embedding tables (like for instance user_genres and item_genres columns).
143+
This can improve performance and efficiency.
144+
145+
Example usage::
146+
router = RouterBlock(selectable)
147+
# First level of routing: separate continuous and categorical features
148+
router.add_route(Tags.CONTINUOUS)
149+
router.add_route(Tags.CATEGORICAL, mm.Embeddings())
150+
151+
# Second level of routing: separate user- and item-features
152+
two_tower = router.nested_router()
153+
two_tower.add_route(Tags.USER, mm.MLPBlock([64, 32]))
154+
two_tower.add_route(Tags.ITEM, mm.MLPBlock([64, 32]))
155+
156+
Returns
157+
-------
158+
RouterBlock
159+
A new router block with the current block as its selectable.
160+
"""
161+
162+
if hasattr(self, "_forward_called"):
163+
# We don't need to track the schema since we will be using the nested router
164+
self._handle.remove()
165+
166+
return RouterBlock(self)
167+
168+
def select(self, selection: Selection) -> "RouterBlock":
169+
"""Select a subset of the branches based on the provided selection.
170+
171+
Parameters
172+
----------
173+
selection : Selection
174+
The selection to apply to the branches.
175+
176+
Returns
177+
-------
178+
RouterBlock
179+
A new router block with the selected branches.
180+
"""
181+
182+
selected_branches = {}
183+
for key, val in self.branches.items():
184+
if len(val) == 1:
185+
val = val[0]
186+
187+
selected_branches[key] = val.select(selection)
188+
189+
selectable = self.__class__(self.selectable.select(selection))
190+
for key, val in selected_branches.items():
191+
selectable.branches[key] = val
192+
193+
selectable.pre = self.pre
194+
selectable.post = self.post
195+
196+
return selectable
197+
198+
199+
class SelectKeys(nn.Module, Selectable):
200+
"""Filter tabular data based on a defined schema.
201+
202+
Example usage::
203+
204+
>>> select_keys = mm.SelectKeys(Schema(["user_id", "item_id"]))
205+
>>> inputs = {
206+
... "user_id": torch.tensor([1, 2, 3]),
207+
... "item_id": torch.tensor([4, 5, 6]),
208+
... "other_key": torch.tensor([7, 8, 9]),
209+
... }
210+
>>> outputs = select_keys(inputs)
211+
>>> print(outputs.keys())
212+
dict_keys(['user_id', 'item_id'])
213+
214+
Parameters
215+
----------
216+
schema : Schema, optional
217+
The schema to use for selection. Default is None.
218+
219+
Attributes
220+
----------
221+
col_names : list
222+
List of column names in the schema.
223+
"""
224+
225+
def __init__(self, schema: Optional[Schema] = None):
226+
super().__init__()
227+
if schema:
228+
self.setup_schema(schema)
229+
230+
def setup_schema(self, schema: Schema):
231+
if isinstance(schema, ColumnSchema):
232+
schema = Schema([schema])
233+
234+
super().setup_schema(schema)
235+
236+
self.col_names: List[str] = schema.column_names
237+
238+
def select(self, selection: Selection) -> "SelectKeys":
239+
"""Select a subset of the schema based on the provided selection.
240+
241+
Parameters
242+
----------
243+
selection : Selection
244+
The selection to apply to the schema.
245+
246+
Returns
247+
-------
248+
SelectKeys
249+
A new SelectKeys instance with the selected schema.
250+
"""
251+
252+
return SelectKeys(select_schema(self.schema, selection))
253+
254+
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
255+
"""Only keep the inputs that are present in the schema.
256+
257+
Parameters
258+
----------
259+
inputs : dict
260+
A dictionary of torch.Tensor objects.
261+
262+
Returns
263+
-------
264+
dict
265+
A dictionary of torch.Tensor objects after selection.
266+
"""
267+
268+
outputs = {}
269+
270+
for key, val in inputs.items():
271+
_key = key
272+
if key.endswith("__values"):
273+
_key = key[: -len("__values")]
274+
elif key.endswith("__offsets"):
275+
_key = key[: -len("__offsets")]
276+
277+
if _key in self.col_names:
278+
outputs[key] = val
279+
280+
return outputs

0 commit comments

Comments
 (0)