Skip to content

Commit 85cc546

Browse files
authored
[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp (#1127)
* [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working. * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing (as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method). * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp by removing the inference path. The IF conditional on the x.requires_grad state changes the behavior of the recomputation of the forward() method which breaks activation checkpointing (as on the recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
1 parent a97a1e0 commit 85cc546

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

xformers/csrc/swiglu/swiglu_packedw.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,10 @@ class SwiGLUPackedWeights
101101
auto x5 = torch::nn::functional::linear(
102102
x4, w3, b3.has_value() ? b3.value() : at::Tensor());
103103

104-
if (ctx != nullptr) {
105-
ctx->save_for_backward({x, w1w2, w3, x1, x2});
106-
ctx->saved_data["has_b1b2"] = b1b2.has_value();
107-
ctx->saved_data["has_b3"] = b3.has_value();
108-
}
104+
ctx->save_for_backward({x, w1w2, w3, x1, x2});
105+
ctx->saved_data["has_b1b2"] = b1b2.has_value();
106+
ctx->saved_data["has_b3"] = b3.has_value();
107+
109108
return x5;
110109
}
111110

@@ -211,12 +210,7 @@ at::Tensor swiglu_packedw_cuda(
211210
const std::optional<at::Tensor> b1b2,
212211
const at::Tensor w3,
213212
const std::optional<at::Tensor> b3) {
214-
if (x.requires_grad()) {
215-
return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3);
216-
} else {
217-
return SwiGLUPackedWeights::forward(
218-
/* ctx */ nullptr, x, w1w2, b1b2, w3, b3);
219-
}
213+
return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3);
220214
}
221215
} // namespace
222216

0 commit comments

Comments
 (0)