8000 [Inductor] Permuted memory access pattern for tiled pointwise kernels · Issue #148718 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Inductor] Permuted memory access pattern for tiled pointwise kernels #148718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
blaine-rister opened this issue Mar 7, 2025 · 11 comments
Closed
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@blaine-rister
Copy link
Contributor
blaine-rister commented Mar 7, 2025

🐛 Describe the bug

On this example program:

import torch
import torch._inductor.config as config

config.triton.prefer_nd_tiling = True
config.triton.use_block_ptr = True

full_size = (21, 32)
view_size = (21, 19)
def get_input():
    full = torch.rand(full_size, device="cuda")
    view = torch.as_strided(full, view_size, full.stride())
    return view

inps = [get_input() for _ in range(2)]

compiled = torch.compile(torch.add)
compiled(*inps)

Inductor generates the Triton kernel:

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 32, 'x': 32}, tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '54E46422D5DB2E55B804C8E038A4A0E2ECEED6FCC5402DED453936C14F5DFA13', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 21
    xnumel = 19
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[19, 21], strides=[1, 32], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1], eviction_policy='evict_last')
    tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[19, 21], strides=[1, 32], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1], eviction_policy='evict_last')
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[19, 21], strides=[1, 19], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tl.broadcast_to(tmp2, [XBLOCK, YBLOCK]).to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')

This is technically correct, but the shapes and strides seem permuted for the load and store. Normally, I would expect the trailing stride to be 1 for both. On certain systems, it's more expensive to load/store permuted strides because you end up doing a DMA in the normal order, then permuting on chip.

I want to see if we can flip the dimensions back so strides=[32, 1]. Or does it even matter? Is this just a matter of convention?

Versions

Collecting environment information...
PyTorch version: 2.7.0a0+gitbd78b54
Is debug build: True
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-5)
Clang version: Could not collect
CMake version: version 3.30.4
Libc version: glibc-2.34

Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk14_hardened_2601_gcd42476b84e9-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA PG509-210
Nvidia driver version: 550.90.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 22
On-line CPU(s) list: 0-21
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 22
Socket(s): 1
Stepping: 11
BogoMIPS: 3591.57
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat vnmi umip pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 704 KiB (22 instances)
L1i cache: 704 KiB (22 instances)
L2 cache: 88 MiB (22 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-21
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Vulnerable: eIBRS with unprivileged eBPF
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.13.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] optree==0.13.0
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.7.0a0+gitbd78b54
[pip3] torchaudio==2.5.0a0+79047bf
[pip3] torchbench==0.1
[pip3] torchdata==0.10.0a0+c64801f
[pip3] torchtext==0.17.0a0+1d4ce73
[pip3] triton==3.1.0
[conda] blas 1.0 mkl
[conda] magma-cuda116 2.6.1 1 pytorch
[conda] magma-cuda121 2.6.1 1 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-include 2023.1.0 h06a4308_46344
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.8 py310h5eee18b_0
[conda] mkl_random 1.2.4 py310hdb19cb5_0
[conda] numpy 2.1.2 pypi_0 pypi
[conda] numpy-base 1.26.4 py310hb5e798b_0
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] optree 0.13.0 pypi_0 pypi
[conda] pytorch-sphinx-theme 0.0.24 pypi_0 pypi
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.7.0a0+gitbd78b54 dev_0
[conda] torchaudio 2.5.0a0+79047bf dev_0
[conda] torchbench 0.1 pypi_0 pypi
[conda] torchdata 0.7.0a0+11bb5b8 pypi_0 pypi
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchtext 0.17.0a0+1d4ce73 dev_0
[conda] triton 3.1.0 pypi_0 pypi

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

@blaine-rister blaine-rister self-assigned this Mar 7, 2025
@blaine-rister
Copy link
Contributor Author
blaine-rister commented Mar 7, 2025

Looking back at this PR, it seems like the strides were flipped back then as well. Maybe this shouldn't matter? It still feels like it would be more intuitive to make the trailing stride 1, though.

@blaine-rister blaine-rister changed the title [Inductor] Block pointers permuted for tiled tensors [Inductor] Block pointers permuted for tiled kernels Mar 7, 2025
@blaine-rister blaine-rister changed the title [Inductor] Block pointers permuted for tiled kernels [Inductor] Permuted block pointers for tiled pointwise kernels Mar 7, 2025
@blaine-rister
Copy link
Contributor Author
blaine-rister commented Mar 8, 2025

