Skip to content

Commit 7838693

Browse files
committed
Fix failined nested router test
1 parent 9434c6b commit 7838693

File tree

7 files changed

+5
-174
lines changed

7 files changed

+5
-174
lines changed

merlin/models/torch/batch.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Dict, Optional, Union
1818

1919
import torch
20-
import torch.nn
2120

2221
from merlin.dataloader.torch import Loader
2322
from merlin.io import Dataset
@@ -321,8 +320,6 @@ def sample_batch(
321320
dictionary of target tensors.
322321
"""
323322

324-
<<<<<<< HEAD
325-
<<<<<<< HEAD
326323
if isinstance(data, Dataset):
327324
if not batch_size:
328325
raise ValueError("Either use 'Loader' or specify 'batch_size'")
@@ -331,37 +328,16 @@ def sample_batch(
331328
loader = data
332329
else:
333330
raise ValueError(f"Expected Dataset or Loader instance, got: {data}")
334-
=======
335-
=======
336-
if not isinstance(dataset_or_loader, (Dataset, Loader)):
337-
raise ValueError(f"Expected Dataset or Loader instance, got {dataset_or_loader}")
338-
339-
>>>>>>> 519159b5 (Increasing test-coverage)
340-
if isinstance(dataset_or_loader, Dataset):
341-
if not batch_size:
342-
raise ValueError("Either use 'Loader' or specify 'batch_size'")
343-
loader = Loader(dataset_or_loader, batch_size=batch_size, shuffle=shuffle)
344-
else:
345-
loader = dataset_or_loader
346-
>>>>>>> d00e67a4 (Adding sample_batch & sample_features)
347331

348332
batch = loader.peek()
349333
# batch could be of type Prediction, so we can't unpack directly
350334
inputs, targets = batch[0], batch[1]
351335

352-
<<<<<<< HEAD
353336
return Batch(inputs, targets)
354337

355338

356339
def sample_features(
357340
data: Union[Dataset, Loader],
358-
=======
359-
return inputs, targets
360-
361-
362-
def sample_features(
363-
dataset_or_loader: Union[Dataset, Loader],
364-
>>>>>>> d00e67a4 (Adding sample_batch & sample_features)
365341
batch_size: Optional[int] = None,
366342
shuffle: Optional[bool] = False,
367343
) -> Dict[str, torch.Tensor]:
@@ -382,8 +358,4 @@ def sample_features(
382358
dictionary of feature tensors.
383359
"""
384360

385-
<<<<<<< HEAD
386361
return sample_batch(data, batch_size, shuffle).features
387-
=======
388-
return sample_batch(dataset_or_loader, batch_size, shuffle)[0]
389-
>>>>>>> d00e67a4 (Adding sample_batch & sample_features)

merlin/models/torch/container.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,9 @@
1414
# limitations under the License.
1515
#
1616

17-
<<<<<<< HEAD
18-
<<<<<<< HEAD
1917
from copy import deepcopy
2018
from functools import reduce
2119
from typing import Dict, Iterator, Optional, Union
22-
=======
23-
from typing import Iterator, Optional, Union
24-
>>>>>>> 2dfe2782 (Adding torch github-action + add copyright)
25-
=======
26-
from copy import deepcopy
27-
from functools import reduce
28-
from typing import Dict, Iterator, Optional, Union
29-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
3020

3121
from torch import nn
3222
from torch._jit_internal import _copy_to_script_wrapper
@@ -57,33 +47,25 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None):
5747

5848
self._name: str = name
5949

60-
<<<<<<< HEAD
6150
def append(self, module: nn.Module, link: Optional[Link] = None):
62-
=======
63-
def append(self, module: nn.Module):
64-
>>>>>>> 793f47ef (Adding improved doc-strings)
6551
"""Appends a given module to the end of the list.
6652
6753
Parameters
6854
----------
6955
module : nn.Module
7056
The PyTorch module to be appended.
71-
<<<<<<< HEAD
7257
link : Optional[LinkType]
7358
The link to use for the module. If None, no link is used.
7459
This can either be a Module or a string, options are:
7560
- "residual": Adds a residual connection to the module.
7661
- "shortcut": Adds a shortcut connection to the module.
7762
- "shortcut-concat": Adds a shortcut connection by concatenating
7863
the input and output.
79-
=======
80-
>>>>>>> 793f47ef (Adding improved doc-strings)
8164
8265
Returns
8366
-------
8467
self
8568
"""
86-
<<<<<<< HEAD
8769
_module = self._check_link(module, link=link)
8870
self.values.append(self.wrap_module(_module))
8971

@@ -111,27 +93,6 @@ def prepend(self, module: nn.Module, link: Optional[Link] = None):
11193
return self.insert(0, module, link=link)
11294

