8000 [Mosaic GPU] Check in WIP grouped GEMM by andportnoy · Pull Request #26997 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Mosaic GPU] Check in WIP grouped GEMM #26997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft

Conversation

andportnoy
Copy link
Contributor
@andportnoy andportnoy commented Mar 7, 2025

No description provided.

Currently the kernel hangs in the second iteration because the
mbarrier state is not carried over correctly. In the last iteration of
the k-loop we arrive on mma_done_barrier instead of ab_empty_barrier,
so the phase of ab_empty_barrier is not flipped.
We are not using clusters, the extra code was only hurting readabilitiy.
We accomplish this by:

1. Round-robin iterating through input buffer slots based on
`persistent_ki`, which tracks k-loop iterations over the entire
lifetime of a CTA worker, and not just a single work item.

2. Arriving on an ab_empty_barrier even in the last iteration of a
k-loop, so that barrier phase is flipped correctly across work items.

3. Waiting on an ab_empty_barrier in all iterations of all work items'
k-loops, except for the first `max_concurrent_steps` iterations of the
very first work item's k-loop.
@andportnoy andportnoy changed the title [Mosaic GPU] Check in WIP persistent matmul kernel [Mosaic GPU] Check in WIP grouped GEMM Mar 19, 2025
andportnoy and others added 15 commits March 20, 2025 19:30
Otherwise when TMA box coordinates are not aligned with a subtile, the
4D TMA box goes out of bounds of the 4D TMA tensor in a non-contiguous
fashion.
Committing with all the messy debug code so I can come back to it
later if needed.

Caveats/hacks:

- We store group chunks at inexact offsets (aligned to tile_m size),
  so the computation is not fully correct/compliant. We should modify
  the kernel to store at exact offsets. We could do this by storing
  directly from registers, or by using tensormap.replace before each
  TMA store. Storing from registers feels simpler.

- We compute the grouped GEMM schedule separately on host. We should
  either do this in a single optimized kernel or fuse this into the
  grouped GEMM, or compute on the fly somehow.
Co-authored-by: Adam Paszke <apaszke@google.com>
It's broken, fails with an illegal instruction. Something is off in
the subgroup_for loop, work_id is not carried correctly. I probably
need a while loop instead.
This needs to be moved to Mosaic GPU utilities.


# TODO(andportnoy) move into tcgen05.py
def tmem_dealloc(tmem: ir.Value, ncols: int, collective: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be deleted now, as it's done automatically by MGPU

import jax.numpy as jnp
import jax.random as jr
import numpy as np
from cuda.bindings.runtime import cudaDeviceGetAttribute, cudaDeviceAttr, cudaGetDeviceCount, cudaError_t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can depend on the CUDA Python APIs in those files. Is there some way to do what you want using the JAX device APIs? If not, we should consider exposing it



def bytecount(shape, dtype):
return int(np.prod(shape) * dtype.dtype.itemsize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: whitespace between top-level declarations



@contextlib.contextmanager
def single_warp_thread():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You already have mgpu.single_thread(scope=mgpu.ThreadSubset.WARP)

jnp.float16)
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
tmem_cols = tmem_slot_count * tile_n
assert tmem_cols.bit_count() == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should not be necessary



TMA_WARP = 1
MMA_WARP = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those seem unused

warpgroup = allocator.alloc_warpgroup

tma_warp = warp()
mma_warp = warp()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of the whole WarpAllocator indirection? They're still allocated statically, so you can just hardcode them to TMA_WARP = 0 and MMA_WARP = 1 as you did before, right?

shape=(tile_m, tile_n),
layout=acc.layout,
dtype=acc.dtype,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just slicing the TMEM ref, right? Perhaps extend TMEMRef.slice to support this, instead?

return value


def subgroup_for(work_id, work_id_group_end, worker_count):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do? Docs would be helpful

)

a_smem_slot = mgpu.memref_slice(a_smem, slot)
a_smem_slot_2d_shape, _ = mma_utils.tiled_memref_shape(a_smem_slot)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you should not use mma_utils in this kernel. It's a private file. What's happening here?

work_id_group_start = cx(0)
initial_work_id = worker_id
group_offset = cx(0)
@mgpu.fori(cx(expert_count), [initial_work_id, work_id_group_start, group_offset])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand the iteration scheme of this kernel. Could you please describe it a bit better?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0