MoonMath AI team has released a bf16 forward attention kernel for AMD’s MI300X GPU. It is written in HIP, not hand-written assembly. The code is open-source under the MIT license. The MoonMath.ai team reports it beats AITER v3, AMD’s own optimized kernel, on every tested shape. Bare-metal access came from HotAisle, an AMD cloud provider.
Attention is the fused softmax(QKᵀ/√d)·V operation inside every transformer. The MI300X is AMD’s CDNA3 data-center GPU, with the ISA target (gfx942). This kernel runs on that hardware only.
TL;DR
- MoonMath.ai open-sources a bf16 forward attention kernel for AMD MI300X, written in HIP, not assembly (MIT).
- It beats AMD’s AITER v3 on every shape and rounding mode — geomean 1.18×/1.15×/1.08×, up to 1.26×.
- The core trick: one-instruction asm wrappers let you pick the opcode while the compiler allocates registers.
- Most of the speedup is memory placement — K in LDS, V hot in L1, Q and accumulators in registers.
- A real SGLang PR used it to speed up Wan2.1 video diffusion by 1.23×, with no quality regression.
Understanding Kernel
A kernel is a small program that runs directly on the GPU’s many cores to perform one specific computation—here, the attention math—as fast as the hardware allows. The kernel computes forward attention in bf16 on MI300X only. It takes inputs in either BSHD or BHSD layout, with no transpose. Head dimension is fixed at 128. It supports any sequence length, including cross-attention.
There are real limits. There is no causal mask, no GQA, and no varlen batching. Outputs are bf16, and it runs on gfx942 hardware exclusively.
Numerics are tightly controlled. All three rounding modes match AITER’s per-mode rounding rule. Every finite output sits within 1 bf16 ULP of AITER. NaN and Inf handling is bit-identical, and results are deterministic.
The Core Trick: One-Instruction asm Wrappers
The core technique avoids a familiar dilemma. Compiler intrinsics keep code tidy but let the compiler reorder or rename operands. Raw inline assembly gives control but forces manual register and address management.
MoonMath wraps exactly one instruction in a __device__ __forceinline__ function. Extended asm constraints describe the operands. The research team picks the opcode. The compiler still allocates registers and tracks data flow.
// in/out tied to the SAME VGPR → no accumulator rename, no v_mov copy. __device__ __forceinline__ void asm_mfma(bf16x4_t a, bf16x4_t b, fp32x4_t& c) { asm volatile("v_mfma_f32_16x16x16_bf16 %0, %1, %2, %0" : "+v"(c) : "v"(a), "v"(b)); }
The "+v"(c) constraint ties the accumulator input and output to the same VGPR. No copy instruction is emitted. This keeps the kernel close to ordinary HIP. It still steers the machine one instruction at a time.
The Architecture: Eight Waves, Two Groups, Two Barriers
A CDNA3 compute unit has four SIMD units. The textbook block is four waves. MoonMath instead runs eight waves per block, in two groups of four.
The two groups run the same Q*K, softmax, O += P*V sequence. They are offset by a phase. While one group saturates the matrix core, the other runs softmax and issues loads. Then they swap, so the matrix core never idles.
There are two s_barriers per iteration. One sits at the phase handoff. One sits at the iteration boundary. Per-counter waits handle the rest of the synchronization.
This echoes FlashAttention-3’s matmul and softmax alternation. It does not copy FA3’s producer and consumer warp split. On CDNA3, every memory move is already asynchronous, so a dedicated producer wave is unnecessary.
Where Data Lives, and Why 16×16×16
Most of the speedup comes from memory placement. K streams from HBM into LDS, double-buffered, shared by all eight waves. V stays hot in L1, read on every PV matmul. Q and accumulators live in registers.
The research team picked the 16×16×16 MFMA over 32×32×8. Both shapes have identical throughput. The smaller tile accumulates into 4 fp32 elements per lane, against 16. Lower accumulator pressure leaves room for deeper prefetch and a third Q tile.
| Decision | Choice | Reason |
|---|---|---|
| Waves per block | 8 (two groups of 4) | Plan the pipeline directly; share one K copy |
| MFMA shape | 16×16×16 bf16 | Same throughput, lower VGPR pressure, better power efficiency |
| K placement | LDS, double-buffered, 32 KiB | Shared by all 8 waves, swapped per iteration |
| V placement | L1, resident, prefetched | Reread across PV, kept hot deliberately |
| Q + accumulators | VGPRs | Read every iteration, never reloaded |
Two later wins close the gap. A third Q tile (3Q) raises data reuse per loaded K and V tile. A Flash-Decoding-style tail KV split rescues the stranded fractional round across MI300X’s 304 CUs. These wins cascade. Moving V to L1 freed the LDS that the third Q tile then fills.
Benchmark
Tests ran on MI300X in bf16, head dimension 128. Each shape was measured at three rounding modes. RTNE rounds to nearest even. RTNA rounds to nearest, ties away from zero. RTZ truncates toward zero.
| Shape (B, H, S, D) | Round | Ours (ms) | AITER v3 (ms) | vs AITER | vs MAX |
|---|---|---|---|---|---|
| (2, 24, 8192, 128) | RTNE | 3.083 | 3.792 | 1.23× | 1.37× |
| (2, 24, 16384, 128) | RTNE | 11.670 | 14.691 | 1.26× | 1.54× |
| (4, 16, 16384, 128) | RTZ | 15.055 | 16.183 | 1.07× | 1.47× |
| (2, 24, 32768, 128) | RTNA | 44.440 | 52.363 | 1.18× | 1.57× |
| (1, 16, 131072, 128) | RTNE | 232.517 | 269.278 | 1.16× | 1.46× |
Geomeans across the sweep favor MoonMath. Versus AITER, it scores 1.18× (RTNE), 1.15× (RTNA), and 1.08× (RTZ). Versus Modular MAX, geomeans run 1.44× to 1.49×, and per-shape speedups reach 1.59×.
RTZ is AITER’s own fastest mode and the tightest race. The (4, 16, 16384) RTZ shape moved from 0.95× to 1.07×. The tail KV split is what closed that final gap.
Interactive Explainer
Use Cases
The kernel installs with pip and exposes a small API. It launches on the caller’s stream, so it overlaps inside larger pipelines.
import torch import moonmath_attention as ma # PyTorch's ROCm build uses the "cuda" device string on AMD GPUs q = torch.randn(2, 8192, 24, 128, dtype=torch.bfloat16, device="cuda") k = torch.randn(2, 8192, 24, 128, dtype=torch.bfloat16, device="cuda") v = torch.randn(2, 8192, 24, 128, dtype=torch.bfloat16, device="cuda") out = ma.forward(q, k, v, layout="bshd") out_rtz = ma.forward(q, k, v, layout="bshd", round_mode="rtz")
One concrete use case is video diffusion. The team added LiteAttention support and sent a PR to SGLang diffusion. On Wan2.1-T2V-1.3B-Diffusers, they switched attention from AITER to liteattention_rocm. End-to-end generation improved by 1.23× on MI300X, with no visible quality regression.
The BSHD layout suits diffusion tensors directly. Cross-attention works with any KV length and no padding.
Key Takeaways
- The kernel is bf16 forward attention for MI300X, written in HIP under MIT.
- It beats AITER v3 on every shape and rounding mode, geomean 1.18×/1.15×/1.08×.
- One-instruction asm wrappers give opcode control while the compiler allocates registers.
- Memory placement drove most of the gain: K in LDS, V hot in L1, Q in registers.
- A real SGLang PR sped up Wan2.1 video diffusion by 1.23× with no quality regression.
Check out the Technical details. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
