8000 torch.compile fails when used together with activation checkpointing · Issue #121966 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

torch.compile fails when used together with activation checkpointing #121966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
0x804d8000 opened this issue Mar 15, 2024 · 10 comments
Closed

torch.compile fails when used together with activation checkpointing #121966

0x804d8000 opened this issue Mar 15, 2024 · 10 comments
Assignees
Labels
module: activation checkpointing Related to activation checkpointing module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@0x804d8000
Copy link
Contributor
0x804d8000 commented Mar 15, 2024

🐛 Describe the bug

import time
import torch
import torch.utils.checkpoint

def add_and_drop(x):
    return torch.nn.functional.dropout(x * 5, 0.5)

x = torch.rand((50001, 3072), device='cuda', dtype=torch.bfloat16)

x.requires_grad_(True)
x.retain_grad()

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    f = torch.compile(add_and_drop)
    out = torch.utils.checkpoint.checkpoint(f, x, use_reentrant=False)
    out.backward(torch.rand_like(out))
    print(x.grad)

fails with

# python3 ./repro.py
/usr/local/lib/python3.9/dist-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py:411: UserWarning: Error detected in NativeDropoutBackward0. Traceback of forward call that caused the error:
  File "//./repro.py", line 6, in add_and_drop
    return torch.nn.functional.dropout(x * 5, 0.5)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:113.)
  result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "//./repro.py", line 15, in <module>
    out = torch.utils.checkpoint.checkpoint(f, x, use_reentrant=False)
  File "/usr/local/lib/python3.9/dist-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/checkpoint.py", line 489, in checkpoint
    ret = function(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 2243, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 919, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 1087, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 1159, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 1140, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/usr/local/lib/python3.9/dist-packages/torch/__init__.py", line 1668, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 1168, in compile_fx
    return aot_autograd(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 887, in aot_module_simplified
    c
8000
ompiled_fn = create_aot_dispatcher_function(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 600, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 425, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 151, in aot_dispatch_autograd
    fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(  # type: ignore[misc]
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 159, in aot_dispatch_autograd_graph
    fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 32, in _create_graph
    fx_g = make_fx(f, decomposition_table=aot_config.decompositions)(*args)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/experimental/proxy_tensor.py", line 871, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/usr/local/lib/python3.9/dist-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/experimental/proxy_tensor.py", line 483, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/_symbolic_trace.py", line 821, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/_symbolic_trace.py", line 688, in flatten_fn
    tree_out = root_fn(*tree_args)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/experimental/proxy_tensor.py", line 519, in wrapped
    out = f(*tensors)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 463, in joint_helper
    return _functionalized_f_helper(primals, tangents)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 355, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 250, in inner_fn_with_anomaly
    return inner_fn(*args)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 235, in inner_fn
    backward_out = torch.autograd.grad(
  File "/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/checkpoint.py", line 1112, in unpack_hook
    frame.recompute_fn(*args)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/checkpoint.py", line 1401, in recompute_fn
    fn(*args, **kwargs)
  File "/usr/lib/python3.9/contextlib.py", line 135, in __exit__
    self.gen.throw(type, value, traceback)
  File "/usr/local/lib/python3.9/dist-packages/torch/random.py", line 175, in fork_rng
    device_mod.set_rng_state(device_rng_state, device)
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/random.py", line 61, in set_rng_state
    new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
  File "/usr/local/lib/python3.9/dist-packages/torch/_subclasses/functional_tensor.py", line 309, in __torch_dispatch__
    outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
  File "/usr/local/lib/python3.9/dist-packages/torch/_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/experimental/proxy_tensor.py", line 596, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/experimental/proxy_tensor.py", line 631, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/experimental/proxy_tensor.py", line 376, in proxy_call
    out = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_ops.py", line 513, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_subclasses/fake_tensor.py", line 1529, in dispatch
    (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
  File "/usr/local/lib/python3.9/dist-packages/torch/_subclasses/fake_tensor.py", line 1776, in validate_and_convert_non_fake_tensors
    validated_args = [validate(a) for a in flat_args]
  File "/usr/local/lib/python3.9/dist-packages/torch/_subclasses/fake_tensor.py", line 1776, in <listcomp>
    validated_args = [validate(a) for a in flat_args]
  File "/usr/local/lib/python3.9/dist-packages/torch/_subclasses/fake_tensor.py", line 1766, in validate
    raise Exception(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.clone.default(tensor([...], size=(16,), dtype=torch.uint8), memory_format=torch.contiguous_format)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

This error only appears if use_reentrant is set to False.

Versions

PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31

Python version: 3.9.2 (default, Feb 28 2021, 17:03:44)  [GCC 10.2.1 20210110] (64-bit runtime)
Is CUDA available: True

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519

@Chillee
Copy link
Collaborator
Chillee commented Mar 17, 2024

I think the issue in this case is that compile(checkpoint)) works but checkpoint(compile) does not.

It's somewhat difficult for us to support compile under checkpoint, particularly in the presence of graph-breaks, but we definitely shouldn't run into this error.

@Chillee Chillee self-assigned this Mar 17, 2024
@0x804d8000
Copy link
Contributor Author

Thanks for the insights!

In our use case, we checkpoint Transformer layer by layer, and use torch.compile to optimize small ops inside Transformer layer, so doing the reverse might be a bit hard.

I wonder if there is any chance we could get it work?

@Chillee
Copy link
Collaborator
Chillee commented Mar 17, 2024

The fundamental thing that's tricky here is that if you imagine the order of events under a checkpointed region, there's 1. the initial forwards pass, 2. the forwards pass that's recomputed in the backwards pass, and then 3. the backwards pass itself.

When the entire checkpoint region can sit under compile, then this can be treated like normal autograd, with a "forwards" graph and a "backwards" graph (which contains the recomputed forwards + the backwards itself).

However, if you're only compiling a subset of the graph, then there are now 3 distinct compiled regions that must run, with other operators inbetween.

Now, I don't think this is impossible to support, particularly if you're fine with the forwards graph and the recomputed forwards graph sharing the same compiled region (and thus losing some of the potential performance benefit of not needing to write out activations).

Another way to support this is to use torch.compile explicitly on the forwards and backwards "ops", although this arguably a bit more annoying.

Perhaps we should just support graph-breaks under activation checkpointing... Perhaps this isn't that difficult to do conceptually if we just reuse the graph.

cc: @albanD @soulitzer @bdhirsh @zou3519 ?

@jansel jansel added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Mar 26, 2024
@ppwwyyxx
Copy link
Collaborator
ppwwyyxx commented May 8, 2024

Also seeing this issue. reentrant=False is now recommended by torch, but this issue is a deal breaker.

@JsBlueCat
Copy link

how to solve this issue? just use special ops to do checkpoint and compile?

@majian4work
Copy link
Contributor

see the same issue while there's a graph break and comment out @torch._disable_dynamo ( comment it out because I don't want to fallback to eager)
in this scenario, compile(checkpoint) turn out to be checkpoint(compile)

@Gy-Lu
Copy link
Gy-Lu commented Nov 19, 2024

@yf225 @Chillee Hi, is this issue the same with #97436 ?
I find that the code can run correctly with torch-2.5.1

@Chillee
Copy link
Collaborator
Chillee commented Nov 19, 2024

@Gy-Lu We actually have fixed this issue. Going to close it now.

@Chillee Chillee closed this as completed Nov 19, 2024
@Gy-Lu
Copy link
Gy-Lu commented Dec 4, 2024

@Gy-Lu We actually have fixed this issue. Going to close it now.

Hi, could you please give me a commit/pull request which fixes this issue?

@soulitzer
Copy link
Contributor
soulitzer commented Dec 4, 2024

#123196

@Gy-Lu This PR fixed something unrelated, but should indirectly support composing AC + compile this way

@soulitzer soulitzer added module: activation checkpointing Related to activation checkpointing and removed activation-checkpointing labels Mar 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: activation checkpointing Related to activation checkpointing module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants
0