11395
def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
114-
=======
115-
self.values.append(self.wrap_module(module))
116-
117-
return self
118-
119-
def prepend(self, module: nn.Module):
120-
"""Prepends a given module to the beginning of the list.
121-
122-
Parameters
123-
----------
124-
module : nn.Module
125-
The PyTorch module to be prepended.
126-
127-
Returns
128-
-------
129-
self
130-
"""
131-
return self.insert(0, module)
132-
133-
def insert(self, index: int, module: nn.Module):
134-
>>>>>>> 793f47ef (Adding improved doc-strings)
13596
"""Inserts a given module at the specified index.
13697
13798
Parameters
@@ -140,28 +101,20 @@ def insert(self, index: int, module: nn.Module):
140101
The index at which the module is to be inserted.
141102
module : nn.Module
142103
The PyTorch module to be inserted.
143-
<<<<<<< HEAD
144104
link : Optional[LinkType]
145105
The link to use for the module. If None, no link is used.
146106
This can either be a Module or a string, options are:
147107
- "residual": Adds a residual connection to the module.
148108
- "shortcut": Adds a shortcut connection to the module.
149109
- "shortcut-concat": Adds a shortcut connection by concatenating
150110
the input and output.
151-
=======
152-
>>>>>>> 793f47ef (Adding improved doc-strings)
153111
154112
Returns
155113
-------
156114
self
157115
"""
158-
<<<<<<< HEAD
159116
_module = self._check_link(module, link=link)
160117
self.values.insert(index, self.wrap_module(_module))
161-
=======
162-
163-
self.values.insert(index, self.wrap_module(module))
164-
>>>>>>> 793f47ef (Adding improved doc-strings)
165118

166119
return self
167120

@@ -222,7 +175,6 @@ def __repr__(self) -> str:
222175
def _get_name(self) -> str:
223176
return super()._get_name() if self._name is None else self._name
224177

225-
<<<<<<< HEAD
226178
def _check_link(self, module: nn.Module, link: Optional[LinkType] = None) -> nn.Module:
227179
if link:
228180
linked_module: Link = Link.parse(link)
@@ -246,31 +198,18 @@ class BlockContainerDict(nn.ModuleDict):
246198
An optional name for the BlockContainer.
247199
"""
248200

249-
=======
250-
251-
class BlockContainerDict(nn.ModuleDict):
252-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
253201
def __init__(
254202
self, *inputs: Union[nn.Module, Dict[str, nn.Module]], name: Optional[str] = None
255203
) -> None:
256204
if not inputs:
257205
inputs = [{}]
258206

259-
<<<<<<< HEAD
260-
<<<<<<< HEAD
261-
=======
262-
if isinstance(inputs, tuple) and len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
263-
modules = inputs[0]
264-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
265-
=======
266-
>>>>>>> 77ca69b4 (Adding ParallelBlock)
267207
if all(isinstance(x, dict) for x in inputs):
268208
modules = reduce(lambda a, b: dict(a, **b), inputs) # type: ignore
269209

270210
super().__init__(modules)
271211
self._name: str = name
272212

273-
<<<<<<< HEAD
274213
def append_to(
275214
self, name: str, module: nn.Module, link: Optional[LinkType] = None
276215
) -> "BlockContainerDict":
@@ -390,37 +329,6 @@ def prepend_for_each(
390329

391330
return self
392331

393-
=======
394-
def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
395-
self._modules[name].append(module)
396-
397-
return self
398-
399-
def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
400-
self._modules[name].prepend(module)
401-
402-
return self
403-
404-
# Append to all branches, optionally copying
405-
def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
406-
for branch in self.values():
407-
_module = module if shared else deepcopy(module)
408-
branch.append(_module)
409-
410-
return self
411-
412-
def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
413-
for branch in self.values():
414-
_module = module if shared else deepcopy(module)
415-
branch.prepend(_module)
416-
417-
return self
418-
419-
# This assumes same branches, we append it's content to each branch
420-
# def append_parallel(self, module: "BlockContainerDict") -> "BlockContainerDict":
421-
# ...
422-
423-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
424332
def add_module(self, name: str, module: Optional[nn.Module]) -> None:
425333
if module and not isinstance(module, BlockContainer):
426334
module = BlockContainer(module, name=name[0].upper() + name[1:])

merlin/models/torch/router.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def nested_router(self) -> "RouterBlock":
159159
A new router block with the current block as its selectable.
160160
"""
161161

162+
if hasattr(self, "_forward_called"):
163+
# We don't need to track the schema since we will be using the nested router
164+
self._handle.remove()
165+
162166
return RouterBlock(self)
163167

164168
def select(self, selection: Selection) -> "RouterBlock":

