Skip to content

Commit 7a0e221

Browse files
add session-based example with pretrained embeddings (#1102)
* add synthetic dataset based on the SIGIR dataset * add notebook demonstrating training a transformer model with pretrained embeddings (including serving) * add a unit test for the notebook
1 parent 07826b8 commit 7a0e221

File tree

8 files changed

+17162
-0
lines changed

8 files changed

+17162
-0
lines changed

examples/usecases/transformers-next-item-prediction-with-pretrained-embeddings.ipynb

Lines changed: 16781 additions & 0 deletions
Large diffs are not rendered by default.

merlin/datasets/ecommerce/sigir/__init__.py

Whitespace-only changes.

merlin/datasets/ecommerce/sigir/browsing_train/__init__.py

Whitespace-only changes.
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
{
2+
"feature": [
3+
{
4+
"name": "session_id_hash",
5+
"type": "INT",
6+
"intDomain": {
7+
"name": "session_id_hash",
8+
"max": "999",
9+
"isCategorical": true
10+
},
11+
"annotation": {
12+
"tag": [
13+
"item_id",
14+
"item",
15+
"categorical",
16+
"id"
17+
],
18+
"extraMetadata": [
19+
{
20+
"num_buckets": null,
21+
"freq_threshold": 0.0,
22+
"max_size": 1000.0,
23+
"start_index": 0.0,
24+
"cat_path": ".//categories/unique.session_id_hash.parquet",
25+
"embedding_sizes": {
26+
"cardinality": 1000.0,
27+
"dimension": 77.0
28+
},
29+
"_dims": [
30+
[
31+
0.0,
32+
null
33+
]
34+
],
35+
"is_list": false,
36+
"is_ragged": false,
37+
"dtype_item_size": 64.0
38+
}
39+
]
40+
}
41+
},
42+
{
43+
"name": "event_type",
44+
"type": "INT",
45+
"intDomain": {
46+
"name": "event_type",
47+
"max": "2",
48+
"isCategorical": true
49+
},
50+
"annotation": {
51+
"tag": [
52+
"categorical"
53+
],
54+
"extraMetadata": [
55+
{
56+
"num_buckets": null,
57+
"freq_threshold": 0.0,
58+
"max_size": 1000.0,
59+
"start_index": 0.0,
60+
"cat_path": ".//categories/unique.event_type.parquet",
61+
"embedding_sizes": {
62+
"cardinality": 3.0,
63+
"dimension": 16.0
64+
},
65+
"_dims": [
66+
[
67+
0.0,
68+
null
69+
]
70+
],
71+
"is_list": false,
72+
"is_ragged": false,
73+
"dtype_item_size": 64.0
74+
}
75+
]
76+
}
77+
},
78+
{
79+
"name": "product_action",
80+
"type": "INT",
81+
"intDomain": {
82+
"name": "product_action",
83+
"max": "4",
84+
"isCategorical": true
85+
},
86+
"annotation": {
87+
"tag": [
88+
"categorical"
89+
],
90+
"extraMetadata": [
91+
{
92+
"num_buckets": null,
93+
"freq_threshold": 0.0,
94+
"max_size": 1000.0,
95+
"start_index": 0.0,
96+
"cat_path": ".//categories/unique.product_action.parquet",
97+
"embedding_sizes": {
98+
"cardinality": 5.0,
99+
"dimension": 16.0
100+
},
101+
"_dims": [
102+
[
103+
0.0,
104+
null
105+
]
106+
],
107+
"is_list": false,
108+
"is_ragged": false,
109+
"dtype_item_size": 64.0
110+
}
111+
]
112+
}
113+
},
114+
{
115+
"name": "product_sku_hash",
116+
"type": "INT",
117+
"intDomain": {
118+
"name": "product_sku_hash",
119+
"max": "999",
120+
"isCategorical": true
121+
},
122+
"annotation": {
123+
"tag": [
124+
"categorical"
125+
],
126+
"extraMetadata": [
127+
{
128+
"num_buckets": null,
129+
"freq_threshold": 0.0,
130+
"max_size": 1000.0,
131+
"start_index": 0.0,
132+
"cat_path": ".//categories/unique.product_sku_hash.parquet",
133+
"embedding_sizes": {
134+
"cardinality": 1000.0,
135+
"dimension": 77.0
136+
},
137+
"_dims": [
138+
[
139+
0.0,
140+
null
141+
]
142+
],
143+
"is_list": false,
144+
"is_ragged": false,
145+
"dtype_item_size": 64.0
146+
}
147+
]
148+
}
149+
},
150+
{
151+
"name": "hashed_url",
152+
"type": "INT",
153+
"intDomain": {
154+
"name": "hashed_url",
155+
"max": "999",
156+
"isCategorical": true
157+
},
158+
"annotation": {
159+
"tag": [
160+
"categorical"
161+
],
162+
"extraMetadata": [
163+
{
164+
"num_buckets": null,
165+
"freq_threshold": 0.0,
166+
"max_size": 1000.0,
167+
"start_index": 0.0,
168+
"cat_path": ".//categories/unique.hashed_url.parquet",
169+
"embedding_sizes": {
170+
"cardinality": 1000.0,
171+
"dimension": 77.0
172+
},
173+
"_dims": [
174+
[
175+
0.0,
176+
null
177+
]
178+
],
179+
"is_list": false,
180+
"is_ragged": false,
181+
"dtype_item_size": 64.0
182+
}
183+
]
184+
}
185+
},
186+
{
187+
"name": "server_timestamp_epoch_ms",
188+
"type": "FLOAT",
189+
"annotation": {
190+
"tag": [
191+
"continuous"
192+
],
193+
"extraMetadata": [
194+
{
195+
"_dims": [
196+
[
197+
0.0,
198+
null
199+
]
200+
],
201+
"is_list": false,
202+
"is_ragged": false,
203+
"dtype_item_size": 64.0
204+
}
205+
]
206+
}
207+
}
208+
]
209+
}

merlin/datasets/ecommerce/sigir/sku_information/__init__.py

Whitespace-only changes.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
{
2+
"feature": [
3+
{
4+
"name": "product_sku_hash",
5+
"type": "INT",
6+
"intDomain": {
7+
"name": "product_sku_hash",
8+
"max": "999",
9+
"isCategorical": true
10+
},
11+
"annotation": {
12+
"tag": [
13+
"id",
14+
"categorical",
15+
"item"
16+
],
17+
"extraMetadata": [
18+
{
19+
"num_buckets": null,
20+
"freq_threshold": 0.0,
21+
"max_size": 1000.0,
22+
"cat_path": ".//categories/unique.product_sku_hash.parquet",
23+
"embedding_sizes": {
24+
"cardinality": 1000.0,
25+
"dimension": 77.0
26+
},
27+
"_dims": [
28+
[
29+
0.0,
30+
null
31+
]
32+
],
33+
"is_list": false,
34+
"is_ragged": false,
35+
"dtype_item_size": 64.0
36+
}
37+
]
38+
}
39+
},
40+
{
41+
"name": "description_vector",
42+
"type": "FLOAT",
43+
"floatDomain": {
44+
"min": -0.44,
45+
"max": 0.603
46+
},
47+
"annotation": {
48+
"tag": [
49+
"item"
50+
],
51+
"extraMetadata": [
52+
{
53+
"_dims": [
54+
[
55+
0.0,
56+
null
57+
],
58+
[
59+
50,
60+
50
61+
]
62+
],
63+
"is_list": true,
64+
"is_ragged": true,
65+
"dtype_item_size": 64.0
66+
}
67+
]
68+
}
69+
},
70+
{
71+
"name": "category_hash",
72+
"type": "INT",
73+
"intDomain": {
74+
"name": "category_hash",
75+
"max": "174",
76+
"isCategorical": true
77+
},
78+
"annotation": {
79+
"tag": [
80+
"item",
81+
"item_id",
82+
"categorical",
83+
"id"
84+
],
85+
"extraMetadata": [
86+
{
87+
"num_buckets": null,
88+
"freq_threshold": 0.0,
89+
"max_size": 1000.0,
90+
"start_index": 0.0,
91+
"cat_path": ".//categories/unique.category_hash.parquet",
92+
"embedding_sizes": {
93+
"cardinality": 175.0,
94+
"dimension": 29.0
95+
},
96+
"_dims": [
97+
[
98+
0.0,
99+
null
100+
]
101+
],
102+
"is_list": false,
103+
"is_ragged": false,
104+
"dtype_item_size": 64.0
105+
}
106+
]
107+
}
108+
},
109+
{
110+
"name": "price_bucket",
111+
"type": "FLOAT",
112+
"annotation": {
113+
"tag": [
114+
"continuous"
115+
],
116+
"extraMetadata": [
117+
{
118+
"_dims": [
119+
[
120+
0.0,
121+
null
122+
]
123+
],
124+
"is_list": false,
125+
"is_ragged": false,
126+
"dtype_item_size": 64.0
127+
}
128+
]
129+
}
130+
}
131+
]
132+
}

merlin/datasets/synthetic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
"booking.com": HERE / "ecommerce/booking/transformed/",
5050
"booking.com-raw": HERE / "ecommerce/booking/raw/",
5151
"transactions": HERE / "ecommerce/transactions",
52+
"sigir-browsing": HERE / "ecommerce/sigir/browsing_train",
53+
"sigir-sku": HERE / "ecommerce/sigir/sku_information",
5254
}
5355

