Commit 85cc546
authored
[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp (#1127)
* [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp
According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead.
After removing forward call, activation checkpointing starts working.
* [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp
The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
* [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp by removing the inference path.
The IF conditional on the x.requires_grad state changes the behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on the recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).1 parent a97a1e0 commit 85cc546
1 file changed
+5
-11
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
101 | 101 | | |
102 | 102 | | |
103 | 103 | | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
109 | 108 | | |
110 | 109 | | |
111 | 110 | | |
| |||
211 | 210 | | |
212 | 211 | | |
213 | 212 | | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | | - | |
| 213 | + | |
220 | 214 | | |
221 | 215 | | |
222 | 216 | | |
| |||
0 commit comments