tests/unit/torch/test_block.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
<<<<<<< HEAD
17-
<<<<<<< HEAD
18-
=======
19-
>>>>>>> 77ca69b4 (Adding ParallelBlock)
2016
from typing import Dict, Tuple
2117

2218
import pytest
23-
<<<<<<< HEAD
24-
=======
25-
26-
>>>>>>> 2dfe2782 (Adding torch github-action + add copyright)
27-
=======
28-
>>>>>>> f7011343 (Increase test-coverage)
2919
import torch
3020
from torch import nn
3121

@@ -104,8 +94,6 @@ def test_repeat(self):
10494

10595
with pytest.raises(ValueError, match="n must be greater than 0"):
10696
block.repeat(0)
107-
<<<<<<< HEAD
108-
<<<<<<< HEAD
10997

11098
def test_repeat_with_link(self):
11199
block = Block(PlusOne())
@@ -133,8 +121,6 @@ def forward(self, inputs):
133121

134122
inputs = torch.randn(1, 3)
135123
assert torch.equal(block(inputs), inputs + 1)
136-
=======
137-
>>>>>>> 77ca69b4 (Adding ParallelBlock)
138124

139125

140126
class TestParallelBlock:
@@ -171,7 +157,6 @@ def test_forward_dict_duplicate(self):
171157
with pytest.raises(RuntimeError):
172158
pb(inputs)
173159

174-
<<<<<<< HEAD
175160
def test_schema_tracking(self):
176161
pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()})
177162

@@ -186,8 +171,6 @@ def test_schema_tracking(self):
186171

187172
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2
188173

189-
=======
190-
>>>>>>> 77ca69b4 (Adding ParallelBlock)
191174
def test_forward_tuple(self):
192175
inputs = torch.randn(1, 3)
193176
pb = ParallelBlock({"test": PlusOneTuple()})
@@ -260,8 +243,3 @@ def test_getitem(self):
260243

261244
with pytest.raises(IndexError):
262245
pb["invalid_key"]
263-
<<<<<<< HEAD
264-
=======
265-
>>>>>>> f7011343 (Increase test-coverage)
266-
=======
267-
>>>>>>> 77ca69b4 (Adding ParallelBlock)

tests/unit/torch/test_container.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,6 @@ def test_empty(self):
178178
container = BlockContainerDict()
179179
assert len(container) == 0
180180

181-
<<<<<<< HEAD
182-
<<<<<<< HEAD
183-
=======
184-
def test_list_of_dict(self):
185-
container = BlockContainerDict(({"test": self.module}))
186-
assert len(container) == 1
187-
assert "test" in container
188-
189-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
190-
=======
191-
>>>>>>> 77ca69b4 (Adding ParallelBlock)
192181
def test_not_module(self):
193182
with pytest.raises(ValueError):
194183
BlockContainerDict({"test": "not a module"})
@@ -197,22 +186,16 @@ def test_append_to(self):
197186
self.container.append_to("test", self.module)
198187
assert "test" in self.container._modules
199188

200-
<<<<<<< HEAD
201189
self.container.append_to("test", self.module, link="residual")
202190
assert isinstance(self.container["test"][-1], link.Residual)
203191

204-
=======
205-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
206192
def test_prepend_to(self):
207193
self.container.prepend_to("test", self.module)
208194
assert "test" in self.container._modules
209195

210-
<<<<<<< HEAD
211196
self.container.prepend_to("test", self.module, link="residual")
212197
assert isinstance(self.container["test"][0], link.Residual)
213198

214-
=======
215-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
216199
def test_append_for_each(self):
217200
container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()})
218201

@@ -227,13 +210,10 @@ def test_append_for_each(self):
227210
assert len(container["b"]) == 3
228211
assert container["a"][-1] == container["b"][-1]
229212

230-
<<<<<<< HEAD
231213
container.append_for_each(to_add, link="residual")
232214
assert isinstance(container["a"][-1], link.Residual)
233215
assert isinstance(container["b"][-1], link.Residual)
234216

235-
=======
236-
>>>>>>> cbf7e956 (Adding BlockContainerDict)
237217
def test_prepend_for_each(self):
238218
container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()})
239219

@@ -247,10 +227,7 @@ def test_prepend_for_each(self):
247227
assert len(container["a"]) == 3
248228
assert len(container["b"]) == 3
249229
assert container["a"][0] == container["b"][0]
250-
<<<<<<< HEAD
251230

252231
container.prepend_for_each(to_add, link="residual")
253232
assert isinstance(container["a"][0], link.Residual)
254233
assert isinstance(container["b"][0], link.Residual)
255-
=======
256-
>>>>>>> cbf7e956 (Adding BlockContainerDict)

0 commit comments

Comments
 (0)