8000 flex_attention + NJT output inconsistent with non-NJT results · Issue #154554 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
flex_attention + NJT output inconsistent with non-NJT results #154554
Closed
@mauriceweiler

Description

@mauriceweiler

🐛 Describe the bug

I found that flex_attention with NJT inputs yields wrong results. I compared them with the following baselines, all of which are consistent with each other:

  • flex_attention in a padded data layout, using padding_mask_mod
  • flex_attention in a concatenated data layout, using document_mask_mod
  • my naive implementation of MHA in the padded layout

This happens both on cpu and cuda devices (A100, 6000 ADA, 2080 Ti).

A second, potentially related issue is that I need to call .clone() on e.g. Q_cat = Q_njt.values().unsqueeze(0). If these lines are commented out I get the following errors:
Commenting out the line for Q (line 82):

AssertionError: Source of '16*s9' is None when lifting it to input of top-level.

Commenting out the line for K (line 83):

AssertionError: s7 (could be from ["L['args'][1].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s12: ["L['args'][1].size()[1]", "L['args'][2].size()[1]"], s13: ["L['args'][1].size()[2]", "L['args'][2].size()[2]"], s14: ["L['args'][1].size()[3]", "L['args'][1].stride()[2]"], s15: ["L['args'][1].stride()[0]"], s70: ["L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.size()[1]"], s80: ["L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]"], s66: ["L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]"], s0: ["L['args'][1]._base.size()[2]"], s16: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s17: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s72: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665

Commenting out the line for V (line 84):