5456

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import shutil
2+
3+
import pytest
4+
from testbook import testbook
5+
6+
from tests.conftest import REPO_ROOT
7+
8+
pytest.importorskip("transformers")
9+
utils = pytest.importorskip("merlin.systems.triton.utils")
10+
11+
TRITON_SERVER_PATH = shutil.which("tritonserver")
12+
13+
14+
@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
15+
@testbook(
16+
REPO_ROOT
17+
/ "examples/usecases/transformers-next-item-prediction-with-pretrained-embeddings.ipynb",
18+
timeout=720,
19+
execute=False,
20+
)
21+
@pytest.mark.notebook
22+
def test_next_item_prediction(tb, tmpdir):
23+
tb.inject(
24+
f"""
25+
import os, random
26+
os.environ["OUTPUT_DATA_DIR"] = "{tmpdir}"
27+
os.environ["NUM_EPOCHS"] = "1"
28+
os.environ["NUM_EXAMPLES"] = "1_500"
29+
os.environ["MINIMUM_SESSION_LENGTH"] = "2"
30+
"""
31+
)
32+
tb.execute_cell(list(range(0, 48)))
33+
34+
with utils.run_triton_server(f"{tmpdir}/ensemble", grpc_port=8001):
35+
tb.execute_cell(list(range(48, len(tb.cells))))
36+
37+
predicted_hashed_url_id = tb.ref("predicted_hashed_url_id").item()
38+
assert predicted_hashed_url_id >= 0 and predicted_hashed_url_id <= 1002

0 commit comments

Comments
 (0)