Open
Description
🐛 Describe the bug
When using FlexAttention with selective activation checkpointing, we got an error as below
traceback : Traceback (most recent call last):
File "/data/users/chienchin/mywork/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 354, in wrapper
return f(*args, **kwargs)
File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/train.py", line 306, in main
pred = model(input_ids)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1857, in _call_impl
return inner()
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1805, in inner
result = forward_call(*args, **kwargs)
File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 478, in forward
h = layer(h, self.freqs_cis)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1857, in _call_impl
return inner()
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1805, in inner
result = forward_call(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
return self.checkpoint_fn( # type: ignore[misc]
File "/data/users/chienchin/mywork/pytorch/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/eval_frame.py", line 764, in _fn
return fn(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 495, in checkpoint
ret = function(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 359, in forward
h = x + self.attention(self.attention_norm(x), freqs_cis)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 230, in forward
output = flex_attention(xq, xk, xv, block_mask=self.block_mask)
File "/data/users/chienchin/mywork/pytorch/torch/nn/attention/flex_attention.py", line 1357, in flex_attention
out, lse = torch.compile(
File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/eval_frame.py", line 585, in _fn
return fn(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/nn/attention/flex_attention.py", line 1345, in _flex_attention_hop_wrapper
return flex_attention_hop(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
return super().__call__(
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
return wrapper()
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 467, in wrapper
return self.dispatch(
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 455, in dispatch
return kernel(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 744, in flex_attention_autograd
out, logsumexp = FlexAttentionAutogradOp.apply(
File "/data/users/chienchin/mywork/pytorch/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 610, in forward
out, logsumexp = flex_attention(
File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
return super().__call__(
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
return wrapper()
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 462, in wrapper
return torch.overrides.handle_torch_function(
File "/data/users/chienchin/mywork/pytorch/torch/overrides.py", line 1721, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 142, in __torch_function__
return func(*args, **(kwargs or {}))
File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
return super().__call__(
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
return wrapper()
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 467, in wrapper
return self.dispatch(
File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 365, in dispatch
raise NotImplementedError(
NotImplementedError: There was no rule registered for HOP flex_attention and mode <torch.utils.checkpoint._CachingTorchDispatchMode object at 0x7f3e5cc0fac0>. We recommend filing an issue.
This issue can be reproduced with pytorch/torchtitan#887 and CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --model.use_flex_attn
Note that full activation checkpointing doesn't cause this issue.
Versions
nightly
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng