You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: dlami/index.rst
+11-2Lines changed: 11 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,10 +17,19 @@ to easily get started on single Neuron instance. Below sections describe the sup
17
17
18
18
Neuron Multi Framework DLAMI
19
19
----------------------------
20
-
Neuron Deep Learning AMI (DLAMI) is a multi-framework DLAMI that supports multiple Neuron framework/libraries. Each DLAMI is pre-installed with Neuron drivers and support all Neuron instance types. Each virtual environment that corresponds to a specific Neuron framework/library
21
-
comes pre-installed with all the Neuron libraries including Neuron compiler and Neuron run-time needed for you to easily get started.
20
+
Neuron Deep Learning AMI (DLAMI) is a multi-framework DLAMI that supports multiple Neuron framework/libraries. Each DLAMI is pre-installed with Neuron drivers and support all Neuron instance types. Each virtual environment that corresponds to a specific Neuron framework/library
21
+
comes pre-installed with all the Neuron libraries including Neuron compiler and Neuron runtime needed for you to easily get started.
22
22
23
23
24
+
.. note::
25
+
26
+
Tensorflow-neuron 2.10 (inf1) released in SDK v2.20.2 is not compatible with the latest runtime in v2.21 SDK.
27
+
Code that compiles will face runtime errors with the latest SDK 2.21.1 version.
28
+
29
+
Neuron team is aware of this issue and it will be fixed in the next minor release.
30
+
31
+
Please refer to `this page <https://github.com/aws-neuron/aws-neuron-sdk/issues/1071>`_ for more information on the issue and a temporary work-around.
Copy file name to clipboardExpand all lines: frameworks/torch/torch-neuronx/programming-guide/training/pytorch-neuron-programming-guide.rst
+70-72Lines changed: 70 additions & 72 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -218,83 +218,82 @@ compiled and executed if there are extra mark-steps or functions with
218
218
implicit mark-steps. Additionally, more graphs can be generated if there
219
219
are different execution paths taken due to control-flows.
220
220
221
-
Automatic casting of float tensors to BFloat16
222
-
----------------------------------------------
223
-
224
-
With PyTorch Neuron, the default behavior is for torch.float (FP32) and torch.double (FP64) tensors
225
-
to be mapped to torch.float in hardware. To reduce memory footprint and improve performance,
226
-
torch.float and torch.double tensors can automatically be converted to BFloat16 by setting
227
-
the environment variable ``XLA_USE_BF16=1``. Alternatively, torch.float can automatically be converted
228
-
to BFloat16 and torch.double converted to FP32 by setting the environment variable ``XLA_DOWNCAST_BF16=1``.
229
-
230
-
Automatic Mixed-Precision
231
-
-------------------------
232
-
233
-
BF16 mixed-precision using PyTorch Autocast
234
-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
235
-
236
-
By default, the compiler automatically cast internal FP32 operations to
237
-
BF16. You can disable this and allow PyTorch's BF16 mixed-precision to
238
-
do the casting. PyTorch's BF16 mixed-precision is achieved by casting
239
-
certain operations to operate BF16. We currently use CUDA's list of
240
-
operations that can operate in BF16:
241
-
242
-
.. code:: bash
243
-
244
-
_convolution
245
-
_convolution
246
-
_convolution_nogroup
247
-
conv1d
248
-
conv2d
249
-
conv3d
250
-
conv_tbc
251
-
conv_transpose1d
252
-
conv_transpose2d
253
-
conv_transpose3d
254
-
convolution
255
-
cudnn_convolution
256
-
cudnn_convolution_transpose
257
-
cudnn_convolution
258
-
cudnn_convolution_transpose
259
-
cudnn_convolution
260
-
cudnn_convolution_transpose
261
-
prelu
262
-
addmm
263
-
addmv
264
-
addr
265
-
matmul
266
-
mm
267
-
mv
268
-
linear
269
-
addbmm
270
-
baddbmm
271
-
bmm
272
-
chain_matmul
273
-
linalg_multi_dot
221
+
Full BF16 with stochastic rounding enabled
222
+
------------------------------------------
274
223
275
-
To enable PyTorch's BF16 mixed-precision, first turn off the Neuron
276
-
compiler auto-cast:
224
+
Previously, on torch-neuronx 2.1 and earlier, the environmental variables ``XLA_USE_BF16`` or ``XLA_DOWNCAST_BF16`` provided full casting to BF16 with stochastic rounding enabled by default. These environmental variables are deprecated in torch-neuronx 2.5, although still functional with warnings. To replace ``XLA_USE_BF16`` or ``XLA_DOWNCAST_BF16`` with stochastic rounding on Neuron, set ``NEURON_RT_STOCHASTIC_ROUNDING_EN=1`` and use the ``torch.nn.Module.to`` method to cast model floating-point parameters and buffers to data-type BF16 as follows:
Next, overwrite torch.cuda.is_bf16_supported to return True:
241
+
Similarly, if the optimizer states are to be kept in FP32, convert the gradients to FP32 before optimizer computations:
283
242
284
243
.. code:: python
285
244
286
-
torch.cuda.is_bf16_supported =lambda: True
245
+
grad = p.grad.data.float()
287
246
288
-
Next, per recommendation from official PyTorch documentation, place only
289
-
the forward-pass of the training step in the torch.autocast scope:
247
+
For a full example, please see the :ref:`PyTorch Neuron BERT Pretraining Tutorial (Data-Parallel) <hf-bert-pretraining-tutorial>`, which has been updated to use ``torch.nn.Module.to`` instead of ``XLA_DOWNCAST_BF16``.
248
+
249
+
BF16 in GPU-compatible mode without stochastic rounding enabled
Full BF16 training in GPU-compatible mode would enable faster convergence without the need for stochastic rounding, but would require a FP32 copy of weights/parameters to be saved and used in the optimizer. To enable BF16 in GPU-compatible mode without stochastic rounding enabled, use the ``torch.nn.Module.to`` method to cast model floating-point parameters and buffers to data-type bfloat16 as follows without setting ``NEURON_RT_STOCHASTIC_ROUNDING_EN=1``:
290
253
291
254
.. code:: python
292
255
293
-
with torch.autocast(dtype=torch.bfloat16, device_type='cuda'):
256
+
# model is created
257
+
model.to(torch.bfloat16)
258
+
259
+
In the initializer of the optimizer, for example AdamW, you can add code like the following code snippet to make a FP32 copy of weights:
260
+
261
+
.. code:: python
262
+
263
+
# keep a copy of weights in highprec
264
+
self.param_groups_highprec = []
265
+
for group inself.param_groups:
266
+
params = group['params']
267
+
param_groups_highprec = [p.data.float() for p in params]
In the :ref:`PyTorch Neuron BERT Pretraining Tutorial (Data-Parallel) <hf-bert-pretraining-tutorial>`, this mode can be enabled by pasing ``--optimizer=AdamW_FP32ParamsCopy`` option to ``dp_bert_large_hf_pretrain_hdf5.py`` and setting ``NEURON_RT_STOCHASTIC_ROUNDING_EN=0`` (or leave it unset).
271
+
272
+
.. _automatic_mixed_precision_autocast:
273
+
274
+
BF16 automatic mixed precision using PyTorch Autocast
By default, the compiler automatically casts internal FP32 operations to
278
+
BF16. You can disable this and allow PyTorch's BF16 automatic mixed precision function (``torch.autocast``) to
279
+
do the casting of certain operations to operate in BF16.
280
+
281
+
To enable PyTorch's BF16 mixed-precision, first turn off the Neuron
282
+
compiler auto-cast:
283
+
284
+
.. code:: python
285
+
286
+
os.environ["NEURON_CC_FLAGS"] ="--auto-cast=none"
287
+
288
+
Next, per recommendation from official PyTorch `torch.autocast documentation <https://pytorch.org/docs/stable/amp.html#autocasting>`__, place only
289
+
the forward-pass of the training step in the ``torch.autocast`` scope with ``xla`` device type:
290
+
291
+
.. code:: python
292
+
293
+
with torch.autocast(dtype=torch.bfloat16, device_type='xla'):
294
294
# forward pass
295
295
296
-
The device type is CUDA because we are using CUDA's list of BF16
297
-
compatible operations as mentioned above.
296
+
The device type is XLA because we are using PyTorch-XLA's autocast backend. The PyTorch-XLA `autocast mode source code <https://github.com/pytorch/xla/blob/master/torch_xla/csrc/autocast_mode.cpp>`_ lists which operations are casted to lower precision BF16 ("lower precision fp cast policy" section), which are maintained in FP32 ("fp32 cast policy"), and which are promoted to the widest input types ("promote" section).
298
297
299
298
Example showing the original training code snippet:
300
299
@@ -319,7 +318,7 @@ The following shows the training loop modified to use BF16 autocast:
319
318
deftrain_loop_fn(train_loader):
320
319
for i, data inenumerate(train_loader):
321
320
torch.cuda.is_bf16_supported =lambda: True
322
-
with torch.autocast(dtype=torch.bfloat16, device_type='cuda'):
321
+
with torch.autocast(dtype=torch.bfloat16, device_type='xla'):
323
322
inputs = data[0]
324
323
labels = data[3]
325
324
outputs = model(inputs, labels=labels)
@@ -328,7 +327,7 @@ The following shows the training loop modified to use BF16 autocast:
328
327
optimizer.step()
329
328
xm.mark_step()
330
329
331
-
For a full example of BF16 mixed-precision, see :ref:`PyTorch Neuron BERT Pretraining Tutorial <hf-bert-pretraining-tutorial>`.
330
+
For a full example of BF16 mixed-precision, see :ref:`PyTorch Neuron BERT Pretraining Tutorial (Data-Parallel) <hf-bert-pretraining-tutorial>`.
332
331
333
332
See official PyTorch documentation for more details about
For best performance, you may try to aggregate the data transfers between host CPUs and devices.
376
+
For example, increasing the value for `batches_per_execution` argument when instantiating ``MpDeviceLoader`` can help increase performance for certain where there's frequent host-device traffic like ViT as described in `a blog <https://towardsdatascience.com/ai-model-optimization-on-aws-inferentia-and-trainium-cfd48e85d5ac>`_. NOTE: Increasing `batches_per_execution` value would delay the mark-step for multiple batches specified by this value, increasing graph size and could lead to out-of-memory (device OOM) error.
377
+
373
378
Ensure common initial weights across workers
374
379
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
375
380
@@ -396,13 +401,6 @@ be loaded using ``serialization.load`` api. More information on this here: `Savi
396
401
397
402
FAQ
398
403
---
399
-
400
-
What is the difference between Trainium and Inferentia?
The BERT training script ``dp_bert_large_hf_pretrain_hdf5.py``
47
+
The BERT training script ``dp_bert_large_hf_pretrain_hdf5.py`` (`source <https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/dp_bert_large_hf_pretrain_hdf5.py>`_)
48
48
can run on a Trainium instance (trn1.32xlarge) that contains the
49
49
appropriate Neuron runtime and Python dependencies.
50
50
@@ -60,7 +60,7 @@ For all the commands below, make sure you are in the virtual environment that yo
60
60
61
61
source~/aws_neuron_venv_pytorch/bin/activate
62
62
63
-
Next, clone the AWS Neuron Samples repository and install requirements in the BERT tutorial directory ``aws-neuron-samples/torch-neuronx/training/dp_bert_hf_pretrain``:
63
+
Next, clone the `AWS Neuron Samples repository <https://github.com/aws-neuron/aws-neuron-samples/>`_ and install requirements in the BERT tutorial directory ``aws-neuron-samples/torch-neuronx/training/dp_bert_hf_pretrain`` (`directory link <https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx/training/dp_bert_hf_pretrain>`_):
Copy file name to clipboardExpand all lines: general/appnotes/torch-neuronx/introducing-pytorch-2-x.rst
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -63,7 +63,7 @@ To migrate the training scripts from PyTorch NeuronX 2.1 to PyTorch NeuronX 2.5,
63
63
64
64
``xm`` below refers to ``torch_xla.core.xla_model`` and ``xr`` refers to ``torch_xla.runtime``
65
65
66
-
* The environment variables ``XLA_DOWNCAST_BF16`` and ``XLA_USE_BF16`` are deprecated (warning when used). Please switch to automatic mixed-precision or use ``model.to(torch.bfloat16)`` command to convert model to BF16 format. (see :ref:`<migration_from_xla_downcast_bf16>`)
66
+
* The environment variables ``XLA_DOWNCAST_BF16`` and ``XLA_USE_BF16`` are deprecated (warning when used). Please switch to automatic mixed-precision or use ``model.to(torch.bfloat16)`` command to convert model to BF16 format. (see :ref:`migration_from_xla_downcast_bf16`)
67
67
* The ``torch_xla.experimental.pjrt`` module which was replaced by ``torch_xla.runtime`` in Torch-XLA 2.1, has been removed in Torch-XLA 2.5. Users should now utilize the ``torch_xla.runtime`` module as a replacement.
68
68
* ``torch_xla.runtime.using_pjrt`` is removed because PJRT is the sole Torch-XLA runtime.
69
69
* ``xm.all_reduce`` no longer operates in-place for single tensors. To fix this, please convert the single tensor to an array (e.g.. ``[single_tensor]``) or assign the output of ``xm.all_reduce`` to a variable.
@@ -108,7 +108,7 @@ This is a warning that ``torch_xla.core.xla_model.xrt_world_size()`` will be rem
108
108
WARNING:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
This is a warning that ``torch_xla.core.xla_model.xla_model.get_ordinal()`` will be removed in a future release. Please switch to using ``torch_xla.runtime.global_ordinal`` instead.
111
+
This is a warning that ``torch_xla.core.xla_model.xla_model.get_ordinal()`` will be removed in a future release. Please switch to using ``torch_xla.runtime.global_ordinal`` instead.
112
112
113
113
114
114
AttributeError: module 'torch_xla.runtime' has no attribute 'using_pjrt'
0 commit comments