Description
🐛 Describe the bug
I found that flex_attention
with NJT inputs yields wrong results. I compared them with the following baselines, all of which are consistent with each other:
flex_attention
in a padded data layout, usingpadding_mask_mod
flex_attention
in a concatenated data layout, usingdocument_mask_mod
- my naive implementation of MHA in the padded layout
This happens both on cpu and cuda devices (A100, 6000 ADA, 2080 Ti).
A second, potentially related issue is that I need to call .clone()
on e.g. Q_cat = Q_njt.values().unsqueeze(0)
. If these lines are commented out I get the following errors:
Commenting out the line for Q (line 82):
AssertionError: Source of '16*s9' is None when lifting it to input of top-level.
Commenting out the line for K (line 83):
AssertionError: s7 (could be from ["L['args'][1].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s12: ["L['args'][1].size()[1]", "L['args'][2].size()[1]"], s13: ["L['args'][1].size()[2]", "L['args'][2].size()[2]"], s14: ["L['args'][1].size()[3]", "L['args'][1].stride()[2]"], s15: ["L['args'][1].stride()[0]"], s70: ["L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.size()[1]"], s80: ["L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]"], s66: ["L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]"], s0: ["L['args'][1]._base.size()[2]"], s16: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s17: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s72: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
Commenting out the line for V (line 84):
AssertionError: s40 (could be from ["L['args'][2].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s7: ["L['args'][1].size()[2]"], s12: ["L['args'][2].size()[1]"], s13: ["L['args'][2].size()[2]"], s14: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s15: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s71: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
Here is a reproducer running and comparing attention in all four versions mentioned above.
import torch
from torch.nested import as_nested_tensor, to_padded_tensor
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
device = 'cuda' if torch.cuda.is_available() else 'cpu'
H = 8
E_inner = 16
E_value = 32
lengths_Q = torch.tensor([11, 13, 15], device=device) # (B,)
lengths_KV = torch.tensor([10, 15, 11], device=device) # (B,)
B = len(lengths_Q)
offsets_Q = torch.cat([torch.tensor([0], device=device), torch.cumsum(lengths_Q, dim=0)]) # (B+1,)
offsets_KV = torch.cat([torch.tensor([0], device=device), torch.cumsum(lengths_KV, dim=0)]) # (B+1,)
batch_ids_Q = torch.repeat_interleave(torch.arange(B, device=device), repeats=lengths_Q) # (NQ_sum,)
batch_ids_KV = torch.repeat_interleave(torch.arange(B, device=device), repeats=lengths_KV) # (NKV_sum,)
# Manual implementation of MHA for padded sequences, used as baseline.
def attention_padded(Q, # (B, H, NQ_max, E_inner)
K, # (B, H, NKV_max, E_inner)
V, # (B, H, NKV_max, E_value)
lengths_Q, # (B, )
lengths_KV, # (B, )
) -> torch.Tensor: # (B, H, NQ_max, E_value)
# Compute mask (True for valid query/key pairs)
B, H, NQ_max, E_inner, NKV_max = *Q.shape, K.size(2)
Q_idxs = torch.arange(NQ_max).view(1,NQ_max,1).to(Q.device)
KV_idxs = torch.arange(NKV_max).view(1,1,NKV_max).to(K.device)
mask = torch.logical_and(Q_idxs < lengths_Q.view( B,1,1),
KV_idxs < lengths_KV.view(B,1,1)
).unsqueeze(1) # (B, 1, NQ_max, NKV_max)
mask = mask.expand(B,H,NQ_max,NKV_max) # (B, H, NQ_max, NKV_max)
# Compute attention
Q /= (E_inner ** 0.5)
scores = torch.einsum('bhic,bhjc->bhij', Q, K) # (B, H, NQ_max, NKV_max)
scores[~mask] = -torch.inf
scores = torch.nn.functional.softmax(scores, dim=-1)
scores[~mask] = 0 # 2nd masking since all -inf rows become uniform (ensures zero-padded output)
attn_out = torch.einsum('bhij,bhjc->bhic', scores, V) # (B, H, NQ_max, E_value)
return attn_out
# mask_mod and BlockMask for flex_attention in padded data layout
def padding_mask_mod(b_idx, h_idx, q_idx, kv_idx):
return torch.logical_and((q_idx < lengths_Q[b_idx]),
(kv_idx < lengths_KV[b_idx]))
block_mask_pad = create_block_mask(mask_mod = padding_mask_mod,
B = B,
H = None, # independent of head index
Q_LEN = lengths_Q.max().item(), # padded layout, NQ_max
KV_LEN = lengths_KV.max().item(), # padded layout, NKV_max
device = device)
# mask_mod and BlockMask for flex_attention in concatenated data layout
def document_mask_mod(b_idx, h_idx, q_idx, kv_idx):
return batch_ids_Q[q_idx] == batch_ids_KV[kv_idx]
block_mask_cat = create_block_mask(mask_mod = document_mask_mod,
B = None, # concat data layout
H = None, # independent of head index
Q_LEN = lengths_Q.sum().item(), # concat layout, NQ_sum
KV_LEN = lengths_KV.sum().item(), # concat layout, NKV_sum
device = device)
# Queries/keys/values in NJT layout
Q_njt = as_nested_tensor([torch.randn(Ni,H,E_inner) for Ni in lengths_Q], layout=torch.jagged)
K_njt = as_nested_tensor([torch.randn(Ni,H,E_inner) for Ni in lengths_KV], layout=torch.jagged)
V_njt = as_nested_tensor([torch.randn(Ni,H,E_value) for Ni in lengths_KV], layout=torch.jagged)
Q_njt = Q_njt.moveaxis(1,2).contiguous().to(device) # (B, H, NQ_jag, E_inner)
K_njt = K_njt.moveaxis(1,2).contiguous().to(device) # (B, H, NKV_jag, E_inner)
V_njt = V_njt.moveaxis(1,2).contiguous().to(device) # (B, H, NKV_jag, E_value)
# Queries/keys/values in concatenated layout
Q_cat = Q_njt.values().unsqueeze(0) # (1, H, NQ_sum, E_inner)
K_cat = K_njt.values().unsqueeze(0) # (1, H, NKV_sum, E_inner)
V_cat = V_njt.values().unsqueeze(0) # (1, H, NKV_sum, E_value)
####################################################################################################
####################################################################################################
# Necessary cloning to prevent assertions, even though *_cat tensors are contiguous
Q_cat = Q_cat.clone() # when commented out individually: AssertionError: Source of '16*s9' is None when lifting it to input of top-level.
K_cat = K_cat.clone() # when commented out individually: AssertionError: s7 (could be from ["L['args'][1].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s12: ["L['args'][1].size()[1]", "L['args'][2].size()[1]"], s13: ["L['args'][1].size()[2]", "L['args'][2].size()[2]"], s14: ["L['args'][1].size()[3]", "L['args'][1].stride()[2]"], s15: ["L['args'][1].stride()[0]"], s70: ["L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.size()[1]"], s80: ["L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]"], s66: ["L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]"], s0: ["L['args'][1]._base.size()[2]"], s16: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s17: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s72: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
V_cat = V_cat.clone() # when commented out individually: AssertionError: s40 (could be from ["L['args'][2].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s7: ["L['args'][1].size()[2]"], s12: ["L['args'][2].size()[1]"], s13: ["L['args'][2].size()[2]"], s14: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s15: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s71: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
####################################################################################################
####################################################################################################
# Queries/keys/values in padded layout
Q_pad = torch.nested.to_padded_tensor(Q_njt, padding=0.) # (B, H, NQ_max, E_inner)
K_pad = torch.nested.to_padded_tensor(K_njt, padding=0.) # (B, H, NKV_max, E_inner)
V_pad = torch.nested.to_padded_tensor(V_njt, padding=0.) # (B, H, NKV_max, E_value)
# Compute attention
# (1) flex_attention in padded layout, with block_mask_pad
attn_out_pad = flex_attention(Q_pad, K_pad, V_pad, block_mask=block_mask_pad)
# (2) flex_attention in concatenated layout, with block_mask_cat
attn_out_cat = flex_attention(Q_cat, K_cat, V_cat, block_mask=block_mask_cat)
# (3) flex_attention in NJT layout
attn_out_njt = flex_attention(Q_njt, K_njt, V_njt)
# (4) manual implementation in padded layout
attn_out_manual = attention_padded(Q_pad, K_pad, V_pad, lengths_Q, lengths_KV)
# Cast NJT result to padded for comparison
attn_out_njt_pad = torch.nested.to_padded_tensor(attn_out_njt, padding=0.)
# Cast concatenated result to padded for comparison
attn_out_cat_pad = torch.zeros_like(attn_out_pad)
for i,Ni in enumerate(lengths_Q):
attn_out_cat_pad[i,:,:Ni] = attn_out_cat[0,:,offsets_Q[i]:offsets_Q[i+1]]
# Print results. (1),(2),(4) agree, (3) is different
print('flex-padded:\n', attn_out_pad[:,0,-8:,0])
print('flex-concat:\n', attn_out_cat_pad[:,0,-8:,0])
print('flex-NJT:\n', attn_out_njt_pad[:,0,-8:,0])
print('manual-padded:\n', attn_out_manual[:,0,-8:,0])
# Compare results via allclose
print('flex-padded == flex-concat: ', torch.allclose(attn_out_pad, attn_out_cat_pad, atol=1e-5)) # True
print('flex-padded == flex-NJT: ', torch.allclose(attn_out_pad, attn_out_njt_pad, atol=1e-5)) # False
print('flex-padded == manual-padded:', torch.allclose(attn_out_pad, attn_out_manual, atol=1e-5)) # True
# The padding structure of (1) is fine though
padding_mask_pad = (attn_out_pad == 0)
padding_mask_njt_pad = (attn_out_njt_pad == 0)
print('padding masks equal:', torch.all(padding_mask_pad == padding_mask_njt_pad).item())
Versions
Collecting environment information...
PyTorch version: 2.8.0.dev20250521+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39
Python version: 3.13.1 (main, Dec 19 2024, 14:32:25) [Clang 18.1.8 ] (64-bit runtime)
Python platform: Linux-6.8.0-58-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe
Nvidia driver version: 550.144.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7532 32-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 0
Frequency boost: enabled
CPU(s) scaling MHz: 71%
CPU max MHz: 2400.0000
CPU min MHz: 1500.0000
BogoMIPS: 4800.03
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 2 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 32 MiB (64 instances)
L3 cache: 512 MiB (32 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng