-
Notifications
You must be signed in to change notification settings - Fork 24.1k
compiled_autograd deadlock when using checkpoint #135298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
besides the deadlock, "inductor" has another error, use "eager" to reproduce the deadlock.
|
The repro you linked doesn't seem to run anything when you don't pass in any args - what's the expected usage? I also see HPU in the repro script - if it fails with inductor-only, what's the right way to run it? |
nvm, this repro's the internal assert for me:
script
|
Thanks for the report. The deadlock will instead be an assert (that @bdhirsh ran into) as of the nightly, and will be in PyTorch 2.5. Compiled Autograd only supports the non-reentrant checkpointing API e.g. Btw, your workaround disables compiled autograd! It'll be as if you never enabled it in the first place
with torch._dynamo.compiled_autograd.disable():
torch.autograd.backward(...) |
As for the triton_key issue, it's due to an old triton version: #127271. PyTorch installation should have installed the correct version for you, but it might have conflicted or been overwritten by another triton installation |
yes, so it's just a workaround, not a solution. if (compiled_autograd != nullptr) {
// see [Note: Compiled Autograd]
TORCH_CHECK(
!create_graph, "compiled_autograd does not support create_graph");
_thread_check.release();
TORCH_CHECK(
!AnomalyMode::is_enabled(),
"compiled_autograd does not support AnomalyMode")
return (*compiled_autograd)(
graph_root, *graph_task, accumulate_grad, outputs);
}
|
if fails with "eager", "inductor", also "hpu_backend", because it's deadlock from checkpoint function, not compiler.
because checkpoint fallback to eager with any graph break in Module.forward, and then backward also fallback to eager. btw, maybe it need to create another issue, why torch.utils.checkpoint was @torch._disable_dynamo, if I remove it, the Module.forward still can be traced into graph, which is faster. |
hi @xmfan `output.sum().backward()`
-> `compile_autograd()` call in cpp
-> `CheckpointFunction.backward`
-> `torch.autograd.backward` under compiled_autograd disabled. |
That is the equivalent of never calling if device == "cpu":
def enable_compiled_autograd():
def compiler_fn(gm):
#return torch.compile(gm, backend='inductor', fullgraph=False)
return torch.compile(gm, backend='eager', fullgraph=False)
import functools
import torch
from torch._dynamo import compiled_autograd
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(compiled_autograd.AutogradCompilerInstance, compiler_fn)
)
torch._dynamo.reset()
torch._dynamo.config.optimize_ddp = "python_reducer"
else:
import habana_frameworks.torch.dynamo.compile_backend
from habana_frameworks.torch.dynamo.compile_backend.experimental import enable_compiled_autograd
enable_compiled_autograd() Note that |
@xmfan for "python_reducer", I have not investigated too much, maybe it's valuable to try other strategies. thanks for your reminder. |
So my suggestion:
|
@majian4work I'd be happy to review a PR for 1, this can be implemented by using:
|
@xmfan |
Fixes pytorch#135298 Pull Request resolved: pytorch#135795 Approved by: https://github.com/xmfan
Fixes pytorch#135298 Pull Request resolved: pytorch#135795 Approved by: https://github.com/xmfan
Fixes #135298 Pull Request resolved: #135795 Approved by: https://github.com/xmfan
Fixes pytorch#135298 Pull Request resolved: pytorch#135795 Approved by: https://github.com/xmfan
Fixes pytorch#135298 Pull Request resolved: pytorch#135795 Approved by: https://github.com/xmfan
🐛 Describe the bug
when we try to use compiled_autograd to improve the backward performance ( graph break makes checkpoint fallback to eager, all fwd and bwd run in eager got bad performance)
when we call
out.sum().backward()
, theCheckpointFunction.backward
will calltorch.autograd.backward
, then leads to a deadlocked.I have a workaround by now, I'm not sure if it can be taken as a final solution.
another option is disable it in
torch::dynamo::autograd::compiled_autograd(...)
before lock.Error logs
No response
Minified repro
Versions
python pytorch-fork/torch/utils/collect_env.py
Collecting environment information...
PyTorch version: 2.4.0a0+git12138a8
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.5 (ssh://gerrit.habana-labs.com:29418/tpc_llvm10 0c338a08ec38e26cef18216f902a07849a108a56)
CMake version: version 3.28.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 38.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] habana-torch-dataloader==1.18.0+git0092472b4
[pip3] habana-torch-plugin==1.18.0+git0092472b4
[pip3] mypy==1.10.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pylsp-mypy==0.6.8
[pip3] pytorch-lightning==2.2.0.post0
[pip3] torch==2.4.0a0+git12138a8
[pip3] torch-debug==2.2.0a0+git10d65c0
[pip3] torch_tb_profiler==0.4.0
[pip3] torchaudio==2.4.0a0+69d4077
[pip3] torchdata==0.7.1+5e6f7b7
[pip3] torchmetrics==1.3.2
[pip3] torchtext==0.18.0a0+9bed85d
[pip3] torchvision==0.19.0a0+48b1edf
[pip3] triton==2.2.0
[conda] Could not collect
cc @ezyang @chauhang @penguinwu @xmfan
The text was updated successfully, but these errors were encountered: