8000 compiled_autograd deadlock when using checkpoint · Issue #135298 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

compiled_autograd deadlock when using checkpoint #135298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
majian4work opened this issue Sep 6, 2024 · 14 comments
Closed

compiled_autograd deadlock when using checkpoint #135298

majian4work opened this issue Sep 6, 2024 · 14 comments
Labels
module: compiled autograd compiled_autograd oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@majian4work
Copy link
Contributor
majian4work commented Sep 6, 2024

🐛 Describe the bug

when we try to use compiled_autograd to improve the backward performance ( graph break makes checkpoint fallback to eager, all fwd and bwd run in eager got bad performance)
when we call out.sum().backward(), the CheckpointFunction.backward will call torch.autograd.backward, then leads to a deadlocked.

I have a workaround by now, I'm not sure if it can be taken as a final solution.
another option is disable it in torch::dynamo::autograd::compiled_autograd(...) before lock.

with torch._dynamo.compiled_autograd.disable():
    torch.autograd.backward(...)

Error logs

No response

Minified repro

import os
import sys
import numpy as np

os.environ["TORCH_LOGS"] = "dynamo,graph_breaks,recompiles,graph_code,aot_joint_graph,aot_graphs"
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
import torch
import torch._dynamo

os.environ["PT_HPU_LAZY_MODE"] = "0"
os.environ["ENABLE_CONSOLE"] = "1"
os.environ["LOG_LEVEL_ALL"] = "6"
os.environ["LOG_LEVEL_PT_EAGER"] = "1"

class MyModule(torch.nn.Module):
    def __init__(self, device):
        super(MyModule, self).__init__()
        self.layer = MyModule2(device)

    def forward(self, x):
        torch._dynamo.graph_break()
        res = x.cos() - x.sin()
        return self.layer(res)

class MyModule2(torch.nn.Module):
    def __init__(self, device):
        super(MyModule2, self).__init__()
        self.linear = torch.nn.Linear(2, 2, device=device)

    def forward(self, x):
        return self.linear(x).cos()


