8000 Flex Attention is incompatible with selective AC · Issue #147879 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Flex Attention is incompatible with selective AC #147879
Open
@fegin

Description

@fegin

🐛 Describe the bug

When using FlexAttention with selective activation checkpointing, we got an error as below

  traceback : Traceback (most recent call last):
    File "/data/users/chienchin/mywork/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 354, in wrapper
      return f(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/train.py", line 306, in main
      pred = model(input_ids)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1857, in _call_impl
      return inner()
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1805, in inner
      result = forward_call(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 478, in forward
      h = layer(h, self.freqs_cis)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1857, in _call_impl
      return inner()
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1805, in inner
      result = forward_call(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/data/users/chienchin/mywork/pytorch/torch/_compile.py", line 51, in inner
      return disable_fn(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/eval_frame.py", line 764, in _fn
      return fn(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 495, in checkpoint
      ret = function(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 359, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 230, in forward
      output = flex_attention(xq, xk, xv, block_mask=self.block_mask)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/attention/flex_attention.py", line 1357, in flex_attention
      out, lse = torch.compile(
    File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/eval_frame.py", line 585, in _fn
      return fn(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/attention/flex_attention.py", line 1345, in _flex_attention_hop_wrapper
      return flex_attention_hop(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
      return super().__call__(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
      return wrapper()
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 455, in dispatch
      return kernel(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 744, in flex_attention_autograd
      out, logsumexp = FlexAttentionAutogradOp.apply(
    File "/data/users/chienchin/mywork/pytorch/torch/autograd/function.py", line 575, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 610, in forward
      out, logsumexp = flex_attention(
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
      return super().__call__(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
      return wrapper()
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 462, in wrapper
      return torch.overrides.handle_torch_function(
    File "/data/users/chienchin/mywork/pytorch/torch/overrides.py", line 1721, in handle_torch_function
      result = mode.__torch_function__(public_api, types, args, kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 142, in __torch_function__
      return func(*args, **(kwargs or {}))
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
      return super().__call__(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
      return wrapper()
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 365, in dispatch
      raise NotImplementedError(
  NotImplementedError: There was no rule registered for HOP flex_attention and mode <torch.utils.checkpoint._CachingTorchDispatchMode object at 0x7f3e5cc0fac0>. We recommend filing an issue.

This issue can be reproduced with pytorch/torchtitan#887 and CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --model.use_flex_attn

Note that full activation checkpointing doesn't cause this issue.

Versions

nightly

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0