cc @jansel: Do you know if the data is being transposed by the triton kernel? Or am I just reading this wrong?

@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 8, 2025
@jansel
Copy link
Contributor
jansel commented Mar 9, 2025

This looks like a bug in loop ordering to me, it is correct but this is not a good loop order. cc @shunting314

What happens if you set this config False:

pick_loop_orders = True

@shunting314
Copy link
Contributor

Could be an issue specific to the block ptr implementation?

If disabling the block pointer, the triton kernel generated have the correct loop order:

@triton.jit
def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 21
    xnumel = 19
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(in_ptr0 + (x1 + 32*y0), xmask & ymask, eviction_policy='evict_last')
    tmp1 = tl.load(in_ptr1 + (x1 + 32*y0), xmask & ymask, eviction_policy='evict_last')
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x1 + 19*y0), tmp2, xmask & ymask)

@blaine-rister
Copy link
Contributor Author
blaine-rister commented Mar 13, 2025

Could be an issue specific to the block ptr implementation?

If disabling the block pointer, the triton kernel generated have the correct loop order:

@triton.jit
def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 21
    xnumel = 19
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(in_ptr0 + (x1 + 32*y0), xmask & ymask, eviction_policy='evict_last')
    tmp1 = tl.load(in_ptr1 + (x1 + 32*y0), xmask & ymask, eviction_policy='evict_last')
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x1 + 19*y0), tmp2, xmask & ymask)

@shunting314 That was my thought as well, but after looking at it a little more I'm now thinking the same loop order is used for both block pointers and non-block pointers, and they're both getting flipped. There are cases where we can't emit a block pointer, and the kernel ends up mixing block and non-block pointers. Those cases seem to have compatible tensor dims, strides, etc. There's an example in https://github.com/pytorch/pytorch/issues/148725.

In your example, will the pointer in_ptr0 + (x1 + 32*y0) end up having a shape of [XBLOCK, YBLOCK]? I'm thinking this is also flipped, since xnumel = 19 corresponds to the innermost dimension of the input/output PyTorch tensors. Does that check out?

@blaine-rister
Copy link
Contributor Author
blaine-rister commented Mar 13, 2025

This looks like a bug in loop ordering to me, it is correct but this is not a good loop order. cc @shunting314

What happens if you set this config False:

pytorch/torch/_inductor/config.py

Line 156 in 17dbeb1

pick_loop_orders = True

@jansel I tried this out and it seems like we get the same loop order with either setting of config.pick_loop_orders. Both the default and the strides-based algorithm result in loops with leading strides of 1.

I created a branch to experiment with reversing the order returned by pick_loop_orders. The code seems to work, but some optimizations like loop fusion aren't composing with the new order. I'm trying to see if we can get the same code as before, but flipped.

@shunting314
Copy link
Contributor

In your example, will the pointer in_ptr0 + (x1 + 32*y0) end up having a shape of [XBLOCK, YBLOCK]?

Yes. I think the reason it's [XBLOCK, YBLOCK} rather than [YBLOCK, XBLOCK] is due to how we construct the index:

xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]

If we did another way like

xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
yindex = yoffset + tl.arange(0, YBLOCK)[:, None]

the tile shape will be [YBLOCK, XBLOCK].

But it does not matter for non-block pointer.

@blaine-rister
Copy link
Contributor Author
blaine-rister commented Mar 14, 2025

In your example, will the pointer in_ptr0 + (x1 + 32*y0) end up having a shape of [XBLOCK, YBLOCK]?

Yes. I think the reason it's [XBLOCK, YBLOCK} rather than [YBLOCK, XBLOCK] is due to how we construct the index:

xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]

If we did another way like

xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
yindex = yoffset + tl.arange(0, YBLOCK)[:, None]

the tile shape will be [YBLOCK, XBLOCK].

But it does not matter for non-block pointer.

@shunting314 @jansel I was doubting whether the existing order is really flipped, so I decided to run some tests on handwritten triton kernels. I've attached the scripts and their output here:

test_loop_order.zip

Kernel 1: X first (what Inductor does today)