# stock pytorch checkpoint
def stock(args):
    device = args[0] if args else "cpu"

    from torch.utils.checkpoint import checkpoint
    
    if device == "cpu":
        def enable_compiled_autograd():
            def compiler_fn(gm):
                #return torch.compile(gm, backend='inductor', fullgraph=False)
                return torch.compile(gm, backend='eager', fullgraph=False)

            import functools
            import torch
            from torch._dynamo import compiled_autograd
            torch._C._dynamo.compiled_autograd.set_autograd_compiler(
                    functools.partial(compiled_autograd.AutogradCompilerInstance, compiler_fn)
                )

            torch._dynamo.reset()
            torch._dynamo.config.optimize_ddp = "python_reducer"
    else:
        import habana_frameworks.torch.dynamo.compile_backend
        from habana_frameworks.torch.dynamo.compile_backend.experimental import enable_compiled_autograd

    enable_compiled_autograd()

    m = MyModule(device)
    def fn(x):
        x = checkpoint(m.__call__, x)
        out = x + 2 
        out = out * 2 
        return out


    # fn = torch.compile(fn, backend="inductor" if device == "cpu" else "hpu_backend")
    fn = torch.compile(fn, backend="hpu_backend" if device == "hpu" else "inductor")
    for i in range(2):
        print(f"run {i} ..., ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
        x = torch.ones(2, 2, device=device, requires_grad=True)
        out = fn(x)
        print(out.cpu())
        print(f"backward {i} ..., bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
        out.sum().backward()


if __name__ == "__main__":

    seed = 1000
    torch.manual_seed(seed)
    # habana_frameworks.torch.hpu.random.manual_seed(seed)

    if len(sys.argv) > 1:
        try:
            eval(f"{sys.argv[1]}({sys.argv[2:]})")
        except Exception as e:
            print(e)
            l = []
            for key, value in dict(locals()).items():
                if callable(value) and value.__module__ == __name__:
                    l.append(key)
            print(f"you can run one of the following functions: {l}")
        exit(0)

Versions

python pytorch-fork/torch/utils/collect_env.py
Collecting environment information...
PyTorch version: 2.4.0a0+git12138a8
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.5 (ssh://gerrit.habana-labs.com:29418/tpc_llvm10 0c338a08ec38e26cef18216f902a07849a108a56)
CMake version: version 3.28.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 38.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] habana-torch-dataloader==1.18.0+git0092472b4
[pip3] habana-torch-plugin==1.18.0+git0092472b4
[pip3] mypy==1.10.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pylsp-mypy==0.6.8
[pip3] pytorch-lightning==2.2.0.post0
[pip3] torch==2.4.0a0+git12138a8
[pip3] torch-debug==2.2.0a0+git10d65c0
[pip3] torch_tb_profiler==0.4.0
[pip3] torchaudio==2.4.0a0+69d4077
[pip3] torchdata==0.7.1+5e6f7b7
[pip3] torchmetrics==1.3.2
[pip3] torchtext==0.18.0a0+9bed85d
[pip3] torchvision==0.19.0a0+48b1edf
[pip3] triton==2.2.0
[conda] Could not collect

cc @ezyang @chauhang @penguinwu @xmfan

@majian4work
Copy link
Contributor Author

besides the deadlock, "inductor" has another error, use "eager" to reproduce the deadlock.

backend='inductor' raised:
LoweringException: ImportError: cannot import name 'triton_key' from 'triton.compiler.compiler' (/home/jianma/venv/lib/python3.10/site-packages/triton/compiler/compiler.py)
  target: aten.addmm.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_2', layout=FixedLayout('cpu', torch.float32, size=[2], stride=[1]))
  ))
  args[1]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.float32, size=[2, 2], stride=[2, 1]), data=Pointwise(
      'cpu',
      torch.float32,
      def inner_fn(index):
          i0, i1 = index
          tmp0 = ops.load(primals_3, i1 + 2 * i0)
          tmp1 = ops.cos(tmp0)
          tmp2 = ops.load(primals_3, i1 + 2 * i0)
          tmp3 = ops.sin(tmp2)
          tmp4 = tmp1 - tmp3
          return tmp4
      ,
      ranges=[2, 2],
      origin_node=sub,
      origins={sin, cos, sub}
    ))
  ))
  args[2]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[2, 2], stride=[2, 1]))
      ),
      FixedLayout('cpu', torch.float32, size=[2, 2], stride=[1, 2]),
      origins={permute}
    )
  )


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@bdhirsh bdhirsh added the module: compiled autograd compiled_autograd label Sep 6, 2024
@bdhirsh
Copy link
Contributor
bdhirsh commented Sep 6, 2024

The repro you linked doesn't seem to run anything when you don't pass in any args - what's the expected usage?

I also see HPU in the repro script - if it fails with inductor-only, what's the right way to run it?

@bdhirsh
Copy link
Contributor
bdhirsh commented Sep 6, 2024

nvm, this repro's the internal assert for me:

  File "/home/hirsheybar/local/b/pytorch/torch/autograd/graph.py", line 818, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: mtx_.try_lock() INTERNAL ASSERT FAILED at "/home/hirsheybar/local/b/pytorch/torch/csrc/dynamo/python_compiled_autograd.cpp":694, please report a bug to PyTorch. Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet.

script

import os
import sys
import numpy as np

os.environ["TORCH_LOGS"] = "dynamo,graph_breaks,recompiles,graph_code,aot_joint_graph,aot_graphs"
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
import torch
import torch._dynamo

os.environ["PT_HPU_LAZY_MODE"] = "0"
os.environ["ENABLE_CONSOLE"] = "1"
os.environ["LOG_LEVEL_ALL"] = "6"
os.environ["LOG_LEVEL_PT_EAGER"] = "1"

class MyModule(torch.nn.Module):
    def __init__(self, device):
        super(MyModule, self).__init__()
        self.layer = MyModule2(device)

    def forward(self, x):
        torch._dynamo.graph_break()
        res = x.cos() - x.sin()
        return self.layer(res)

