Skip to content

Commit

Permalink
[Mosaic GPU] Take TMEM as a TMEMRef in tcgen05.mma, not as a raw address
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723936021
  • Loading branch information
apaszke authored and Google-ML-Automation committed Feb 6, 2025
1 parent efbb0af commit 026b6c9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
6 changes: 2 additions & 4 deletions jax/experimental/mosaic/gpu/examples/matmul_blackwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
from jax._src.lib.mlir.dialects import nvvm
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import c, ds, utils
Expand Down Expand Up @@ -47,7 +46,6 @@ def build_kernel(
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()
index = ir.IndexType.get()
ptr6 = ir.Type.parse("!llvm.ptr<6>") # TMEM

swizzle = 128
tile_k = 64 # TODO(apaszke): I think we need to tile TMA to change this.
Expand Down Expand Up @@ -111,13 +109,13 @@ def _tma_body(ki, _):
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
tcgen05.tmem_alloc(tmem_addr_addr, tile_n)
tcgen05.tmem_relinquish_alloc_permit()
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
with mgpu.when(warp_leader):
tmem_addr_value = llvm.load(ptr6, tmem_addr_addr)
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
def _mma_body(ki, accumulate):
ab_full_barrier.wait()
tcgen05.mma(
tmem_addr_value,
tmem_ref,
a_smem,
mgpu.memref_transpose(b_smem, (0, 1, 3, 2)),
a_swizzle=swizzle,
Expand Down
23 changes: 18 additions & 5 deletions jax/experimental/mosaic/gpu/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import dataclasses
import enum

Expand Down Expand Up @@ -86,7 +88,7 @@ def create_instr_descriptor(


def mma(
d: ir.Value,
d: TMEMRef,
a: ir.Value,
b: ir.Value,
*,
Expand Down Expand Up @@ -124,15 +126,24 @@ def mma(
descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
)

if m_tiling != 128:
raise ValueError(f"A must have rows tiled by 128, got: {m_tiling}")
# TODO(apaszke): It's enough to make this a multiple of d.num_rows, but it
# would need more code below.
if m_tiling != d.num_rows:
raise ValueError(
f"A's row tiling must be a multiple of {d.num_rows} (inferred from"
f" accumulator's TMEM layout), got: {m_tiling}"
)

a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type)

groups_k = k // kn_tiling
groups_m = m // m_tiling

# TODO(apaszke): Verify ACC shape.
if d.shape != (m, n):
raise ValueError(
f"Accumulator shape mismatch: expected {(m, n)}, got {d.shape}"
)

i64 = ir.IntegerType.get_signless(64)
for mi in range(groups_m):
Expand All @@ -142,8 +153,10 @@ def mma(
utils.c(_wgmma.wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
)
b_k = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(ki * b_k_byte_stride), i64))
if groups_m != 1:
raise NotImplementedError("D needs to be sliced")
accumulate = _do_mma(
d,
d.address,
a_mk,
b_k,
d_type=ir.F32Type.get(),
Expand Down

0 comments on commit 026b6c9

Please sign in to comment.