-
Notifications
You must be signed in to change notification settings - Fork 430
Open
Labels
documentationImprovements or additions to documentationImprovements or additions to documentation
Description
Documentation
Hi,
I am currently trying to figure out, how the global batch size will be calculated.
I was able to only find one hint here.
But it only documents the per_device_batch_size which says: "Sets the local batch size per accelerator chip."
So when someone has a v6-8 TPU for example, the following code returns the number of chips:
import jax
jax.local_device_count("tpu")It will return 8. But this gets complicated when e.g. a v6e-32 is used. So the best way of getting the number of chips is to retrieve them from the documentation.
That means as a rule of thumb: "global_batch_size = per_device_batch_size * number of chips". It would be amazing if this information could be documented (at least for TPU usage).
Metadata
Metadata
Assignees
Labels
documentationImprovements or additions to documentationImprovements or additions to documentation