You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
As we've landed functorch-backed GradSampleModule, we also want to update the README which helps people navigate different grad samplers
For the content of this readme I've also run benchmarks for all the options - some results are surprising and hard to interpret, but we have mostly consistent picture
## tl;dr
* There's no difference on CPU
* functorch performance depends on the exact GPU setup: same benchmarks show could be up to 4x slower or 2x faster than the baseline depending on the GPU
* EW are consistently 25-30% faster for linear, but not conv
## benchmarks
| device | benchmark | hooks | functorch | ExpandedWeights
|:-------:|:-------:|:-------:|:-------:|:-------:|
| cpu | nn.Conv2d | 1x | 0.9x | 1x |
| cpu | nn.Linear | 1x | 1x | 0.9x |
| cpu | full epoch on CIFAR10 example | 1x | 1.5x | 1x |
| Tesla T4 (Google Colab) | nn.Conv2d | 1x | 4x | 0.9x |
| Tesla T4 (Google Colab) | nn.Linear | 1x | 1.25x | 0.75x |
| A100 (AWS) | nn.Conv2d | 1x | 0.5x | 1x |
| A100 (AWS) | nn.Linear | 1x | 1.5x | 0.75x |
| A100 (AWS) | full epoch on CIFAR10 example | 1x | 1.1x | 0.75x |
FYI samdow
Pull Request resolved: #497
Reviewed By: karthikprasad
Differential Revision: D39352067
Pulled By: ffuuugor
fbshipit-source-id: 19b4fff80fe3c1963fab24e1292ae625200bc749
Copy file name to clipboardExpand all lines: opacus/grad_sample/README.md
+42-13Lines changed: 42 additions & 13 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -14,8 +14,11 @@ which one to use.
14
14
improves upon ``GradSampleModule`` on performance and functionality.
15
15
16
16
**TL;DR:** If you want stable implementation, use ``GradSampleModule`` (`grad_sample_mode="hooks"`).
17
-
If you want to experiment with the new functionality - try ``GradSampleModuleExpandedWeights``(`grad_sample_mode="ew"`)
18
-
and switch back to ``GradSampleModule`` if you encounter strange errors or unexpexted behaviour.
17
+
If you want to experiment with the new functionality, you have two options. Try
18
+
``GradSampleModuleExpandedWeights``(`grad_sample_mode="ew"`) for better performance and `grad_sample_mode=functorch`
19
+
if your model is not supported by ``GradSampleModule``.
20
+
21
+
Please switch back to ``GradSampleModule``(`grad_sample_mode="hooks"`) if you encounter strange errors or unexpexted behaviour.
19
22
We'd also appreciate it if you report these to us
20
23
21
24
## Hooks-based approach
@@ -26,6 +29,23 @@ Computes per-sample gradients for a model using backward hooks. It requires cust
26
29
trainable layer in the model. We provide such methods for most popular PyTorch layers. Additionally, client can
27
30
provide their own grad sampler for any new unsupported layer (see [tutorial](https://github.com/pytorch/opacus/blob/main/tutorials/guide_to_grad_sampler.ipynb))
28
31
32
+
## Functorch approach
33
+
- Model wrapping class: ``opacus.grad_sample.grad_sample_module.GradSampleModule (force_functorch=True)``
34
+
- Keyword argument for ``PrivacyEngine.make_private()``: `grad_sample_mode="functorch"`
35
+
36
+
[functorch](https://pytorch.org/functorch/stable/) is JAX-like composable function transforms for PyTorch.
37
+
With functorch we can compute per-sample-gradients efficiently by using function transforms. With the efficient
38
+
parallelization provided by `vmap`, we can obtain per-sample gradients for any function function (i.e. any model) by
39
+
doing essentially `vmap(grad(f(x)))`.
40
+
41
+
Our experiments show, that `vmap` computations in most cases are as fast as manually written grad samplers used in
42
+
hooks-based approach.
43
+
44
+
With the current implementation `GradSampleModule` will use manual grad samplers for known modules (i.e. maintain the
45
+
old behaviour for all previously supported models) and will only use functorch for unknown modules.
46
+
47
+
With `force_functorch=True` passed to the constructor `GradSampleModule` will rely exclusively on functorch.
48
+
29
49
## ExpandedWeigths approach
30
50
- Model wrapping class: ``opacus.grad_sample.gsm_exp_weights.GradSampleModuleExpandedWeights``
31
51
- Keyword argument for ``PrivacyEngine.make_private()``: `grad_sample_mode="ew"`
@@ -42,14 +62,23 @@ is roughly the same.
42
62
Please note that these are known limitations and we plan to improve Expanded Weights and bridge the gap in feature completeness
43
63
44
64
45
-
| xxx | Hooks | Expanded Weights |
46
-
|:-----:|:-------:|:------------------:|
47
-
| Required PyTorch version | 1.8+ | 1.13+ |
48
-
| Development status | Underlying mechanism deprecated | Beta |
49
-
| Performance | - | ✅ Likely up to 2.5x faster |
50
-
| torchscript models | Not supported | ✅ Supported |
51
-
| Client-provided grad sampler | ✅ Supported | Not supported |
52
-
|`batch_first=False`| ✅ Supported | Not supported |
53
-
| Most popular nn.* layers | ✅ Supported | ✅ Supported |
54
-
| Recurrent networks | ✅ Supported | Not supported |
55
-
| Padding `same` in Conv | ✅ Supported | Not supported |
0 commit comments