-
Notifications
You must be signed in to change notification settings - Fork 558
Added Multi-Host TPU tutorial #9507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
4527a9f
to
8c0d217
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @vfdev-5 ! A few very straightforward comments and one question (should we use TensorBoard or XProf for profiling?0
|
||
Before diving into the code and the commands to execute, let us introduce some terminology. The section below is an adapted version of [JAX multi-host tutorial](https://docs.jax.dev/en/latest/multi_process.html#terminology). | ||
|
||
We sometimes call each Python process running PyTorch/XLA computations a controller or a host, but the terms are essentially synonymous. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
We sometimes call each Python process running PyTorch/XLA computations a controller or a host, but the terms are essentially synonymous. | |
We sometimes call each Python process running PyTorch/XLA computations a _controller_ or a _host_, but the terms are essentially synonymous. |
|
||
### Google Cloud tools setup | ||
|
||
We first need to install `gcloud` CLI. The official guide can be found [here](https://cloud.google.com/sdk/docs/install), below we provide the commands for Linux/Ubuntu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We first need to install `gcloud` CLI. The official guide can be found [here](https://cloud.google.com/sdk/docs/install), below we provide the commands for Linux/Ubuntu: | |
We first need to install the `gcloud` CLI. The [official guide](https://cloud.google.com/sdk/docs/install) has full installation instructions. Below we provide the commands for Linux/Ubuntu: |
<details> | ||
|
||
<summary> | ||
How to instal gcloud CLI on Linux/Ubuntu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to instal gcloud CLI on Linux/Ubuntu | |
How to install the gcloud CLI on Linux/Ubuntu |
|
||
</details> | ||
|
||
Next, we need to run the `gcloud` configuration by choosing the project, compute zone etc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Next, we need to run the `gcloud` configuration by choosing the project, compute zone etc: | |
Next, we need to run the `gcloud` configuration command by choosing the project and compute zone: |
|
||
#### (Optional) GCS Fuse tool | ||
|
||
We can optionally install `gcsfuse` tool to be able to mount Google Cloud Storage buckets. The official guide can be found [here](https://cloud.google.com/storage/docs/cloud-storage-fuse/install), below we provide the commands for Linux/Ubuntu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can optionally install `gcsfuse` tool to be able to mount Google Cloud Storage buckets. The official guide can be found [here](https://cloud.google.com/storage/docs/cloud-storage-fuse/install), below we provide the commands for Linux/Ubuntu: | |
We can optionally install the `gcsfuse` tool to be able to mount Google Cloud Storage buckets. The [official guide](https://cloud.google.com/storage/docs/cloud-storage-fuse/install) has full installation instructions. Below we provide the commands for Linux/Ubuntu: |
data, target = data.to(device), target.to(device) | ||
xs.mark_sharding(data, conv_mesh, ("data", "dim1", None, None)) | ||
``` | ||
The model contains convolutional and fully-connected layers and we also shard them. Convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model contains convolutional and fully-connected layers and we also shard them. Convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`. | |
The model contains convolutional and fully-connected layers and we also shard them. The convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`. |
xs.mark_sharding(data, conv_mesh, ("data", "dim1", None, None)) | ||
``` | ||
The model contains convolutional and fully-connected layers and we also shard them. Convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`. | ||
The fully-connected layer `fc1` is sharded similarly to the convolutions and `fc2` are sharded over two dimensions as following `(OutFeat // D1, InFeat // 2)`. In this script the model's sharding is defined directly inside the constructor method: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fully-connected layer `fc1` is sharded similarly to the convolutions and `fc2` are sharded over two dimensions as following `(OutFeat // D1, InFeat // 2)`. In this script the model's sharding is defined directly inside the constructor method: | |
The fully-connected layer `fc1` is sharded similarly to the convolutions and `fc2` are sharded over two dimensions as following `(OutFeat // D1, InFeat // 2)`. In this script, the model's sharding is defined directly inside the constructor method: |
|
||
``` | ||
|
||
Once the TPU VM is ready, we copy previously created `mnist_xla.py` file to all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the TPU VM is ready, we copy previously created `mnist_xla.py` file to all | |
Once the TPU VM is ready, we copy the previously created `mnist_xla.py` file to all |
0 Training finished! | ||
``` | ||
|
||
#### Profiler logs in TensorBoard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we want to use XProf instead of TensorBoard, but we should confirm.
|
||
#### Troubleshooting | ||
|
||
We can execute command on all workers together or on a single worker: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can execute command on all workers together or on a single worker: | |
We can execute commands on all workers together or on a single worker: |
Hi folks - gentle ping. If you have any feedback, we're happy to address. Thanks! |
This is a draft of multi-host tutorial, based on this gist: https://gist.github.com/vfdev-5/70f695e462443685a0922e79ce0ee899 and Chris Jones' mnist_xla.py code.
cc @melissawm