AssertionError: s40 (could be from ["L['args'][2].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s7: ["L['args'][1].size()[2]"], s12: ["L['args'][2].size()[1]"], s13: ["L['args'][2].size()[2]"], s14: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s15: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s71: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665

Here is a reproducer running and comparing attention in all four versions mentioned above.

import torch
from torch.nested import as_nested_tensor, to_padded_tensor
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

device = 'cuda' if torch.cuda.is_available() else 'cpu'
H = 8
E_inner = 16
E_value = 32
lengths_Q  = torch.tensor([11, 13, 15], device=device) # (B,)
lengths_KV = torch.tensor([10, 15, 11], device=device) # (B,)
B = len(lengths_Q)
offsets_Q  = torch.cat([torch.tensor([0], device=device), torch.cumsum(lengths_Q,  dim=0)]) # (B+1,)
offsets_KV = torch.cat([torch.tensor([0], device=device), torch.cumsum(lengths_KV, dim=0)]) # (B+1,)
batch_ids_Q =  torch.repeat_interleave(torch.arange(B, device=device), repeats=lengths_Q)  # (NQ_sum,)
batch_ids_KV = torch.repeat_interleave(torch.arange(B, device=device), repeats=lengths_KV) # (NKV_sum,)


# Manual implementation of MHA for padded sequences, used as baseline.
def attention_padded(Q,           # (B, H, NQ_max,  E_inner)
                     K,           # (B, H, NKV_max, E_inner)
                     V,           # (B, H, NKV_max, E_value)
                     lengths_Q,   # (B, )
                     lengths_KV,  # (B, )
                    ) -> torch.Tensor: # (B, H, NQ_max,  E_value)
    # Compute mask (True for valid query/key pairs)
    B, H, NQ_max, E_inner, NKV_max = *Q.shape, K.size(2)
    Q_idxs  = torch.arange(NQ_max).view(1,NQ_max,1).to(Q.device)
    KV_idxs = torch.arange(NKV_max).view(1,1,NKV_max).to(K.device)
    mask = torch.logical_and(Q_idxs  < lengths_Q.view( B,1,1),
                             KV_idxs < lengths_KV.view(B,1,1)
                            ).unsqueeze(1) # (B, 1, NQ_max, NKV_max)
    mask = mask.expand(B,H,NQ_max,NKV_max) # (B, H, NQ_max, NKV_max)
    # Compute attention
    Q /= (E_inner ** 0.5)
    scores = torch.einsum('bhic,bhjc->bhij', Q, K)        # (B, H, NQ_max, NKV_max)
    scores[~mask] = -torch.inf
    scores = torch.nn.functional.softmax(scores, dim=-1)
    scores[~mask] = 0 # 2nd masking since all -inf rows become uniform (ensures zero-padded output)
    attn_out = torch.einsum('bhij,bhjc->bhic', scores, V) # (B, H, NQ_max, E_value)
    return attn_out


# mask_mod and BlockMask for flex_attention in padded data layout
def padding_mask_mod(b_idx, h_idx, q_idx, kv_idx):
    return torch.logical_and((q_idx  < lengths_Q[b_idx]),
                             (kv_idx < lengths_KV[b_idx]))

block_mask_pad = create_block_mask(mask_mod   = padding_mask_mod,
                                   B          = B,
                                   H          = None, # independent of head index
                                   Q_LEN      = lengths_Q.max().item(),  # padded layout, NQ_max
                                   KV_LEN     = lengths_KV.max().item(), # padded layout, NKV_max
                                   device     = device)

# mask_mod and BlockMask for flex_attention in concatenated data layout
def document_mask_mod(b_idx, h_idx, q_idx, kv_idx):
    return batch_ids_Q[q_idx] == batch_ids_KV[kv_idx]

block_mask_cat = create_block_mask(mask_mod   = document_mask_mod,
                                   B          = None, # concat data layout
                                   H          = None, # independent of head index
                                   Q_LEN      = lengths_Q.sum().item(),  # concat layout, NQ_sum
                                   KV_LEN     = lengths_KV.sum().item(), # concat layout, NKV_sum
                                   device     = device)


# Queries/keys/values in NJT layout
Q_njt = as_nested_tensor([torch.randn(Ni,H,E_inner) for Ni in lengths_Q],  layout=torch.jagged)
K_njt = as_nested_tensor([torch.randn(Ni,H,E_inner) for Ni in lengths_KV], layout=torch.jagged)
V_njt = as_nested_tensor([torch.randn(Ni,H,E_value) for Ni in lengths_KV], layout=torch.jagged)
Q_njt = Q_njt.moveaxis(1,2).contiguous().to(device)           # (B, H, NQ_jag,  E_inner)
K_njt = K_njt.moveaxis(1,2).contiguous().to(device)           # (B, H, NKV_jag, E_inner)
V_njt = V_njt.moveaxis(1,2).contiguous().to(device)           # (B, H, NKV_jag, E_value)

# Queries/keys/values in concatenated layout
Q_cat = Q_njt.values().unsqueeze(0)                           # (1, H, NQ_sum,  E_inner)
K_cat = K_njt.values().unsqueeze(0)                           # (1, H, NKV_sum, E_inner)
V_cat = V_njt.values().unsqueeze(0)                           # (1, H, NKV_sum, E_value)
####################################################################################################
####################################################################################################
# Necessary cloning to prevent assertions, even though *_cat tensors are contiguous
Q_cat = Q_cat.clone() # when commented out individually:  AssertionError: Source of '16*s9' is None when lifting it to input of top-level.
K_cat = K_cat.clone() # when commented out individually:  AssertionError: s7 (could be from ["L['args'][1].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s12: ["L['args'][1].size()[1]", "L['args'][2].size()[1]"], s13: ["L['args'][1].size()[2]", "L['args'][2].size()[2]"], s14: ["L['args'][1].size()[3]", "L['args'][1].stride()[2]"], s15: ["L['args'][1].stride()[0]"], s70: ["L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.size()[1]"], s80: ["L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]"], s66: ["L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]"], s0: ["L['args'][1]._base.size()[2]"], s16: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s17: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s72: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
V_cat = V_cat.clone() # when commented out individually:  AssertionError: s40 (could be from ["L['args'][2].size()[2]"]) not in {s69: ["L['args'][0].size()[2]"], s7: ["L['args'][1].size()[2]"], s12: ["L['args'][2].size()[1]"], s13: ["L['args'][2].size()[2]"], s14: ["L['args'][2].size()[3]", "L['args'][2].stride()[2]"], s15: ["L['args'][2].stride()[0]"], s76: ["L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.size()[1]"], s71: ["L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s51: ["L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s63: ["L['args'][2]._base.size()[2]"], s87: ["L['args'][4][0]"], s10: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"]}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
####################################################################################################
####################################################################################################

# Queries/keys/values in padded layout
Q_pad = torch.nested.to_padded_tensor(Q_njt, padding=0.) # (B, H, NQ_max,  E_inner)
K_pad = torch.nested.to_padded_tensor(K_njt, padding=0.) # (B, H, NKV_max, E_inner)
V_pad = torch.nested.to_padded_tensor(V_njt, padding=0.) # (B, H, NKV_max, E_value)


# Compute attention
# (1) flex_attention in padded layout, with block_mask_pad
attn_out_pad = flex_attention(Q_pad, K_pad, V_pad, block_mask=block_mask_pad)
# (2) flex_attention in concatenated layout, with block_mask_cat
attn_out_cat = flex_attention(Q_cat, K_cat, V_cat, block_mask=block_mask_cat)
# (3) flex_attention in NJT layout
attn_out_njt = flex_attention(Q_njt, K_njt, V_njt)
# (4) manual implementation in padded layout
attn_out_manual = attention_padded(Q_pad, K_pad, V_pad, lengths_Q, lengths_KV)

# Cast NJT result to padded for comparison
attn_out_njt_pad = torch.nested.to_padded_tensor(attn_out_njt, padding=0.)
# Cast concatenated result to padded for comparison
attn_out_cat_pad = torch.zeros_like(attn_out_pad)
for i,Ni in enumerate(lengths_Q):
    attn_out_cat_pad[i,:,:Ni] = attn_out_cat[0,:,offsets_Q[i]:offsets_Q[i+1]]


# Print results. (1),(2),(4) agree, (3) is different
print('flex-padded:\n',   attn_out_pad[:,0,-8:,0])
print('flex-concat:\n',   attn_out_cat_pad[:,0,-8:,0])
print('flex-NJT:\n',      attn_out_njt_pad[:,0,-8:,0])
print('manual-padded:\n', attn_out_manual[:,0,-8:,0])
# Compare results via allclose
print('flex-padded == flex-concat:  ', torch.allclose(attn_out_pad, attn_out_cat_pad, atol=1e-5)) # True
print('flex-padded == flex-NJT:     ', torch.allclose(attn_out_pad, attn_out_njt_pad, atol=1e-5)) # False
print('flex-padded == manual-padded:', torch.allclose(attn_out_pad, attn_out_manual,  atol=1e-5)) # True
# The padding structure of (1) is fine though
padding_mask_pad     = (attn_out_pad     == 0)
padding_mask_njt_pad = (attn_out_njt_pad == 0)
print('padding masks equal:', torch.all(padding_mask_pad == padding_mask_njt_pad).item())

@BoyuanFeng

Versions

Collecting environment information...
PyTorch version: 2.8.0.dev20250521+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.13.1 (main, Dec 19 2024, 14:32:25) [Clang 18.1.8 ] (64-bit runtime)
Python platform: Linux-6.8.0-58-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe

Nvidia driver version: 550.144.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7532 32-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 0
Frequency boost: enabled
CPU(s) scaling MHz: 71%
CPU max MHz: 2400.0000
CPU min MHz: 1500.0000
BogoMIPS: 4800.03
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 2 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 32 MiB (64 instances)
L3 cache: 512 MiB (32 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

Labels

module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0