The first script print_loop_order_x_first.py is what Inductor generates today, on the elementwise add test case above. Here, X is the leading dimension. (To make things easier to read, I modified the shape in the test case to (4,4) and strides to (5, 1).) The only modification to this script is that we flatten the index tensor to 1D, and then print it. The rationale here is that tl.load should be invariant to reshaping the pointer, so by flattening the index, we can tell which addresses will be loaded by which threads.

@triton.jit
def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 4
    xnumel = 4
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    combined_idx = tl.reshape(x1 + 5 * y0, [XBLOCK * YBLOCK]) # Debugging
    tl.device_print("linear_idx", combined_idx)
    tmp0 = tl.load(in_ptr0 + (x1 + 5*y0), xmask & ymask)
    tmp1 = tl.load(in_ptr1 + (x1 + 5*y0), xmask & ymask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x1 + 4*y0), tmp2, xmask & ymask)

When I ran this script, the printed indices were discontiguous. This is essentially accessing the matrix in column-major order, although the underlying data is row-major.

pid (0, 0, 0) idx ( 0) linear_idx: 0
pid (0, 0, 0) idx ( 1) linear_idx: 5
pid (0, 0, 0) idx ( 2) linear_idx: 10
pid (0, 0, 0) idx ( 3) linear_idx: 15
pid (0, 0, 0) idx ( 4) linear_idx: 1
pid (0, 0, 0) idx ( 5) linear_idx: 6
pid (0, 0, 0) idx ( 6) linear_idx: 11
pid (0, 0, 0) idx ( 7) linear_idx: 16
pid (0, 0, 0) idx ( 8) linear_idx: 2
pid (0, 0, 0) idx ( 9) linear_idx: 7
pid (0, 0, 0) idx (10) linear_idx: 12
pid (0, 0, 0) idx (11) linear_idx: 17
pid (0, 0, 0) idx (12) linear_idx: 3
pid (0, 0, 0) idx (13) linear_idx: 8
pid (0, 0, 0) idx (14) linear_idx: 13
pid (0, 0, 0) idx (15) linear_idx: 18

Kernel 2: Y first

In the second file, print_loop_order_y_first.py, I've modified the Triton kernel to make Y the leading dimension, instead of X. This simply means swapping the [:, None] slices as @shunting314 suggested.

@triton.jit
def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 4
    xnumel = 4
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    combined_idx = tl.reshape(x1 + 5 * y0, [XBLOCK * YBLOCK]) # Debugging
    tl.device_print("linear_idx", combined_idx)
    tmp0 = tl.load(in_ptr0 + (x1 + 5*y0), xmask & ymask)
    tmp1 = tl.load(in_ptr1 + (x1 + 5*y0), xmask & ymask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x1 + 4*y0), tmp2, xmask & ymask)

After swapping the slices, the indices became contiguous.

pid (0, 0, 0) idx ( 0) linear_idx: 0
pid (0, 0, 0) idx ( 1) linear_idx: 1
pid (0, 0, 0) idx ( 2) linear_idx: 2
pid (0, 0, 0) idx ( 3) linear_idx: 3
pid (0, 0, 0) idx ( 4) linear_idx: 5
pid (0, 0, 0) idx ( 5) linear_idx: 6
pid (0, 0, 0) idx ( 6) linear_idx: 7
pid (0, 0, 0) idx ( 7) linear_idx: 8
pid (0, 0, 0) idx ( 8) linear_idx: 10
pid (0, 0, 0) idx ( 9) linear_idx: 11
pid (0, 0, 0) idx (10) linear_idx: 12
pid (0, 0, 0) idx (11) linear_idx: 13
pid (0, 0, 0) idx (12) linear_idx: 15
pid (0, 0, 0) idx (13) linear_idx: 16
pid (0, 0, 0) idx (14) linear_idx: 17
pid (0, 0, 0) idx (15) linear_idx: 18

Kernel 3: X first, permuted loop order