class MyModule2(torch.nn.Module):
    def __init__(self, device):
        super(MyModule2, self).__init__()
        self.linear = torch.nn.Linear(2, 2, device=device)

    def forward(self, x):
        return self.linear(x).cos()


# stock pytorch checkpoint
def stock():
    device = "cpu"

    from torch.utils.checkpoint import checkpoint

    def enable_compiled_autograd():
        def compiler_fn(gm):
            #return torch.compile(gm, backend='inductor', fullgraph=False)
            return torch.compile(gm, backend='eager', fullgraph=False)

        import functools
        import torch
        from torch._dynamo import compiled_autograd
        torch._C._dynamo.compiled_autograd.set_autograd_compiler(
                functools.partial(compiled_autograd.AutogradCompilerInstance, compiler_fn)
            )

        torch._dynamo.reset()
        torch._dynamo.config.optimize_ddp = "python_reducer"

    enable_compiled_autograd()

    m = MyModule(device)
    def fn(x):
        x = checkpoint(m.__call__, x)
        out = x + 2
        out = out * 2
        return out


    # fn = torch.compile(fn, backend="inductor" if device == "cpu" else "hpu_backend")
    fn = torch.compile(fn)
    for i in range(2):
        print(f"run {i} ..., ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
        x = torch.ones(2, 2, device=device, requires_grad=True)
        out = fn(x)
        print(out.cpu())
        print(f"backward {i} ..., bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
        out.sum().backward()


stock()

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 6, 2024
@xmfan
Copy link
Member
xmfan commented Sep 6, 2024

Thanks for the report. The deadlock will instead be an assert (that @bdhirsh ran into) as of the nightly, and will be in PyTorch 2.5.

Compiled Autograd only supports the non-reentrant checkpointing API e.g. checkpoint(use_reentrant=False).
HOWEVER, you'll likely not see memory savings, this is coming soon.

Btw, your workaround disables compiled autograd! It'll be as if you never enabled it in the first place

I have a workaround by now, I'm not sure if it can be taken as a final solution.
another option is disable it in torch::dynamo::autograd::compiled_autograd(...) before lock.

with torch._dynamo.compiled_autograd.disable():
    torch.autograd.backward(...)

@xmfan
Copy link
Member
xmfan commented Sep 6, 2024

As for the triton_key issue, it's due to an old triton version: #127271. PyTorch installation should have installed the correct version for you, but it might have conflicted or been overwritten by another triton installation

@majian4work
Copy link
Contributor Author

Thanks for the report. The deadlock will instead be an assert (that @bdhirsh ran into) as of the nightly, and will be in PyTorch 2.5.

Compiled Autograd only supports the non-reentrant checkpointing API e.g. checkpoint(use_reentrant=False). HOWEVER, you'll likely not 8000 see memory savings, this is coming soon.

Btw, your workaround disables compiled autograd! It'll be as if you never enabled it in the first place

I have a workaround by now, I'm not sure if it can be taken as a final solution.
another option is disable it in torch::dynamo::autograd::compiled_autograd(...) before lock.

with torch._dynamo.compiled_autograd.disable():
    torch.autograd.backward(...)

yes, so it's just a workaround, not a solution.
and I don't think assert is a good solution, it can't resolve the deadlock issue, just throw it.
I think maybe it't better to disable itself before it's called

  if (compiled_autograd != nullptr) {
    // see [Note: Compiled Autograd]
    TORCH_CHECK(
        !create_graph, "compiled_autograd does not support create_graph");
    _thread_check.release();
    TORCH_CHECK(
        !AnomalyMode::is_enabled(),
        "compiled_autograd does not support AnomalyMode")
    return (*compiled_autograd)(
        graph_root, *graph_task, accumulate_grad, outputs);
  }

@majian4work
Copy link
Contributor Author
majian4work commented Sep 9, 2024

The repro you linked doesn't seem to run anything when you don't pass in any args - what's the expected usage?

I also see HPU in the repro script - if it fails with inductor-only, what's the right way to run it?

if fails with "eager", "inductor", also "hpu_backend", because it's deadlock from checkpoint function, not compiler.
where need to pass args? I seems can work well with the WA.

"what's the expected usage"

because checkpoint fallback to eager with any graph break in Module.forward, and then backward also fallback to eager.
We want to use compiled_autograd to capture the backward into graph, although it can't be captured into a whole big graph, but splited small graph also faster than eager.

btw, maybe it need to create another issue, why torch.utils.checkpoint was @torch._disable_dynamo, if I remove it, the Module.forward still can be traced into graph, which is faster.

@majian4work
Copy link
Contributor Author

@xmfan

Compiled Autograd only supports the non-reentrant checkpointing API e.g. checkpoint(use_reentrant=False).

There are other issues when graph breaks happen. #121966

@majian4work
Copy link
Contributor Author

hi @xmfan
I add this WA in CheckpointFunction.backward, it can work well.

`output.sum().backward()`
    -> `compile_autograd()` call in cpp
        -> `CheckpointFunction.backward`
           -> `torch.autograd.backward` under compiled_autograd disabled.

8000

@xmfan
Copy link
Member
xmfan commented Sep 9, 2024

That is the equivalent of never calling enable_compiled_autograd, you can just remove this code block:

if device == "cpu":
        def enable_compiled_autograd():
            def compiler_fn(gm):
                #return torch.compile(gm, backend='inductor', fullgraph=False)
                return torch.compile(gm, backend='eager', fullgraph=False)

            import functools
            import torch
            from torch._dynamo import compiled_autograd
            torch._C._dynamo.compiled_autograd.set_autograd_compiler(
                    functools.partial(compiled_autograd.AutogradCompilerInstance, compiler_fn)
                )

            torch._dynamo.reset()
            torch._dynamo.config.optimize_ddp = "python_reducer"
    else:
        import habana_frameworks.torch.dynamo.compile_backend
        from habana_frameworks.torch.dynamo.compile_backend.experimental import enable_compiled_autograd

    enable_compiled_autograd()

Note that torch._dynamo.config.optimize_ddp = "python_reducer" will force DDP into using the high-overhead python implementation of DDP gradient bucketing. Under compiled autograd, this logic will be traced into the graph and inductor will optimize it. But when you disable compiled autograd, this high-overhead implementation is not captured in the graph and will directly run in eager

@majian4work
Copy link
Contributor Author

@xmfan
backend="eager" here just for reproducing the deadlock issue, because "inductor"'s "triton_key" issue.
actually, we use another "hpu_backend" for hpu device, then enable_compiled_autograd can help to compile the bwd pass after the WA bypassing the deadlock issue, and finally we can get +35% performance boost.

for "python_reducer", I have not investigated too much, maybe it's valuable to try other strategies. thanks for your reminder.

@majian4work
Copy link
Contributor Author

So my suggestion:

  1. we could resolve the deadlock issue, but not with the assert. my proposal is to disable it before calling into compiled_autograd in cpp. Then we don't need any workaround in case the user call torch.autograd.backward in the context of compiled_autograd.
  2. maybe we should remove @torch._disable_dynamo for torch.utils.checkpoint.checkpoint, then it can still try to compile sub-graph if any graph break happens.

@xmfan
Copy link
Member
xmfan commented Sep 11, 2024

@majian4work I'd be happy to review a PR for 1, this can be implemented by using:

  • torch._dynamo.utils.in_compiled_autograd_region is True during the execution of a compiled autograd graph
  • torch.autograd.backward and torch.autograd.grad are the two entrypoints to autograd, we could disable compiled autograd when we are already executing a compiled autograd graph

@majian4work
Copy link
Contributor Author
majian4work commented Sep 12, 2024

@xmfan
thanks, I'll try to submit a patch for 1.
maybe I need to create a separate issue for 2.

tolleybot pushed a commit to tolleybot/pytorch that referenced this issue Sep 14, 2024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this issue Sep 20, 2024
pytorchmergebot pushed a commit that referenced this issue Sep 24, 2024
skotapati pushed a commit to kulinseth/pytorch that referenced this issue Sep 24, 2024
BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this issue Sep 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: compiled autograd compiled_autograd oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
0