-
Notifications
You must be signed in to change notifi 8000 cation settings - Fork 24.1k
torch.compile fails when used together with activation checkpointing #121966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I think the issue in this case is that It's somewhat difficult for us to support |
Thanks for the insights! In our use case, we checkpoint Transformer layer by layer, and use torch.compile to optimize small ops inside Transformer layer, so doing the reverse might be a bit hard. I wonder if there is any chance we could get it work? |
The fundamental thing that's tricky here is that if you imagine the order of events under a checkpointed region, there's 1. the initial forwards pass, 2. the forwards pass that's recomputed in the backwards pass, and then 3. the backwards pass itself. When the entire checkpoint region can sit under compile, then this can be treated like normal autograd, with a "forwards" graph and a "backwards" graph (which contains the recomputed forwards + the backwards itself). However, if you're only compiling a subset of the graph, then there are now 3 distinct compiled regions that must run, with other operators inbetween. Now, I don't think this is impossible to support, particularly if you're fine with the forwards graph and the recomputed forwards graph sharing the same compiled region (and thus losing some of the potential performance benefit of not needing to write out activations). Another way to support this is to use torch.compile explicitly on the forwards and backwards "ops", although this arguably a bit more annoying. Perhaps we should just support graph-breaks under activation checkpointing... Perhaps this isn't that difficult to do conceptually if we just reuse the graph. cc: @albanD @soulitzer @bdhirsh @zou3519 ? |
Also seeing this issue. |
how to solve this issue? just use special ops to do checkpoint and compile? |
see the same issue while there's a graph break and comment out @torch._disable_dynamo ( comment it out because I don't want to fallback to eager) |
@Gy-Lu We actually have fixed this issue. Going to close it now. |
Hi, could you please give me a commit/pull request which fixes this issue? |
🐛 Describe the bug
fails with
This error only appears if
use_reentrant
is set toFalse
.Versions
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519
The text was updated successfully, but these errors were encountered: