Skip to content

Commit 76b0b5e

Browse files
authored
fix train_test split (#1291)
1 parent 7d470ae commit 76b0b5e

File tree

6 files changed

+448
-472
lines changed

6 files changed

+448
-472
lines changed

examples/getting-started-movielens/01-Download-Convert.ipynb

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@
5151
"# External dependencies\n",
5252
"import os\n",
5353
"\n",
54-
"from sklearn.model_selection import train_test_split\n",
55-
"\n",
5654
"from nvtabular.utils import download_file\n",
5755
"\n",
5856
"# Get dataframe library - cudf or pandas\n",
@@ -89,7 +87,16 @@
8987
"cell_type": "code",
9088
"execution_count": 4,
9189
"metadata": {},
92-
"outputs": [],
90+
"outputs": [
91+
{
92+
"name": "stderr",
93+
"output_type": "stream",
94+
"text": [
95+
"downloading ml-25m.zip: 262MB [00:06, 42.1MB/s] \n",
96+
"unzipping files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00, 1.74files/s]\n"
97+
]
98+
}
99+
],
93100
"source": [
94101
"download_file(\n",
95102
" \"http://files.grouplens.org/datasets/movielens/ml-25m.zip\",\n",
@@ -415,7 +422,7 @@
415422
"cell_type": "markdown",
416423
"metadata": {},
417424
"source": [
418-
"We drop the timestamp column and split the ratings into training and test dataset. We use a simple random split."
425+
"We drop the timestamp column and split the ratings into training and test datasets. We use a simple random split."
419426
]
420427
},
421428
{
@@ -425,9 +432,15 @@
425432
"outputs": [],
426433
"source": [
427434
"ratings = ratings.drop(\"timestamp\", axis=1)\n",
428-
"# convert ratings to pandas df to use sklearn train_test_split func\n",
429-
"ratings = ratings.to_pandas()\n",
430-
"train, valid = train_test_split(ratings, test_size=0.2, random_state=42)"
435+
"\n",
436+
"# shuffle the dataset\n",
437+
"ratings = ratings.sample(len(ratings), replace=False)\n",
438+
"\n",
439+
"# split the train_df as training and validation data sets.\n",
440+
"num_valid = int(len(ratings) * 0.2)\n",
441+
"\n",
442+
"train = ratings[:-num_valid]\n",
443+
"valid = ratings[-num_valid:]"
431444
]
432445
},
433446
{

0 commit comments

Comments
 (0)