The bad news is that changing Inductor's loop order doesn't generate the code from file 2. The third script, print_loop_order_permuted.py contains the code that Inductor generates if we reverse the output of the the pick_loop_order function. (The code changes are in #149176.

@triton.jit
def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 4
    xnumel = 4
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    combined_idx = tl.reshape(y0 + 5 * x1, [XBLOCK * YBLOCK]) # Debugging
    tl.device_print("linear_idx", combined_idx)
    tmp0 = tl.load(in_ptr0 + (y0 + 5*x1), xmask & ymask)
    tmp1 = tl.load(in_ptr1 + (y0 + 5*x1), xmask & ymask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (y0 + 4*x1), tmp2, xmask & ymask)

In this case, the strides are permuted compared to kernel #1, but X is still the leading dimension. This gives us contiguous indices, the same as in kernel #2. But IMO the code is less intuitive to reason about than #2. Permuting the strides also seems to affect some of the other loop optimizations that Inductor does, like loop fusion. So it may be more desirable to generate #2 by changing the Triton backend, compared to swapping the loop order higher in the stack.

pid (0, 0, 0) idx ( 0) linear_idx: 0
pid (0, 0, 0) idx ( 1) linear_idx: 1
pid (0, 0, 0) idx ( 2) linear_idx: 2
pid (0, 0, 0) idx ( 3) linear_idx: 3
pid (0, 0, 0) idx ( 4) linear_idx: 5
pid (0, 0, 0) idx ( 5) linear_idx: 6
pid (0, 0, 0) idx ( 6) linear_idx: 7
pid (0, 0, 0) idx ( 7) linear_idx: 8
pid (0, 0, 0) idx ( 8) linear_idx: 10
pid (0, 0, 0) idx ( 9) linear_idx: 11
pid (0, 0, 0) idx (10) linear_idx: 12
pid (0, 0, 0) idx (11) linear_idx: 13
pid (0, 0, 0) idx (12) linear_idx: 15
pid (0, 0, 0) idx (13) linear_idx: 16
pid (0, 0, 0) idx (14) linear_idx: 17
pid (0, 0, 0) idx (15) linear_idx: 18

What do you think? Do these tests seem reasonable? I also tried some performance tests with larger sizes, but the results were inconclusive. I'll keep looking to see if I can find a test case that shows a clear difference.

@shunting314
Copy link
Contributor

For kernel 1, I think the memory access pattern looks bad mainly because the block size is too small. For reasonable large block sizes, our memory access pattern should still be good

@blaine-rister
Copy link
Contributor Author
blaine-rister commented Mar 14, 2025

For kernel 1, I think the memory access pattern looks bad mainly because the block size is too small. For reasonable large block sizes, our memory access pattern should still be good

That's a good point. Is the idea that we eventually get YBLOCK=1 for large enough inputs? In that case, the addresses would become contiguous again.

As long as YBLOCK > 1, it seems like the stride of kernel 1's (linear) memory accesses is equal to xnumel, so we'd have a larger stride between accesses the larger our tensor is. But maybe the block sizes interact with this in a less predictable way.

Anyways, would there be any objection to swapping the memory access pattern to look like kernel 2, if it's at least perf neutral on GPU? That would work better on systems that rely on DMAs for memory accesses, where you're hoping that the memory will be fetched in large contiguous spans.

@blaine-rister blaine-rister changed the title [Inductor] Permuted block pointers for tiled pointwise kernels [Inductor] Permuted memory access pattern for tiled pointwise kernels Mar 14, 2025
@shunting314
Copy link
Contributor

Is the idea that we eventually get YBLOCK=1 for large enough inputs? In that case, the addresses would become contiguous again.

I was thinking that triton would be able to find more reasonable layout for non-trivial block sizes.

svekars pushed a commit that referenced this issue Mar 21, 2025
…re x (#149339)

# Feature

Fixes #148718 by reordering the tensor dims to `(z, y, x)`.

As a bonus refactor, block pointers no longer needed the `reorder=True` argument to `self.active_range_trees()`. Since this argument is no longer used anywhere, this PR simply deletes it as opposed to updating the logic for the new iteration order.

# Perf impact

It looks like there's a decent perf bump on A100, with cudagraphs enabled. Granted, perf runs seem to have some noise between commits. ([Workflow run](https://github.com/pytorch/pytorch/actions/runs/13914815576).)

Training (all neutral or positive):
![image](https://github.com/user-attachments/assets/57f1ef1d-60b4-446f-baf3-aca87a26b81b)

Inference (one positive, one very small negative):
![image](https://github.com/user-attachments/assets/679aa057-af23-47f1-8d8e-8520daf1bd92)

As reported in #148718, this PR makes consecutive threads access consecutive memory addresses. This should theoretically give the GPU more opportunities to coalesce loads and stores. From Nvidia's [kernel profiling guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html):

> Local memory is private storage for an executing thread and is not visible outside of that thread. It is intended for thread-local data like thread stacks and register spills. Local memory addresses are translated to global virtual addresses by the AGU unit. Local memory has the same latency as global memory. One difference between global and local memory is that local memory is arranged such that consecutive 32-bit words are accessed by consecutive thread IDs. Accesses are therefore fully coalesced as long as all threads in a warp access the same relative address (e.g., same index in an array variable, same member in a structure variable, etc.).

I couldn't find any information on how coalescing works for other kinds of memory, but the guide mentions it is also supported for accesses to the L2 cache.

> The L2 Request Coalescer (LRC) processes incoming requests for L2 and tries to coalesce read requests before forwarding them to the L2 cache. It also serves programmatic multicast requests from the SM and supports compression for writes.

The [answer to this Stack Overflow post](https://stackoverflow.com/a/5044424) also explains coalescing in a straightforward way. Inductor's current iteration order corresponds to the first (uncoalesced) example in that answer, while the order after this PR corresponds to the second (coalesced) example.

Besides GPUs, this order of accessing data is highly advantageous for systems relying on DMAs, as those are designed to access contiguous spans of memory. This change improves the performance of an elementwise add kernel on an internal model, using internal hardware, by 1.76x. I will share the details with reviewers who are Meta employees via a private channel.

# Test plan
 - Updated expected code on CI tests.
 - Added a new test checking the {x,y,z}indices and block pointers on a 3D pointwise kernel.

Pull Request resolved: #149339
Approved by: https://github.com/jansel
amathewc pushed a commit to amathewc/pytorch that referenced this issue Apr 17, 2025
…re x (pytorch#149339)

# Feature

Fixes pytorch#148718 by reordering the tensor dims to `(z, y, x)`.

As a bonus refactor, block pointers no longer needed the `reorder=True` argument to `self.active_range_trees()`. Since this argument is no longer used anywhere, this PR simply deletes it as opposed to updating the logic for the new iteration order.

# Perf impact

It looks like there's a decent perf bump on A100, with cudagraphs enabled. Granted, perf runs seem to have some noise between commits. ([Workflow run](https://github.com/pytorch/pytorch/actions/runs/13914815576).)

Training (all neutral or positive):
![image](https://github.com/user-attachments/assets/57f1ef1d-60b4-446f-baf3-aca87a26b81b)

Inference (one positive, one very small negative):
![image](https://github.com/user-attachments/assets/679aa057-af23-47f1-8d8e-8520daf1bd92)

As reported in pytorch#148718, this PR makes consecutive threads access consecutive memory addresses. This should theoretically give the GPU more opportunities to coalesce loads and stores. From Nvidia's [kernel profiling guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html):

> Local memory is private storage for an executing thread and is not v
6843
isible outside of that thread. It is intended for thread-local data like thread stacks and register spills. Local memory addresses are translated to global virtual addresses by the AGU unit. Local memory has the same latency as global memory. One difference between global and local memory is that local memory is arranged such that consecutive 32-bit words are accessed by consecutive thread IDs. Accesses are therefore fully coalesced as long as all threads in a warp access the same relative address (e.g., same index in an array variable, same member in a structure variable, etc.).

I couldn't find any information on how coalescing works for other kinds of memory, but the guide mentions it is also supported for accesses to the L2 cache.

> The L2 Request Coalescer (LRC) processes incoming requests for L2 and tries to coalesce read requests before forwarding them to the L2 cache. It also serves programmatic multicast requests from the SM and supports compression for writes.

The [answer to this Stack Overflow post](https://stackoverflow.com/a/5044424) also explains coalescing in a straightforward way. Inductor's current iteration order corresponds to the first (uncoalesced) example in that answer, while the order after this PR corresponds to the second (coalesced) example.

Besides GPUs, this order of accessing data is highly advantageous for systems relying on DMAs, as those are designed to access contiguous spans of memory. This change improves the performance of an elementwise add kernel on an internal model, using internal hardware, by 1.76x. I will share the details with reviewers who are Meta employees via a private channel.

# Test plan
 - Updated expected code on CI tests.
 - Added a new test checking the {x,y,z}indices and block pointers on a 3D pointwise kernel.

Pull Request resolved: pytorch#149339
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants
0