-
-
Notifications
You must be signed in to change notification settings - Fork 224
CountVectorizer example #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
dccfa6e
68f348d
007ae47
f257e40
715de82
39cb3bd
4a59343
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Using Dask-ML's CountVectorizer\n", | ||
| "\n", | ||
| "Dask-ML includes a [CountVectorizer](https://ml.dask.org/modules/generated/dask_ml.feature_extraction.text.CountVectorizer.html#dask_ml.feature_extraction.text.CountVectorizer) that's appropriate for parallel / distributed processing of large datasets." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Loading Data\n", | ||
| "\n", | ||
| "As we'll see later, Dask-ML's `CountVectorizer` benefits from using the `dask.distributed` scheduler, even on a single machine." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from dask.distributed import Client\n", | ||
| "\n", | ||
| "client = Client(n_workers=4, threads_per_worker=1)\n", | ||
| "client" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "In this example, we'll work with the 20 newsgroups dataset from scikit-learn." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import sklearn.datasets\n", | ||
| "\n", | ||
| "news = sklearn.datasets.fetch_20newsgroups()\n", | ||
| "news['data'][:2]" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "This returns a list of documents (strings). Dask-ML's `CountVectorizer` expects a `dask.bag.Bag` of documents. We'll use `dask.delayed` to load the 20 newsgroups in parallel, taking care to load the data on the workers and not place large values (like `news['data']`) in the the task graph. See https://docs.dask.org/en/latest/best-practices.html#load-data-with-dask and https://docs.dask.org/en/latest/delayed-best-practices.html#don-t-call-dask-delayed-on-other-dask-collections for more on these concepts.\n", | ||
| "\n", | ||
| "This example is a bit contrived to get a Bag with multiple partitions. Typically the full dataset would be partitioned into multiple files on disk, and you'd load one partition per file. In this case, we split the single file into multiple partitions by loading the data and then slicing." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import dask\n", | ||
| "import numpy as np\n", | ||
| "import dask.bag as db\n", | ||
| "import toolz\n", | ||
| "\n", | ||
| "@dask.delayed\n", | ||
| "def load_news(slice_):\n", | ||
| " \"\"\"Load a slice of the 20 newsgroups dataset.\"\"\"\n", | ||
| " return sklearn.datasets.fetch_20newsgroups()['data'][slice_]\n", | ||
| "\n", | ||
| "npartitions = 10\n", | ||
| "partition_size = len(news['data']) // npartitions\n", | ||
| "\n", | ||
| "lengths = np.cumsum([partition_size] * npartitions)\n", | ||
| "lengths = [0] + list(lengths) + [None]\n", | ||
| "\n", | ||
| "slices = [slice(a, b) for a, b in\n", | ||
| " toolz.sliding_window(2, lengths)]\n", | ||
| "# Notice the persist here! More details later.\n", | ||
| "documents = db.from_delayed([load_news(x) for x in slices]).persist()\n", | ||
| "documents" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import dask_ml.feature_extraction.text" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
|
||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "vectorizer = dask_ml.feature_extraction.text.CountVectorizer()\n", | ||
| "%time result = vectorizer.fit_transform(documents)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "The call to `fit_transform` did some work to discover the *vocabulary*, a mapping from terms in the documents to positions in the transformed result array." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "list(vectorizer.vocabulary_.items())[:5]" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "Speaking of the result, it's a Dask `Array` backed by `scipy.sparse.csr_matrix` objects. We can bring it back to the client with `.compute()`" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "local_result = result.compute()\n", | ||
| "local_result[:5].toarray()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "Notice that we persisted `documents` earlier. If possible, persisting the input documents is preferable to avoid making two passes over the data. One to discover the vocabulary and a second to transform. If the dataset is larger than (distributed) memory, then two passes will be necessary." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## A note on vocabularies\n", | ||
| "\n", | ||
| "You can also provide a vocabulary ahead of time, which avoids the need for making two passes over the data. This makes operations like `vectorizer.transform` instantaneous, since no vocabulary needs to be discovered. However, vocabularies can become quite large. Consider persisting your data ahead of time to avoid bloating the size of the `CountVectorizer` object. Dask-ML's `CountVectorizer` works just fine when the `vocabulary` is a pointer to a piece of data on the cluster." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "vocabulary = vectorizer.vocabulary_\n", | ||
| "remote_vocabulary, = client.scatter([vocabulary], broadcast=True)\n", | ||
| "\n", | ||
| "vectorizer2 = dask_ml.feature_extraction.text.CountVectorizer(\n", | ||
| " vocabulary=remote_vocabulary\n", | ||
| ")" | ||
|
||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "%time result = vectorizer2.transform(documents)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "%time result.compute()" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.7.6" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 2 | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also call
db.read_sequence(..., npartitions=10).persist()and then callclient.rebalance()Given that people are going to blindly copy-paste whatever we do anyway I'd personally rather that they see this. It's a bit more in line with ordinary behavior I think.