Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/qwen3/ark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

All ops (QKV projection, QK-norm, RoPE, attention, output projection) use
torch. ARK ops (``ark_rmsnorm``, ``precompute_ark_rope_freqs``) are kept
dormant for re-enablement after the upstream composed-graph fix lands (Q6).
dormant for re-enablement after the upstream composed-graph fix lands.
TODO(upstream): re-enable after ark planner composed-graph fix.
"""

import math
Expand Down Expand Up @@ -138,12 +139,12 @@ def torch_rmsnorm(x, weight, eps):
eps: epsilon for numerical stability.

Returns:
``torch.Tensor`` (fp16) with the same shape as *x*.
``torch.Tensor`` with the same shape and dtype as *x*.
"""
x_f32 = x.float()
rms = torch.sqrt(x_f32.pow(2).mean(dim=-1, keepdim=True) + eps)
x_normed = x_f32 / rms
return (x_normed * weight.float()).half()
return (x_normed * weight.float()).to(x.dtype)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -206,7 +207,7 @@ def ark_gqa_attention(
# ---- Mask + softmax (torch) ----
if mask is not None:
scores = scores + mask
attn_w = torch.softmax(scores.float(), dim=-1).half()
attn_w = torch.softmax(scores.float(), dim=-1).to(x.dtype)

# ---- Stage 5: Weighted sum (torch matmul) ----
out = torch.matmul(attn_w, v)
Expand Down
116 changes: 116 additions & 0 deletions examples/qwen3/ark_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Qwen3 SwiGLU MLP: torch matmul + torch silu·gate fallback.

Computes: down_proj(SiLU(gate_proj(x)) * up_proj(x))

All matmul uses torch.matmul (full-ARK matmul deferred to Q10).
SiLU·gate fusion uses torch ops (F.silu(gate) * up) because the upstream
ARK composed-graph planner bug crashes at intermediate_dim=12288 (same
bug class as the 4-D shape crash documented in Q4).

The ARK path (``ark_silu_gate``) is retained dormant for re-enablement
after the upstream fix lands.
"""

import torch
import torch.nn.functional as F

import ark

from .qwen3_config import Qwen3Config

# ---------------------------------------------------------------------------
# SiLU·gate implementations
# ---------------------------------------------------------------------------


def ark_silu_gate(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""SiLU(gate) * up using ARK primitives.

NOTE: Dormant — crashes with the upstream ARK composed-graph planner
bug at shapes like (2048, 12288). Kept for re-enablement after
the upstream fix lands.

Args:
gate: (N, intermediate_dim) fp16 tensor on CUDA.
up: (N, intermediate_dim) fp16 tensor on CUDA.

Returns:
(N, intermediate_dim) fp16 tensor.
"""
ark.init()
# SiLU(x) = x * sigmoid(x)
sig = ark.sigmoid(gate)
silu = ark.mul(gate, sig)
result = ark.mul(silu, up)
return result.eval()


def torch_silu_gate(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""SiLU(gate) * up using pure torch ops.

Replaces ``ark_silu_gate`` which crashes with the upstream
composed-graph planner bug at intermediate_dim=12288.

Args:
gate: (N, intermediate_dim) fp16 tensor on CUDA.
up: (N, intermediate_dim) fp16 tensor on CUDA.

Returns:
(N, intermediate_dim) fp16 tensor.
"""
return F.silu(gate) * up


# ---------------------------------------------------------------------------
# Full SwiGLU MLP
# ---------------------------------------------------------------------------


def ark_swiglu_mlp(
x: torch.Tensor,
gate_w: torch.Tensor,
up_w: torch.Tensor,
down_w: torch.Tensor,
cfg: Qwen3Config,
) -> torch.Tensor:
"""SwiGLU MLP: down_proj(SiLU(gate_proj(x)) * up_proj(x)).

All weight/input arguments are torch tensors on CUDA.
Matmul uses torch.matmul; silu·gate uses torch fallback.

Args:
x: (B, S, hidden_dim) fp16 input tensor.
gate_w: (intermediate_dim, hidden_dim) gate projection weight.
up_w: (intermediate_dim, hidden_dim) up projection weight.
down_w: (hidden_dim, intermediate_dim) down projection weight.
cfg: Qwen3Config instance.

Returns:
(B, S, hidden_dim) fp16 output tensor wrapped in ark.copy
for .eval() API consistency.
"""
orig_shape = x.shape # (B, S, hidden_dim)
batch_seq = orig_shape[0] * orig_shape[1]

# Flatten to 2D for matmul
x_2d = x.reshape(batch_seq, cfg.hidden_dim)

# Gate and up projections (torch matmul)
gate = torch.matmul(x_2d, gate_w.t()) # (B*S, intermediate_dim)
up = torch.matmul(x_2d, up_w.t()) # (B*S, intermediate_dim)

# SiLU·gate fusion (torch fallback — ARK crashes at intermediate_dim=12288)
hidden = torch_silu_gate(gate, up) # (B*S, intermediate_dim)

# Down projection (torch matmul)
out_2d = torch.matmul(hidden, down_w.t()) # (B*S, hidden_dim)

# Reshape back to (B, S, hidden_dim)
result = out_2d.reshape(orig_shape)

# Wrap as trivial ARK graph so callers can use .eval()
ark.init()
return ark.copy(result)
83 changes: 83 additions & 0 deletions examples/qwen3/bench_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Microbenchmark: ARK SwiGLU MLP vs torch eager SwiGLUMLP.

Torch-only pipeline (matmul + F.silu·gate). ARK silu·gate deferred to
upstream fix (same composed-graph planner bug class as Q4).

Shapes: S=2048 (prefill) and S=1 (decode) at Qwen3-8B dimensions.
Run out-of-band on A100: ``python -m examples.qwen3.bench_mlp``
"""

import torch

from .qwen3_config import Qwen3Config
from .qwen3_ref import SwiGLUMLP
from .ark_mlp import ark_swiglu_mlp
from .microbench import microbench

# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------


def _run(seq_len, label):
cfg = Qwen3Config() # 8B defaults
torch.manual_seed(42)
mlp = SwiGLUMLP(cfg).cuda().half()

B = 1
x = torch.randn(
B, seq_len, cfg.hidden_dim, device="cuda", dtype=torch.float16
)

# --- Torch eager ---
def run_torch():
with torch.no_grad():
mlp(x)

torch_res = microbench(
run_torch,
use_cuda_graph=False,
flush_l2=False,
)

# --- ARK (torch-only fallback) ---
gate_w = mlp.gate_proj.weight.detach()
up_w = mlp.up_proj.weight.detach()
down_w = mlp.down_proj.weight.detach()

def run_ark():
with torch.no_grad():
ark_swiglu_mlp(x, gate_w, up_w, down_w, cfg).eval()

ark_res = microbench(
run_ark,
use_cuda_graph=False,
flush_l2=False,
)

return label, torch_res, ark_res


def main():
print("NOTE: torch-only (ARK silu·gate deferred to upstream fix / Q10).")
print(
f"{'Shape':<20} {'Torch (us)':>16} {'ARK-wrap (us)':>20} {'Speedup':>10}"
)
print("-" * 70)
for seq, label in [(2048, "prefill S=2048"), (1, "decode S=1")]:
name, t, a = _run(seq, label)
sp = t["mean_us"] / a["mean_us"] if a["mean_us"] > 0 else float("nan")
print(
f"{name:<20} "
f"{t['mean_us']:>10.1f} ± {t['std_us']:<5.1f}"
f"{a['mean_us']:>14.1f} ± {a['std_us']:<5.1f}"
f"{sp:>8.2f}x"
)


if __name__ == "__main__":
main()
11 changes: 5 additions & 6 deletions examples/qwen3/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_qk_norm():
weight = norm.weight.detach().half().cuda()
out = torch_rmsnorm(x, weight, 1e-6)

assert_close(out, ref, atol=1e-6, rtol=1e-6, msg="QK-norm mismatch")
# fp16 precision; tight enough to catch real regressions
assert_close(out, ref, atol=1e-3, rtol=1e-3, msg="QK-norm mismatch")


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -264,7 +265,7 @@ def _mha_cfg():
)


def _run_attention_equivalence(cfg, B, S, seed_offset=10, mask="causal"):
def _run_attention_equivalence(cfg, B, S, seed_offset=10, causal=True):
"""Run ARK vs torch attention equivalence for given cfg, B, S."""
from .ark_attention import ark_gqa_attention, precompute_torch_rope_freqs
from .equiv import assert_close
Expand All @@ -276,9 +277,7 @@ def _run_attention_equivalence(cfg, B, S, seed_offset=10, mask="causal"):

torch.manual_seed(_SEED + seed_offset)
x = torch.randn(B, S, cfg.hidden_dim, device="cuda", dtype=torch.float16)
if mask == "causal":
mask = _causal_mask(S, "cuda", torch.float16)
# else mask is already None or a user-supplied tensor
mask = _causal_mask(S, "cuda", torch.float16) if causal else None

with torch.no_grad():
ref = attn(x, rope_freqs, mask)
Expand Down Expand Up @@ -323,7 +322,7 @@ def test_attention_mha():
def test_attention_no_mask():
"""ARK attention matches torch with mask=None (no causal mask)."""
_run_attention_equivalence(
_small_cfg(), B=1, S=16, seed_offset=23, mask=None
_small_cfg(), B=1, S=16, seed_offset=23, causal=False
)


Expand Down
Loading