8000 Lower `linalg.copy` to direct global load by lialan · Pull Request #20568 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Lower linalg.copy to direct global load #20568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

lialan
Copy link
Contributor
@lialan lialan commented Apr 17, 2025

Summary

This PR sets the foundation for using global_load_lds instruction to load values from global to LDS memory. The pipeline is as follows:

  • Only convert linalg.copy emitted in PromoteGPUMatMulOperands. When it sees fit, insert a different attribute (#iree_gpu.use_global_load_dma) to linalg.copy to tag it along the pipeline.
  • Tagged linalg.copy will not be decomposed/tiled until bufferization.
  • after distributed to threads and bufferization, the tagged linalg.copy will then be lowered to a sequence of code responsible for subgroup-coalesced loading op iree_gpu.global_load_dma.
  • iree_gpu.global_load_dma will be mapped to amdgpu.gather_to_lds op, which will mapped to corresponding rocdl op.
  • Disable padding to reduce bank conflict pass because the destination workgroup memory has to be contiguous.

Lowering linalg.copy

After bufferization and distribute to threads, tagged linalg.copy still exists in the IR:

linalg.copy {lowering_config = #iree_gpu.use_global_load_dma}
  ins(%subview_12 : memref<64x128xi8, strided<[256, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>)
  outs(%alloc_4 : memref<64x128xi8, #gpu.address_space<workgroup>>)

Note that this linalg.copy is kept in the thread's code. The op itself is then converted into a for loop, in which subgroup of threads loads coalesced chunk of values. For example, assume there are N subgroups loading from tensor<a x b x c>:

  • then i-th subgruop will load a sub tensor of size [a/N, b, c], so each slice is consecutive.
    • At this moment, assume row-major, and only tile the outermost dim.
    • The reason right now we are only dealing with linalg.copy emitted by GPUPromoteMatmulOperands is that we know the destination is allocated contiguously.
    • TODO: expand to any memref slices.
  • given gpu.subgroup_id and gpu.lane_id, each thread calculates the consecutive data chunk the subgroup the thread belongs to is responsible to load:
    • the chunk indices is the delinearized indices of the input tensor, from:
      • affine.delinearize_index[gpu.subgroup_id * (num_elems_of(tensor) / num_subgroups)], to
      • affine.delinearize_index[(gpu.subgroup_id + 1) * (num_elems_of(tensor) / num_subgroups) - 1]
  • Assume each subgroup will load n values from linearized index [N_f, N_b], then thread with lane id i will try to load: iter = 0 to n : N_f + subgroup_size * iter + (i - 1) .
    Then it will be converted to something like the following (in the example, assume workgroup size = 256, subgroup_size = 64, loading 64x128xi8):
scf.for %indvar = %c0 to %c32 step %c1 {
  ;; thread-specific gathering address from global address
  %17 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 2048 + s2 * 64)>()[%lane_id, %subgroup_id, %indvar]
  %18:2 = affine.delinearize_index %17 into (128, 64) : index, index
  ;; this iteration's base storing index
  %19 = affine.apply affine_map<()[s0, s1] -> (s0 * 2048 + s1 * 64)>()[%subgroup_id, %indvar]
  %20:2 = affine.delinearize_index %19 into (128, 64) : index, index 
  iree_gpu.global_load_dma %subview_13[%18#0, %18#1] -> %alloc_5[%20#0, %20#1] : memref<128x64xi8, strided<[256, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> -> memref<128x64xi8, #gpu.address_space<workgroup>>
}
;; if there are residual elements (subgroup_copy_region_size % subgroup_size != 0), copy residual elements here 
gpu.barrier

Dependent PRs:

@krzysz00
Copy link
Contributor

Side note, I'm still poking at getting the buffer fat pointer to LDS intrinsic set up - it's caught up in bikeshedding on the compiler team

@lialan lialan force-pushed the users/lialan/global_load_lds branch from e4ce145 to 335319d Compare April 22, 2025 02:00
Copy link
Contributor
@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level observation: at the point this is being called, shouldn't we know the subgroup size, so that we don't need the subgroup_id op?

Like, you can just look at the workgroup sizes to see which subgroup you're in

//===----------------------------------------------------------------------===//

SmallVector<int64_t>
UseGlobalLoadDMAAttr::getStaticTilingLevelSizes(unsigned level,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why's there an unsigned in here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how the LoweringConfigAttrInterface exposes tiling levels. It's up to the backend + lowering config to interpret the level consistently.

@lialan lialan force-pushed the users/lialan/global_load_lds branch from bbe9565 to f6c7290 Compare April 28, 2025 19:19
@lialan lialan changed the title Implement iree_gpu.global_load_dma op Lower linalg.copy to direct global load Apr 28, 2025
@lialan lialan force-pushed the users/lialan/global_load_lds branch 5 times, most recently from 275e72b to 97161a8 Compare May 1, 2025 20:05
Copy link
Contributor
@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few notes, some of which I apparently failed to submit on Friday x.x

@lialan lialan force-pushed the users/lialan/global_load_lds branch 3 times, most recently from ee462fd to c2b31d4 Compare May 13, 2025 14:27
@lialan
Copy link
Contributor Author
lialan commented May 13, 2025

Getting this error in CI:

iree/runtime/src/iree/hal/drivers/hip/native_executable.c:358: FAILED_PRECONDITION;
HIP driver error 'hipErrorSharedObjectInitFailed' (303): shared object initialization failed;
mismatched target chip? missing/wrong bitcode directory?;
while invoking native function hal.executable.create; while calling import; 

Needs to wait until llvm/llvm-project#137425 is integrated.

fixed.

@lialan lialan force-pushed the users/lialan/global_load_lds branch 2 times, most recently from 8152e33 to f7a8bda Compare May 14, 2025 02:21
@lialan lialan marked this pull request as ready for review May 14, 2025 19:31
@lialan lialan force-pushed the users/lialan/global_load_lds branch 2 times, most recently from e48e8b2 to 0b6d30c Compare May 19, 2025 15:04
Copy link
Contributor
@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @lialan. It looks better! Here is my final round of the review. (I can skim through the code again after you address the comments.)

Copy link
Contributor
@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some notes

return numElements;
}

static bool distributeLinalgCopyToThreads(RewriterBase &rewriter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this might want to be a LogicalResult?

@lialan lialan force-pushed the users/lialan/global_load_lds branch from c66f276 to c5206f1 Compare May 28, 2025 01:09
@lialan lialan force-pushed the users/lialan/global_load_lds branch from c5206f1 to cb3bad7 Compare May 28, 2025 01:36
@lialan lialan requested a review from qedawkins May 28, 2025 02:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0