-
Notifications
You must be signed in to change notification settings - Fork 3k
[Mosaic GPU] Check in WIP grouped GEMM #26997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Currently the kernel hangs in the second iteration because the mbarrier state is not carried over correctly. In the last iteration of the k-loop we arrive on mma_done_barrier instead of ab_empty_barrier, so the phase of ab_empty_barrier is not flipped.
We are not using clusters, the extra code was only hurting readabilitiy.
We accomplish this by: 1. Round-robin iterating through input buffer slots based on `persistent_ki`, which tracks k-loop iterations over the entire lifetime of a CTA worker, and not just a single work item. 2. Arriving on an ab_empty_barrier even in the last iteration of a k-loop, so that barrier phase is flipped correctly across work items. 3. Waiting on an ab_empty_barrier in all iterations of all work items' k-loops, except for the first `max_concurrent_steps` iterations of the very first work item's k-loop.
Otherwise when TMA box coordinates are not aligned with a subtile, the 4D TMA box goes out of bounds of the 4D TMA tensor in a non-contiguous fashion.
Committing with all the messy debug code so I can come back to it later if needed. Caveats/hacks: - We store group chunks at inexact offsets (aligned to tile_m size), so the computation is not fully correct/compliant. We should modify the kernel to store at exact offsets. We could do this by storing directly from registers, or by using tensormap.replace before each TMA store. Storing from registers feels simpler. - We compute the grouped GEMM schedule separately on host. We should either do this in a single optimized kernel or fuse this into the grouped GEMM, or compute on the fly somehow.
Co-authored-by: Adam Paszke <apaszke@google.com>
It's broken, fails with an illegal instruction. Something is off in the subgroup_for loop, work_id is not carried correctly. I probably need a while loop instead.
This needs to be moved to Mosaic GPU utilities.
|
||
|
||
# TODO(andportnoy) move into tcgen05.py | ||
def tmem_dealloc(tmem: ir.Value, ncols: int, collective: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be deleted now, as it's done automatically by MGPU
import jax.numpy as jnp | ||
import jax.random as jr | ||
import numpy as np | ||
from cuda.bindings.runtime import cudaDeviceGetAttribute, cudaDeviceAttr, cudaGetDeviceCount, cudaError_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can depend on the CUDA Python APIs in those files. Is there some way to do what you want using the JAX device APIs? If not, we should consider exposing it
|
||
|
||
def bytecount(shape, dtype): | ||
return int(np.prod(shape) * dtype.dtype.itemsize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: whitespace between top-level declarations
|
||
|
||
@contextlib.contextmanager | ||
def single_warp_thread(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You already have mgpu.single_thread(scope=mgpu.ThreadSubset.WARP)
jnp.float16) | ||
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer]) | ||
tmem_cols = tmem_slot_count * tile_n | ||
assert tmem_cols.bit_count() == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should not be necessary
|
||
|
||
TMA_WARP = 1 | ||
MMA_WARP = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those seem unused
warpgroup = allocator.alloc_warpgroup | ||
|
||
tma_warp = warp() | ||
mma_warp = warp() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the benefit of the whole WarpAllocator indirection? They're still allocated statically, so you can just hardcode them to TMA_WARP = 0
and MMA_WARP = 1
as you did before, right?
shape=(tile_m, tile_n), | ||
layout=acc.layout, | ||
dtype=acc.dtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just slicing the TMEM ref, right? Perhaps extend TMEMRef.slice
to support this, instead?
return value | ||
|
||
|
||
def subgroup_for(work_id, work_id_group_end, worker_count): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do? Docs would be helpful
) | ||
|
||
a_smem_slot = mgpu.memref_slice(a_smem, slot) | ||
a_smem_slot_2d_shape, _ = mma_utils.tiled_memref_shape(a_smem_slot) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you should not use mma_utils in this kernel. It's a private file. What's happening here?
work_id_group_start = cx(0) | ||
initial_work_id = worker_id | ||
group_offset = cx(0) | ||
@mgpu.fori(cx(expert_count), [initial_work_id, work_id_group_start, group_offset]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if I understand the iteration scheme of this kernel. Could you please describe it a bit better?
No description provided.