Metal GPU + CoreML/ANE support for transformer nets (SiLU, GQA, learnable RoPE)#1205
Open
ChinChangYang wants to merge 12 commits into
Open
Metal GPU + CoreML/ANE support for transformer nets (SiLU, GQA, learnable RoPE)#1205ChinChangYang wants to merge 12 commits into
ChinChangYang wants to merge 12 commits into
Conversation
…(v15) Implement the LLaMA-style transformer-hybrid forward pass (RMSNorm, multi-head attention with learnable 2D RoPE, SwiGLU FFN) plus ACTIVATION_SILU across the Metal GPU (MPSGraph) and CoreML/ANE (MIL) backends, so the v15 b10c384h6nbttflrs model runs end-to-end. Metal GPU (MPSGraph) — verified via testgpuerror vs Eigen reference at sizes 9/13/19 (winrate error ~0.0001%, well under threshold): - metallayers.swift: TransformerRMSNormLayer, TrunkRMSNormLayer, TransformerAttentionBlock, TransformerFFNBlock, silu() activation, SWTransformer*/SWRMSNorm descriptors; Trunk branches on trunkNormKind - metalbackend.cpp: SILU bridge + transformer/RMSNorm desc bridges, wired into residualBlocksToSwift and trunkDescToSwift CoreML/ANE (katagocoreml MIL) — implemented end-to-end; fp32 model logically correct and consistent across CPU/ANE/GPU. fp16 ANE path is numerically precision-limited (~5%) due to fp16 matmul accumulation in the deep attention stack: - types/parser: ActivationType::Silu, trunk_norm_kind, transformer block kinds 4/5, RMSNorm/attention/FFN descriptors - MILBuilder: addSiluOps, RMSNorm ops, transformer attention/FFN blocks. Fixes 4 CoreML bugs: reshape-after-transpose, fp16 mask overflow, fp16 RMSNorm reduce_sum overflow (reduce_mean) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
# Conflicts: # cpp/external/katagocoreml/src/parser/KataGoParser.cpp
GQA models (numKVHeads != numHeads, e.g. b7c96h6kv3qk32v16tflrs) crashed on the Metal GPU path with: MPSNDArray.mm: NDArray dimension length > INT_MAX repeatKVHeads expanded the KV heads via reshape -> broadcast -> reshape, passing -1 for the batch dim in the broadcast target shape. Unlike reshape, MPSGraph.broadcast(_:shape:) does not infer -1 and treats it as a literal (near-INT_MAX) dimension, tripping the NDArray assertion. Replace the broadcast with a shape-safe slice + concat: slice each KV head (dim 1) and concatenate groupSize copies consecutively, so query head h uses kv = h / groupSize, matching the Eigen reference (kvh = h / kvGroupSize). No -1 broadcast. Verified: testgpuerror GPU vs Eigen reference at 9/13/19 now passes (~0.00003% winrate); non-GQA models (incl. b10c384h6) unaffected since the GQA branch is gated on numKVHeads != numHeads. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The MIL builder's inline activation dispatch (buildValueHead v2, policy-head pass activation, and both SGF metadata encoder layers) handled only ReLU and Mish; SiLU silently fell through to the else branch and applied NO activation at all. This corrupted the value-head pool -> v2 -> v3 scalar path for every SiLU model, producing large errors in winrate/score/lead while ownership (which branches off v1, before v2) stayed correct. Add an ActivationType::Silu branch (addSiluOps) at all four sites. The generic conv/BN activation path already handled SiLU, which is why the trunk and v1/ownership were fine. Root-caused via systematic debugging: CoreML-CPU(fp32) error was identical to ANE (-> logical bug, not fp16), and perfect ownership with wrong scalars localized it to the value-head post-pooling path. This corrects the earlier "ANE is fp16-precision-limited (~5%)" conclusion -- that 5.66% on b10c384h6 was this bug. After the fix, testgpuerror ANE vs Eigen drops to GPU-level accuracy for all models: b10c384h6 5.66% -> ~0.00005-0.0002% cnorm 11-13% -> ~0.00007% rsnh 22-29% -> ~0.00004-0.0001% Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The CoreML MIL builder threw "GQA (numKVHeads != numHeads) not supported" for grouped-query-attention transformer models, while the Metal GPU (MPSGraph) path already handled GQA. Port that support to the MIL builder. In buildTransformerAttentionBlock, remove the throw guard and, after the RoPE block and before the scores matmul, repeat each KV head groupSize (= numHeads/numKVHeads) times along the head axis via slice_by_size + concat (interleave=false), so query head h consumes kv head h/groupSize. This matches the Eigen reference (kvh = h/kvGroupSize) and the GPU repeatKVHeads ordering. RoPE stays before the repeat (its cos/sin tables are numKVHeads-shaped). The block is gated by numKVHeads != numHeads, so the standard MHA path is unchanged. Verified on b7c96h6kv3qk32v16tflrs-fson-bnh (6 query / 3 KV heads, qk32/v16) vs Eigen reference: ANE testgpuerror 9/13/19 = 0.00002-0.00003% winrate (previously a hard throw); GPU unchanged; non-GQA model ANE error identical to pre-change; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Transformer models failed testgpuerror on the CoreML/ANE FP16 path: the ANE accumulates FP16 matmuls AND convs in FP16 (unlike OpenCL/CUDA/TRT, which accumulate in FP32), so wide/deep transformers lose too much precision and miss the thresholds at larger board sizes. BF16 is not an option (no compute path in CoreML: cast op, ArrayFeatureType and MLMultiArray all lack bf16; coremltools confirms FLOAT16/FLOAT32 only). Follow KataGo's FP16 convention (spatial convs FP16, non-spatial FP32), channel-gated for the ANE since every FP32 op runs off the FP16-only ANE: - RMSNorm reduction cores: FP32 in FP16 mode (always). - Non-spatial (FFN/Q-K-V proj/pooling/matmul): FP32 (always). MIL `linear` needs const weight/bias so it can't runtime-cast; only `matmul` is wrapped. - Convs: FP32 only for wide trunks (>= 320ch); narrower keep convs on-ANE. - Narrow trunks (< 256ch) sit on the testgpuerror thresholds and no partial FP32 config passes all board sizes (islands cast back to FP16 leave a noisy FP16 spatial stream); build them fully FP32 (off-ANE, cheap since small). Weights stay FP16-stored via runtime up-casts, except full-FP32 models. Add per-weight FP32 serialization (WeightEntry.is_fp32) so a const declared FP32 inside an otherwise-FP16 model is stored FP32 (fixes the load-time "storage and type have different number of elements" abort and enables the full-FP32 tier). Also fixes addFloatScalarConstOp keying storage off m_use_fp16 instead of the declared m_weight_dtype. Result: all 4 transformer test models (b10c384h6/b4c256h4/b7c96h3/ b7c96h6kv3-GQA) pass testgpuerror on ANE FP16 at sizes 9/13/19; runtests and runnnlayertests pass. All changes gated on m_use_fp16; FP32 mode unchanged. The 256/320 channel thresholds are width heuristics validated on these models. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Three near-identical blocks wrapped global pooling in FP32 (policy head, value head, gpool residual block): cast input/mask up to FP32, flip m_weight_dtype, pool, restore, cast pooled features back to FP16 - with inconsistent save-variable names and one site using castFixed vs addCastOp for the output cast. Extract a single addGlobalPoolingFp32(input, mask, channels, output, valueVariant) helper and a small RAII ScopedFp32 guard for the temporary m_weight_dtype flip. The three call sites become one-liners. Behavior-preserving: same emitted op sequence; testgpuerror output is byte-identical across all precision tiers (partial-FP32 b10c384h6, full-FP32 b7c96h3, non-spatial-FP32 b4c256h4), all 12 transformer gate runs pass, runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The width-keyed precision tiers (commit 3839e52) forced FP32 ops off the FP16-only ANE on plain production convnets, not just transformers. b18c384nbt ran ~2.6x slower on the ANE path (160 vs 416 visits/s) with no accuracy benefit. The dominant cost is the per-block global-pooling FP32 (non-spatial), which breaks the ANE pipeline once per gpool-residual block; conv-FP32 is secondary. Add a recursive blocksContainTransformer() helper and gate all three escalations (full-FP32, non-spatial-FP32, conv-FP32) on transformer-block presence. Convnets now run pure FP16 on the ANE (the long-standing pre-tier path); for transformer models the added "&& hasTransformer" is always true, so their emitted MIL is byte-identical and behavior is unchanged. Verified on the ANE FP16 path: b18c384nbt testgpuerror passes (winrate 99%=0.57%, max=0.87%) and recovers full throughput (424 visits/s); b28c512nbt passes (99%=0.41%); all 4 transformer test models x sizes 9/13/19 pass with numbers byte-identical to before; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…former-silu-b10c384h6
makeRopeTables allocated cosBuf/sinBuf with UnsafeMutablePointer.allocate and handed them to Data(floatsNoCopy:), which uses deallocator: .none, so the buffers were never freed -- a leak on every graph build (per attention block, per board size). Unlike the other floatsNoCopy callers (weights/gamma/beta), which point at C++-descriptor memory that lives for the model's lifetime, these tables have no persistent owner. Switch to managed [Float32] arrays and copy into the Data via Data(buffer:) so MPSGraph owns the bytes -- avoids both the leak and a use-after-free that a naive deallocate() on the no-copy path would cause. Output-neutral: testgpuerror on the GQA + learnable-RoPE model (b7c96h6kv3qk32v16tflrs, board 19) vs Eigen FP32 reference matches to 0.00028% max winrate error over 2247 positions. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The Metal forward pass (metallayers.swift TransformerFFNBlock) only implements the SwiGLU path (SiLU(linear1) * gate). A non-SwiGLU model carries no gate weights, so building the Swift descriptor from the empty linearGate would crash obscurely (or silently misbehave). Eigen (eigenbackend.cpp) and CoreML (katagocoreml MILBuilder) both throw a clear "non-SwiGLU transformer FFN not supported" error in this case; the Metal GPU path had no such guard. Add the matching StringError at the FFN descriptor conversion so all three backends fail loudly and consistently. No behavior change for any current model (all use useSwiGLU=true): the guard sits on an untaken path. Verified the SwiGLU model b10c384h6nbttflrs still passes testgpuerror on both GPU (0.00005% winrate) and ANE (unchanged from baseline); runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The trunk-tip dispatch in metallayers.swift compared trunkNormKind against the literal 1, while the rest of the codebase uses the named constants from desc.h (TRUNK_NORM_KIND_STANDARD/_RMSNORM). Add matching Swift constants and use TRUNK_NORM_KIND_RMSNORM at the comparison site. Pure literal-to-named-constant rename; no behavior change. Verified both branches still pass testgpuerror at GPU-level accuracy: RMSNorm tip (b10c384h6nbttflrs) 0.00005% winrate, BatchNorm tip (b7c96h6kv3 GQA) 0.00003%; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Contributor
Author
|
I have reviewed and approved this. |
ChinChangYang
added a commit
to ChinChangYang/KataGo
that referenced
this pull request
Jun 4, 2026
…into feature/coreml-conversion-memory-levers (PR lightvector#1202) Brings Metal GPU + CoreML/ANE transformer-net support (SiLU, GQA, learnable RoPE, RMSNorm, SwiGLU) onto the CoreML-conversion memory branch (non-owning FloatView weight views, streaming parser, ANE weight release). All four conflicts were in the katagocoreml converter, from one root cause: lightvector#1202 reworked weight registration (non-owning FloatView views, registerOwnedWeight, deleted rvalue overloads) while lightvector#1205 added per-weight FP32 precision tiers (is_fp32). Resolved to keep both: - Operations.{hpp,cpp}: registerWeight stamps entry.is_fp32; keep the non-owning view plus registerOwnedWeight. - WeightSerializer.cpp: keep the FloatView-based count plus per-weight store_fp16 = use_fp16 && !is_fp32. - MILBuilder.cpp addConstOp: keep the emitConstOp refactor and pass is_fp32 = (m_weight_dtype == FLOAT32). A follow-up commit retrofits lightvector#1205's transformer derived consts to lightvector#1202's owned-weight + FP32-marking contract (required for the FP16 ANE path). Verified after the follow-up fix: katago runtests and runnnlayertests pass, and testgpuerror vs fresh Eigen FP32 references passes on convnet + all three transformer nets on both the Metal GPU and CoreML/ANE paths. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ChinChangYang
added a commit
to ChinChangYang/KataGo
that referenced
this pull request
Jun 4, 2026
…contract The transformer attention builder emits four function-local std::vector<float> tensors: RoPE cos/sin tables, the rotation matrix R, and per-head out-projection weight slices. After merging the transformer support onto the FloatView branch, these needed two fixes: 1. Dangling view. lightvector#1202 made WeightEntry::data a non-owning FloatView, so addConstOp registers a view whose backing buffer must outlive serialization. These locals were passed to addConstOp and would dangle once the build function returns (serialization runs afterwards). Route them through addOwnedConstOp so KataGoOps owns the buffer until serialization. (Under lightvector#1205's owning WeightEntry they were copied, so this only surfaces post-merge.) 2. dtype mismatch. emitConstOp declares each const's dtype as m_weight_dtype, but addOwnedConstOp / registerOwnedWeight stored at the global mode (is_fp32 hardcoded false). In an FP16 model these derived consts land in the attention / value-head FP32 sub-region (m_weight_dtype == FLOAT32), so they were declared FP32 but stored FP16. CoreML/ANE then rejects the model at load ("Metadata data type does not match requested type", BNNS error -14), which SIGABRT'd every FP16 ANE transformer. Thread is_fp32 through registerOwnedWeight and have addOwnedConstOp pass is_fp32 = (m_weight_dtype == FLOAT32), mirroring addConstOp so the stored dtype always matches the declared dtype. This also fixes the same latent mismatch for addLinearOp's transposed value-head weights. Verified with testgpuerror against fresh Eigen FP32 references: b7c96h3tfrs and b7c96h6gqa, which previously SIGABRT'd on the FP16 ANE path, now load and match to <0.0005% winrate; convnet ANE output is byte-identical and the Metal GPU path is unchanged. katago runtests and runnnlayertests also pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
4 tasks
ChinChangYang
added a commit
to ChinChangYang/KataGo
that referenced
this pull request
Jun 4, 2026
…contract The transformer attention builder emits four function-local std::vector<float> tensors: RoPE cos/sin tables, the rotation matrix R, and per-head out-projection weight slices. After merging the transformer support onto the FloatView branch, these needed two fixes: 1. Dangling view. lightvector#1202 made WeightEntry::data a non-owning FloatView, so addConstOp registers a view whose backing buffer must outlive serialization. These locals were passed to addConstOp and would dangle once the build function returns (serialization runs afterwards). Route them through addOwnedConstOp so KataGoOps owns the buffer until serialization. (Under lightvector#1205's owning WeightEntry they were copied, so this only surfaces post-merge.) 2. dtype mismatch. emitConstOp declares each const's dtype as m_weight_dtype, but addOwnedConstOp / registerOwnedWeight stored at the global mode (is_fp32 hardcoded false). In an FP16 model these derived consts land in the attention / value-head FP32 sub-region (m_weight_dtype == FLOAT32), so they were declared FP32 but stored FP16. CoreML/ANE then rejects the model at load ("Metadata data type does not match requested type", BNNS error -14), which SIGABRT'd every FP16 ANE transformer. Thread is_fp32 through registerOwnedWeight and have addOwnedConstOp pass is_fp32 = (m_weight_dtype == FLOAT32), mirroring addConstOp so the stored dtype always matches the declared dtype. This also fixes the same latent mismatch for addLinearOp's transposed value-head weights. Verified with testgpuerror against fresh Eigen FP32 references: b7c96h3tfrs and b7c96h6gqa, which previously SIGABRT'd on the FP16 ANE path, now load and match to <0.0005% winrate; convnet ANE output is byte-identical and the Metal GPU path is unchanged. katago runtests and runnnlayertests also pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 8481a94)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds Metal GPU and CoreML/ANE (Apple Neural Engine) inference support for the v15+ transformer trunk
architectures (e.g.
b10c384h6nbttflrs), bringing the macOS backends to parity with the newtransformer features, and fixes a convnet performance regression introduced along the way.
This targets the macOS Metal 4 backend and the bundled
katagocoremlMIL converter only; it does nottouch the CUDA / TensorRT / OpenCL / Eigen backends.
What's included
NDArraydimension (INT_MAX); reshaped to avoid it.accuracy-sensitive ops (non-spatial pooling, RMSNorm, heads, and — for narrow trunks — the whole graph)
are escalated to FP32 while attention stays FP16. Mirrors the TensorRT precision split.
d052d2a1): the precision tiers were originally keyed on trunk width, whichalso caught plain production convnets (b18c384 / b28c512 / b60c320) and forced FP32 ops off the FP16-only
ANE (~2.6× slower). The escalations are now gated on actual transformer-block presence (recursing into
nested-bottleneck blocks), so convnets run pure FP16 on the ANE (identical to prior behavior) and
transformer graphs are byte-for-byte unchanged.
Verification (macOS, Apple Silicon)
metalDeviceToUseThread0=100, metalUseFP16=true):cpp/rungpuerrortest.shrun two-phase (Eigen FP32 references → Metal compare): 32/32 CoreML-supported configs pass across
board sizes 9 / 13 / 19 / 10×14 / rectangle, including the rectbuffer and weird-settings cases. Worst
ANE-FP16-vs-Eigen winrate max = 1.13% (threshold 5.0%);
fp32-vs-reference ≈ 0.0005% confirms validreferences. Models exercised: b18c384nbt ×2, b28c512nbt, b18c384nbt-humanv0, b5c192nbt, g170e-b10c128,
and 3 transformer test nets (fixed-RoPE / learnable-RoPE+SiLU / GQA).
gpuIdx=0=METAL_MUX_GPU→ MPSGraph; nometalDeviceToUseThread0=100override):testgpuerroragainst the same Eigen FP32 references for all 3transformer test nets, 13/13 configs pass across board sizes 9 / 13 / 19 / rectangle plus the
rectbuffer case. Worst MPSGraph-vs-Eigen winrate max ≈ 0.0006% (threshold 5.0%);
fp32-vs-referencematches, i.e. pure FP32-vs-FP32 numerical noise. Covers the fixed-RoPE + RMSNorm-tip (b7c96h3tfrs),
learnable-RoPE + SiLU + spatial-RMSNorm-tip (b4c256h4nbttflrs), and GQA + ReLU + batchnorm-tip
(b7c96h6kv3qk32v16tflrs) architectures. The MPSGraph path is FP32-only by design.
./katago runtestsand./katago runnnlayertestspass.Notes
MIN_SUPPORTED_VERSION = 8); suchnets are expected to run on the GPU / other backends.
🤖 Generated with Claude Code