diff --git a/examples/qwen3/ark_attention.py b/examples/qwen3/ark_attention.py index 6616fc31..ee48a65d 100644 --- a/examples/qwen3/ark_attention.py +++ b/examples/qwen3/ark_attention.py @@ -5,7 +5,8 @@ All ops (QKV projection, QK-norm, RoPE, attention, output projection) use torch. ARK ops (``ark_rmsnorm``, ``precompute_ark_rope_freqs``) are kept -dormant for re-enablement after the upstream composed-graph fix lands (Q6). +dormant for re-enablement after the upstream composed-graph fix lands. +TODO(upstream): re-enable after ark planner composed-graph fix. """ import math @@ -138,12 +139,12 @@ def torch_rmsnorm(x, weight, eps): eps: epsilon for numerical stability. Returns: - ``torch.Tensor`` (fp16) with the same shape as *x*. + ``torch.Tensor`` with the same shape and dtype as *x*. """ x_f32 = x.float() rms = torch.sqrt(x_f32.pow(2).mean(dim=-1, keepdim=True) + eps) x_normed = x_f32 / rms - return (x_normed * weight.float()).half() + return (x_normed * weight.float()).to(x.dtype) # --------------------------------------------------------------------------- @@ -206,7 +207,7 @@ def ark_gqa_attention( # ---- Mask + softmax (torch) ---- if mask is not None: scores = scores + mask - attn_w = torch.softmax(scores.float(), dim=-1).half() + attn_w = torch.softmax(scores.float(), dim=-1).to(x.dtype) # ---- Stage 5: Weighted sum (torch matmul) ---- out = torch.matmul(attn_w, v) diff --git a/examples/qwen3/ark_mlp.py b/examples/qwen3/ark_mlp.py new file mode 100644 index 00000000..fd8499c2 --- /dev/null +++ b/examples/qwen3/ark_mlp.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Qwen3 SwiGLU MLP: torch matmul + torch silu·gate fallback. + +Computes: down_proj(SiLU(gate_proj(x)) * up_proj(x)) + +All matmul uses torch.matmul (full-ARK matmul deferred to Q10). +SiLU·gate fusion uses torch ops (F.silu(gate) * up) because the upstream +ARK composed-graph planner bug crashes at intermediate_dim=12288 (same +bug class as the 4-D shape crash documented in Q4). + +The ARK path (``ark_silu_gate``) is retained dormant for re-enablement +after the upstream fix lands. +""" + +import torch +import torch.nn.functional as F + +import ark + +from .qwen3_config import Qwen3Config + +# --------------------------------------------------------------------------- +# SiLU·gate implementations +# --------------------------------------------------------------------------- + + +def ark_silu_gate(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """SiLU(gate) * up using ARK primitives. + + NOTE: Dormant — crashes with the upstream ARK composed-graph planner + bug at shapes like (2048, 12288). Kept for re-enablement after + the upstream fix lands. + + Args: + gate: (N, intermediate_dim) fp16 tensor on CUDA. + up: (N, intermediate_dim) fp16 tensor on CUDA. + + Returns: + (N, intermediate_dim) fp16 tensor. + """ + ark.init() + # SiLU(x) = x * sigmoid(x) + sig = ark.sigmoid(gate) + silu = ark.mul(gate, sig) + result = ark.mul(silu, up) + return result.eval() + + +def torch_silu_gate(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """SiLU(gate) * up using pure torch ops. + + Replaces ``ark_silu_gate`` which crashes with the upstream + composed-graph planner bug at intermediate_dim=12288. + + Args: + gate: (N, intermediate_dim) fp16 tensor on CUDA. + up: (N, intermediate_dim) fp16 tensor on CUDA. + + Returns: + (N, intermediate_dim) fp16 tensor. + """ + return F.silu(gate) * up + + +# --------------------------------------------------------------------------- +# Full SwiGLU MLP +# --------------------------------------------------------------------------- + + +def ark_swiglu_mlp( + x: torch.Tensor, + gate_w: torch.Tensor, + up_w: torch.Tensor, + down_w: torch.Tensor, + cfg: Qwen3Config, +) -> torch.Tensor: + """SwiGLU MLP: down_proj(SiLU(gate_proj(x)) * up_proj(x)). + + All weight/input arguments are torch tensors on CUDA. + Matmul uses torch.matmul; silu·gate uses torch fallback. + + Args: + x: (B, S, hidden_dim) fp16 input tensor. + gate_w: (intermediate_dim, hidden_dim) gate projection weight. + up_w: (intermediate_dim, hidden_dim) up projection weight. + down_w: (hidden_dim, intermediate_dim) down projection weight. + cfg: Qwen3Config instance. + + Returns: + (B, S, hidden_dim) fp16 output tensor wrapped in ark.copy + for .eval() API consistency. + """ + orig_shape = x.shape # (B, S, hidden_dim) + batch_seq = orig_shape[0] * orig_shape[1] + + # Flatten to 2D for matmul + x_2d = x.reshape(batch_seq, cfg.hidden_dim) + + # Gate and up projections (torch matmul) + gate = torch.matmul(x_2d, gate_w.t()) # (B*S, intermediate_dim) + up = torch.matmul(x_2d, up_w.t()) # (B*S, intermediate_dim) + + # SiLU·gate fusion (torch fallback — ARK crashes at intermediate_dim=12288) + hidden = torch_silu_gate(gate, up) # (B*S, intermediate_dim) + + # Down projection (torch matmul) + out_2d = torch.matmul(hidden, down_w.t()) # (B*S, hidden_dim) + + # Reshape back to (B, S, hidden_dim) + result = out_2d.reshape(orig_shape) + + # Wrap as trivial ARK graph so callers can use .eval() + ark.init() + return ark.copy(result) diff --git a/examples/qwen3/bench_mlp.py b/examples/qwen3/bench_mlp.py new file mode 100644 index 00000000..cb42ab3d --- /dev/null +++ b/examples/qwen3/bench_mlp.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Microbenchmark: ARK SwiGLU MLP vs torch eager SwiGLUMLP. + +Torch-only pipeline (matmul + F.silu·gate). ARK silu·gate deferred to +upstream fix (same composed-graph planner bug class as Q4). + +Shapes: S=2048 (prefill) and S=1 (decode) at Qwen3-8B dimensions. +Run out-of-band on A100: ``python -m examples.qwen3.bench_mlp`` +""" + +import torch + +from .qwen3_config import Qwen3Config +from .qwen3_ref import SwiGLUMLP +from .ark_mlp import ark_swiglu_mlp +from .microbench import microbench + +# --------------------------------------------------------------------------- +# Benchmark +# --------------------------------------------------------------------------- + + +def _run(seq_len, label): + cfg = Qwen3Config() # 8B defaults + torch.manual_seed(42) + mlp = SwiGLUMLP(cfg).cuda().half() + + B = 1 + x = torch.randn( + B, seq_len, cfg.hidden_dim, device="cuda", dtype=torch.float16 + ) + + # --- Torch eager --- + def run_torch(): + with torch.no_grad(): + mlp(x) + + torch_res = microbench( + run_torch, + use_cuda_graph=False, + flush_l2=False, + ) + + # --- ARK (torch-only fallback) --- + gate_w = mlp.gate_proj.weight.detach() + up_w = mlp.up_proj.weight.detach() + down_w = mlp.down_proj.weight.detach() + + def run_ark(): + with torch.no_grad(): + ark_swiglu_mlp(x, gate_w, up_w, down_w, cfg).eval() + + ark_res = microbench( + run_ark, + use_cuda_graph=False, + flush_l2=False, + ) + + return label, torch_res, ark_res + + +def main(): + print("NOTE: torch-only (ARK silu·gate deferred to upstream fix / Q10).") + print( + f"{'Shape':<20} {'Torch (us)':>16} {'ARK-wrap (us)':>20} {'Speedup':>10}" + ) + print("-" * 70) + for seq, label in [(2048, "prefill S=2048"), (1, "decode S=1")]: + name, t, a = _run(seq, label) + sp = t["mean_us"] / a["mean_us"] if a["mean_us"] > 0 else float("nan") + print( + f"{name:<20} " + f"{t['mean_us']:>10.1f} ± {t['std_us']:<5.1f}" + f"{a['mean_us']:>14.1f} ± {a['std_us']:<5.1f}" + f"{sp:>8.2f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/qwen3/test_attention.py b/examples/qwen3/test_attention.py index 0abdab28..29798509 100644 --- a/examples/qwen3/test_attention.py +++ b/examples/qwen3/test_attention.py @@ -78,7 +78,8 @@ def test_qk_norm(): weight = norm.weight.detach().half().cuda() out = torch_rmsnorm(x, weight, 1e-6) - assert_close(out, ref, atol=1e-6, rtol=1e-6, msg="QK-norm mismatch") + # fp16 precision; tight enough to catch real regressions + assert_close(out, ref, atol=1e-3, rtol=1e-3, msg="QK-norm mismatch") # --------------------------------------------------------------------------- @@ -264,7 +265,7 @@ def _mha_cfg(): ) -def _run_attention_equivalence(cfg, B, S, seed_offset=10, mask="causal"): +def _run_attention_equivalence(cfg, B, S, seed_offset=10, causal=True): """Run ARK vs torch attention equivalence for given cfg, B, S.""" from .ark_attention import ark_gqa_attention, precompute_torch_rope_freqs from .equiv import assert_close @@ -276,9 +277,7 @@ def _run_attention_equivalence(cfg, B, S, seed_offset=10, mask="causal"): torch.manual_seed(_SEED + seed_offset) x = torch.randn(B, S, cfg.hidden_dim, device="cuda", dtype=torch.float16) - if mask == "causal": - mask = _causal_mask(S, "cuda", torch.float16) - # else mask is already None or a user-supplied tensor + mask = _causal_mask(S, "cuda", torch.float16) if causal else None with torch.no_grad(): ref = attn(x, rope_freqs, mask) @@ -323,7 +322,7 @@ def test_attention_mha(): def test_attention_no_mask(): """ARK attention matches torch with mask=None (no causal mask).""" _run_attention_equivalence( - _small_cfg(), B=1, S=16, seed_offset=23, mask=None + _small_cfg(), B=1, S=16, seed_offset=23, causal=False ) diff --git a/examples/qwen3/test_mlp.py b/examples/qwen3/test_mlp.py new file mode 100644 index 00000000..50467b4e --- /dev/null +++ b/examples/qwen3/test_mlp.py @@ -0,0 +1,198 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Equivalence tests: ARK SwiGLU MLP vs torch reference. + +All GPU tests are gated with ``skipif(not cuda)``. +""" + +import subprocess +import sys +import os + +import pytest +import torch + +_CUDA = torch.cuda.is_available() +requires_cuda = pytest.mark.skipif(not _CUDA, reason="CUDA not available") + +_SEED = 42 + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _small_cfg(): + """Return a small Qwen3Config suitable for unit tests.""" + from .qwen3_config import Qwen3Config + + return Qwen3Config( + n_layers=1, + hidden_dim=128, + n_q_heads=4, + n_kv_heads=2, + head_dim=32, + intermediate_dim=256, + rms_norm_eps=1e-6, + rope_theta=1e6, + max_seq_len=256, + ) + + +def _build_ref_mlp(cfg): + """Instantiate a torch SwiGLUMLP with fixed seed on CUDA.""" + from .qwen3_ref import SwiGLUMLP + + torch.manual_seed(_SEED) + return SwiGLUMLP(cfg).cuda().half().eval() + + +def _run_mlp_equivalence(cfg, B, S, seed_offset=0): + """Run ARK vs torch MLP equivalence for given cfg, B, S.""" + from .ark_mlp import ark_swiglu_mlp + from .equiv import assert_close + + mlp = _build_ref_mlp(cfg) + + torch.manual_seed(_SEED + seed_offset) + x = torch.randn(B, S, cfg.hidden_dim, device="cuda", dtype=torch.float16) + + with torch.no_grad(): + ref = mlp(x) + + with torch.no_grad(): + ark_out = ark_swiglu_mlp( + x, + mlp.gate_proj.weight.detach(), + mlp.up_proj.weight.detach(), + mlp.down_proj.weight.detach(), + cfg, + ).eval() + + assert ark_out.shape == ref.shape + assert_close(ark_out, ref, atol=5e-3, rtol=5e-3, msg=f"MLP B={B} S={S}") + + +# --------------------------------------------------------------------------- +# Intermediate check: SiLU·gate +# --------------------------------------------------------------------------- + + +@requires_cuda +def test_silu_gate(): + """torch_silu_gate matches F.silu(gate) * up.""" + import torch.nn.functional as F + + from .ark_mlp import torch_silu_gate + from .equiv import assert_close + + torch.manual_seed(_SEED) + gate = torch.randn(64, 256, device="cuda", dtype=torch.float16) + up = torch.randn(64, 256, device="cuda", dtype=torch.float16) + + ref = F.silu(gate) * up + out = torch_silu_gate(gate, up) + + assert_close(out, ref, atol=1e-6, rtol=1e-6, msg="SiLU·gate mismatch") + + +# --------------------------------------------------------------------------- +# Full MLP equivalence tests +# --------------------------------------------------------------------------- + + +@requires_cuda +def test_mlp_small(): + """ARK MLP matches SwiGLUMLP at B=1, S=16 (small shape).""" + _run_mlp_equivalence(_small_cfg(), B=1, S=16, seed_offset=10) + + +@requires_cuda +def test_mlp_prefill(): + """ARK MLP matches SwiGLUMLP at B=1, S=128 (prefill shape).""" + _run_mlp_equivalence(_small_cfg(), B=1, S=128, seed_offset=11) + + +@requires_cuda +def test_mlp_decode(): + """ARK MLP matches SwiGLUMLP at B=1, S=1 (decode step).""" + _run_mlp_equivalence(_small_cfg(), B=1, S=1, seed_offset=12) + + +@requires_cuda +def test_mlp_batch(): + """ARK MLP matches SwiGLUMLP at B=2, S=16 (multi-batch).""" + _run_mlp_equivalence(_small_cfg(), B=2, S=16, seed_offset=13) + + +# --------------------------------------------------------------------------- +# Output shape and dtype +# --------------------------------------------------------------------------- + + +@requires_cuda +def test_mlp_output_shape(): + """ARK MLP output has correct shape and dtype.""" + from .ark_mlp import ark_swiglu_mlp + + cfg = _small_cfg() + mlp = _build_ref_mlp(cfg) + + B, S = 2, 32 + torch.manual_seed(_SEED + 20) + x = torch.randn(B, S, cfg.hidden_dim, device="cuda", dtype=torch.float16) + + with torch.no_grad(): + out = ark_swiglu_mlp( + x, + mlp.gate_proj.weight.detach(), + mlp.up_proj.weight.detach(), + mlp.down_proj.weight.detach(), + cfg, + ).eval() + + assert out.shape == (B, S, cfg.hidden_dim) + assert out.dtype == torch.float16 + + +# --------------------------------------------------------------------------- +# xfail: ARK silu·gate at (2048, 12288) — upstream ARK planner bug +# --------------------------------------------------------------------------- + + +@requires_cuda +@pytest.mark.xfail( + reason="ARK planner bug: composed graph crashes at (2048, 12288)", + strict=False, +) +def test_ark_silu_gate_large_xfail(): + """Document that ark_silu_gate crashes at Qwen3-8B intermediate_dim. + + Same class of upstream ARK composed-graph planner bug as Q4's + 4-D shape crash. Runs in a subprocess to avoid poisoning the + CUDA context. + """ + script = ( + "import torch, ark\n" + "torch.manual_seed(42)\n" + "gate = torch.randn(2048, 12288, device='cuda', dtype=torch.float16)\n" + "up = torch.randn(2048, 12288, device='cuda', dtype=torch.float16)\n" + "ark.init()\n" + "sig = ark.sigmoid(gate)\n" + "silu = ark.mul(gate, sig)\n" + "result = ark.mul(silu, up)\n" + "out = result.eval()\n" + "assert out.shape == (2048, 12288)\n" + ) + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=120, + env=os.environ.copy(), + ) + assert ( + result.returncode == 0 + ), f"Subprocess exited {result.returncode}: {result.stderr[-500:]}"