Skip to content

Added Multi-Host TPU tutorial #9507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Jul 24, 2025

This is a draft of multi-host tutorial, based on this gist: https://gist.github.com/vfdev-5/70f695e462443685a0922e79ce0ee899 and Chris Jones' mnist_xla.py code.

cc @melissawm

@vfdev-5 vfdev-5 force-pushed the docs-learn-multi-host-tpu branch from 4527a9f to 8c0d217 Compare July 24, 2025 09:54
Copy link
Contributor

@melissawm melissawm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @vfdev-5 ! A few very straightforward comments and one question (should we use TensorBoard or XProf for profiling?0


Before diving into the code and the commands to execute, let us introduce some terminology. The section below is an adapted version of [JAX multi-host tutorial](https://docs.jax.dev/en/latest/multi_process.html#terminology).

We sometimes call each Python process running PyTorch/XLA computations a controller or a host, but the terms are essentially synonymous.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
We sometimes call each Python process running PyTorch/XLA computations a controller or a host, but the terms are essentially synonymous.
We sometimes call each Python process running PyTorch/XLA computations a _controller_ or a _host_, but the terms are essentially synonymous.


### Google Cloud tools setup

We first need to install `gcloud` CLI. The official guide can be found [here](https://cloud.google.com/sdk/docs/install), below we provide the commands for Linux/Ubuntu:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We first need to install `gcloud` CLI. The official guide can be found [here](https://cloud.google.com/sdk/docs/install), below we provide the commands for Linux/Ubuntu:
We first need to install the `gcloud` CLI. The [official guide](https://cloud.google.com/sdk/docs/install) has full installation instructions. Below we provide the commands for Linux/Ubuntu:

<details>

<summary>
How to instal gcloud CLI on Linux/Ubuntu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
How to instal gcloud CLI on Linux/Ubuntu
How to install the gcloud CLI on Linux/Ubuntu


</details>

Next, we need to run the `gcloud` configuration by choosing the project, compute zone etc:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Next, we need to run the `gcloud` configuration by choosing the project, compute zone etc:
Next, we need to run the `gcloud` configuration command by choosing the project and compute zone:


#### (Optional) GCS Fuse tool

We can optionally install `gcsfuse` tool to be able to mount Google Cloud Storage buckets. The official guide can be found [here](https://cloud.google.com/storage/docs/cloud-storage-fuse/install), below we provide the commands for Linux/Ubuntu:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We can optionally install `gcsfuse` tool to be able to mount Google Cloud Storage buckets. The official guide can be found [here](https://cloud.google.com/storage/docs/cloud-storage-fuse/install), below we provide the commands for Linux/Ubuntu:
We can optionally install the `gcsfuse` tool to be able to mount Google Cloud Storage buckets. The [official guide](https://cloud.google.com/storage/docs/cloud-storage-fuse/install) has full installation instructions. Below we provide the commands for Linux/Ubuntu:

data, target = data.to(device), target.to(device)
xs.mark_sharding(data, conv_mesh, ("data", "dim1", None, None))
```
The model contains convolutional and fully-connected layers and we also shard them. Convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The model contains convolutional and fully-connected layers and we also shard them. Convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`.
The model contains convolutional and fully-connected layers and we also shard them. The convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`.

xs.mark_sharding(data, conv_mesh, ("data", "dim1", None, None))
```
The model contains convolutional and fully-connected layers and we also shard them. Convolution's weights of shape `(OutFeat, InFeat, K, K)` are sharded such that a single worker has a local shard of shape `(OutFeat // D1, InFeat, K, K)`.
The fully-connected layer `fc1` is sharded similarly to the convolutions and `fc2` are sharded over two dimensions as following `(OutFeat // D1, InFeat // 2)`. In this script the model's sharding is defined directly inside the constructor method:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The fully-connected layer `fc1` is sharded similarly to the convolutions and `fc2` are sharded over two dimensions as following `(OutFeat // D1, InFeat // 2)`. In this script the model's sharding is defined directly inside the constructor method:
The fully-connected layer `fc1` is sharded similarly to the convolutions and `fc2` are sharded over two dimensions as following `(OutFeat // D1, InFeat // 2)`. In this script, the model's sharding is defined directly inside the constructor method:


```

Once the TPU VM is ready, we copy previously created `mnist_xla.py` file to all
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Once the TPU VM is ready, we copy previously created `mnist_xla.py` file to all
Once the TPU VM is ready, we copy the previously created `mnist_xla.py` file to all

0 Training finished!
```

#### Profiler logs in TensorBoard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we want to use XProf instead of TensorBoard, but we should confirm.


#### Troubleshooting

We can execute command on all workers together or on a single worker:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We can execute command on all workers together or on a single worker:
We can execute commands on all workers together or on a single worker:

@melissawm
Copy link
Contributor

Hello @pgmoka @bhavya01 - would you mind taking a look for correctness and scope of this tutorial? If you are happy with the general idea, we can remove this from draft and address any other feedback. Thank you!

@melissawm
Copy link
Contributor

Hi folks - gentle ping. If you have any feedback, we're happy to address. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants