1- < << << << HEAD
2- < << << << HEAD
31from typing import Dict
42
53import pytest
@@ -99,142 +97,19 @@ def test_select(self):
9997 plus_one = PlusOneDict ()
10098
10199 self .router .add_route (Tags .CONTINUOUS )
102- self .router .add_route (Tags .CATEGORICAL )
100+ self .router .add_route (Tags .CATEGORICAL , mm .MLPBlock ([10 ]))
101+ self .router .add_route (Tags .USER , mm .MLPBlock ([10 ]).prepend (mm .SelectKeys (self .schema )))
102+ self .router .add_route (Tags .ITEM , mm .MLPBlock ([10 ]))
103103 self .router .prepend (plus_one )
104104
105105 router = self .router .select (Tags .CATEGORICAL )
106106 assert router .selectable .schema == self .schema .select_by_tag (Tags .CATEGORICAL )
107107 assert router [0 ][0 ] == plus_one
108108
109- def test_double_add (self ):
110- self .router .add_route (Tags .CONTINUOUS )
111- with pytest .raises (ValueError ):
112- self .router .add_route (Tags .CONTINUOUS )
113-
114- def test_nested (self ):
115- self .router .add_route (Tags .CONTINUOUS )
116-
117- nested = self .router .nested_router ()
118- nested .add_route (Tags .USER )
119- assert "user" in nested
120-
121- outputs = module_utils .module_test (nested , self .batch .features )
122- assert list (outputs .keys ()) == ["user_age" ]
123- assert "user_age" in nested .output_schema ().column_names
124- < << << << HEAD
125- == == == =
126- import pytest
127- == == == =
128- from typing import Dict
129- > >> >> >> 89 a6f043 (Increase test - coverage )
130-
131- import pytest
132- import torch
133- from torch import nn
134-
135- import merlin .models .torch as mm
136- from merlin .models .torch .batch import Batch , sample_batch
137- from merlin .models .torch .utils import module_utils
138- from merlin .schema import ColumnSchema , Schema , Tags
139-
140-
141- class ToFloat (nn .Module ):
142- def forward (self , x ):
143- return x .float ()
144-
145-
146- class PlusOneDict (nn .Module ):
147- def forward (self , inputs : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
148- return {k : v + 1 for k , v in inputs .items ()}
149-
150-
151- class TestRouterBlock :
152- < << << << HEAD
153- ...
154- > >> >> >> a2644079 (Add selection_utils )
155- == == == =
156- @pytest .fixture (autouse = True )
157- def setup_method (self , music_streaming_data ):
158- self .schema = music_streaming_data .schema
159- self .router : mm .RouterBlock = mm .RouterBlock (self .schema )
160- self .batch : Batch = sample_batch (music_streaming_data , batch_size = 10 )
161-
162- def test_add_route (self ):
163- self .router .add_route (Tags .CONTINUOUS )
164-
165- outputs = module_utils .module_test (self .router , self .batch .features )
166- assert set (outputs .keys ()) == set (self .schema .select_by_tag (Tags .CONTINUOUS ).column_names )
167- assert "continuous" in self .router
168- assert len (self .router ["continuous" ]) == 1
169- assert isinstance (self .router ["continuous" ][0 ], mm .SelectKeys )
170-
171- def test_add_route_module (self ):
172- class CustomSelect (mm .SelectKeys ):
173- ...
174-
175- self .router .add_route (Tags .CONTINUOUS , CustomSelect ())
176-
177- outputs = self .router (self .batch .features )
178- assert set (outputs .keys ()) == set (self .schema .select_by_tag (Tags .CONTINUOUS ).column_names )
179- assert len (self .router ["continuous" ]) == 2
180- assert isinstance (self .router ["continuous" ][0 ], mm .SelectKeys )
181- assert isinstance (self .router ["continuous" ][1 ], CustomSelect )
182-
183- def test_module_with_setup (self ):
184- class Dummy (nn .Module ):
185- def setup_schema (self , schema : Schema ):
186- self .schema = schema
187-
188- def forward (self , x ):
189- return x
190-
191- dummy = Dummy ()
192- self .router .add_route (Tags .CONTINUOUS , dummy )
193- assert dummy .schema == mm .select_schema (self .schema , Tags .CONTINUOUS )
194-
195- dummy_2 = Dummy ()
196- self .router .add_route_for_each (ColumnSchema ("user_id" ), dummy_2 , shared = True )
197- assert dummy_2 .schema == mm .select_schema (self .schema , ColumnSchema ("user_id" ))
198-
199- def test_add_route_parallel_block (self ):
200- class FakeEmbeddings (mm .ParallelBlock ):
201- ...
202-
203- self .router .add_route (Tags .CATEGORICAL , FakeEmbeddings ())
204- assert isinstance (self .router ["categorical" ], FakeEmbeddings )
205-
206- @pytest .mark .parametrize ("shared" , [True , False ])
207- def test_add_route_for_each (self , shared ):
208- block = mm .Block (mm .Concat (), ToFloat (), nn .LazyLinear (10 )).to (self .batch .device ())
209- self .router .add_route_for_each (Tags .CONTINUOUS , block , shared = shared )
210-
211- dense_pos = self .router .branches ["position" ][1 ][- 1 ]
212- dense_age = self .router .branches ["user_age" ][1 ][- 1 ]
213- if shared :
214- assert dense_pos == dense_age
215- else :
216- assert dense_pos != dense_age
217-
218- outputs = self .router (self .batch .features )
219- assert set (outputs .keys ()) == set (self .schema .select_by_tag (Tags .CONTINUOUS ).column_names )
220-
221- for value in outputs .values ():
222- assert value .shape [- 1 ] == 10
223-
224- def test_add_route_for_each_list (self ):
225- self .router .add_route_for_each ([ColumnSchema ("user_id" )], ToFloat ())
226- assert isinstance (self .router .branches ["user_id" ][1 ], ToFloat )
227-
228- def test_select (self ):
229- plus_one = PlusOneDict ()
230-
231- self .router .add_route (Tags .CONTINUOUS )
232- self .router .add_route (Tags .CATEGORICAL )
233- self .router .prepend (plus_one )
234-
235- router = self .router .select (Tags .CATEGORICAL )
236- assert router .selectable .schema == self .schema .select_by_tag (Tags .CATEGORICAL )
237- assert router [0 ][0 ] == plus_one
109+ user = self .router .select (Tags .USER )
110+ assert "item_recency" not in user .branches ["continuous" ][0 ].col_names
111+ assert "item_id" not in user .branches ["categorical" ][0 ].col_names
112+ assert "item" not in user .branches
238113
239114 def test_double_add (self ):
240115 self .router .add_route (Tags .CONTINUOUS )
@@ -250,52 +125,12 @@ def test_nested(self):
250125
251126 outputs = module_utils .module_test (nested , self .batch .features )
252127 assert list (outputs .keys ()) == ["user_age" ]
253- >> >> >> > 89 a6f043 (Increase test - coverage )
254- == == == =
255- >> >> >> > 78386932 (Fix failined nested router test )
128+ assert "user_age" in nested .output_schema ().column_names
256129
257130
258131class TestSelectKeys :
259132 @pytest .fixture (autouse = True )
260133 def setup_method (self , music_streaming_data ):
261- < << << << HEAD
262- < << << << HEAD
263- self .batch : Batch = sample_batch (music_streaming_data , batch_size = 10 )
264- self .schema : Schema = music_streaming_data .schema
265- self .user_schema : Schema = mm .select_schema (self .schema , Tags .USER )
266-
267- def test_forward (self ):
268- select_user = mm .SelectKeys (self .user_schema )
269- outputs = select_user (self .batch .features )
270-
271- assert select_user .schema == self .user_schema
272-
273- for col in {"user_id" , "country" , "user_age" }:
274- assert col in outputs
275-
276- assert "user_genres__values" in outputs
277- assert "user_genres__offsets" in outputs
278-
279- def test_select (self ):
280- select_user = mm .SelectKeys (self .user_schema )
281-
282- user_id = Schema ([self .user_schema ["user_id" ]])
283- assert select_user .select (ColumnSchema ("user_id" )).schema == user_id
284- assert select_user .select (Tags .USER ).schema == self .user_schema
285-
286- def test_setup_schema (self ):
287- select_user = mm .SelectKeys ()
288- select_user .setup_schema (self .user_schema ["user_id" ])
289- assert select_user .schema == Schema ([self .user_schema ["user_id" ]])
290- == == == =
291- self .data = music_streaming_data
292- self .schema = music_streaming_data .schema
293- self .select_keys = SelectKeys (music_streaming_data .schema )
294-
295- def test_forward (self ):
296- ...
297- >> >> >> > a2644079 (Add selection_utils )
298- == == == =
299134 self .batch : Batch = sample_batch (music_streaming_data , batch_size = 10 )
300135 self .schema : Schema = music_streaming_data .schema
301136 self .user_schema : Schema = mm .select_schema (self .schema , Tags .USER )
@@ -323,4 +158,3 @@ def test_setup_schema(self):
323158 select_user = mm .SelectKeys ()
324159 select_user .setup_schema (self .user_schema ["user_id" ])
325160 assert select_user .schema == Schema ([self .user_schema ["user_id" ]])
326- >> >> >> > 89 a6f043 (Increase test - coverage )
0 commit comments