Skip to content

Add MLX backend for Apple Silicon (GPU + ANE/CoreML dispatch, incl. transformer nets)#1199

Draft
ChinChangYang wants to merge 34 commits into
lightvector:masterfrom
ChinChangYang:mlx-backend-squash
Draft

Add MLX backend for Apple Silicon (GPU + ANE/CoreML dispatch, incl. transformer nets)#1199
ChinChangYang wants to merge 34 commits into
lightvector:masterfrom
ChinChangYang:mlx-backend-squash

Conversation

@ChinChangYang

@ChinChangYang ChinChangYang commented May 23, 2026

Copy link
Copy Markdown
Contributor

This PR adds a new neural-net backend (USE_BACKEND=MLX) targeting Apple
Silicon via Apple's MLX framework,
with two dispatch paths sharing a single backend:

  • GPU — MLX/Metal with an F(2×2, 3×3) Winograd path and an adaptive
    per-shape tuner.
  • ANE — CoreML on CPU + Apple Neural Engine, mirroring the Metal
    backend's gpuIdx = 100 convention. Usable standalone or muxed with
    the GPU path on the same model to overlap forward passes.

It implements the full nninterface.h contract (model load, batched
evaluation, FP16/FP32 paths) and reuses the existing CoreML conversion
pipeline shared with the Metal backend. Both dispatch paths support
the v15+ transformer trunk
(attention incl. GQA, learnable RoPE, SiLU,
RMSNorm/batchnorm tips, SwiGLU FFN) in addition to convnets.

What's new

Backend (cpp/neuralnet/)

  • mlxbackend.cpp — backend implementation.
    • GPU path: Variable board sizes via input masking (same
      nnXLen / nnYLen contract as other backends; the global
      COMPILE_MAX_BOARD_LEN bound still applies), FP16/FP32 selected
      by mlxUseFP16 (default auto → fp16). Mish runs FP16-safe; the
      code asserts on ACTIVATION_MISH_SCALE8 so out-of-range variants
      fail loudly rather than truncate silently.
    • ANE path: Selected per server thread via
      mlxDeviceToUseThread<N> = 100 (or the backend-agnostic
      deviceToUseThread<N>); shares the model+converter cache with
      Metal. Feeds the real spatial mask (channel 0 NHWC → buffer) so
      rectangular / sub-NN-frame boards predict correctly, transposes
      NHWC → NCHW into a per-batch staging buffer for the Swift
      MLMultiArray contract, and uses path-correct strides in the
      policy-optimism postprocessor for v12+ models.
    • Mux (GPU + ANE): Serializes ComputeHandle construction with
      a file-static mutex (CoreML converter is not concurrent), and
      eagerly evaluates FP16 weight casts so secondary MLX/GPU server
      threads don't trip MLX 0.31.2's thread-local command-encoder map.
  • mlxwinograd.h — F(2×2, 3×3) Winograd transform with fused
    activation + residual add.
  • mlxwinotuner.{cpp,h} — per-shape Winograd tuner with adaptive
    scoring (rotates the candidate set per shape, scores by median-time
    delta against a baked-default baseline). Logs the conv-3x3 shape
    distribution at model load. At model load it loads a valid cache or,
    on a miss, coarse-tunes once — it never re-tunes a valid cache and
    never sweeps the wide grid; explicit re-tuning is command-only via
    ./katago tuner (see Build / wiring). KATAGO_MLX_WINOTUNER=0
    skips tuning entirely and uses baked-default launch geometry, while
    KATAGO_MLX_WINOGRAD=0 falls 3×3 convs back to mx::conv2d.
  • mlxtests.cpp — Winograd + tuner numeric-consistency tests, gated
    under runnnlayertests.

Transformer trunk support (both dispatch paths)

  • GPU path (native MLX): transformer trunk/tip implemented directly
    in mlxbackend.cpp — scaled-dot-product attention with grouped-query
    attention (GQA), learnable RoPE, RMSNorm and batchnorm trunk tips,
    SwiGLU FFN, and the SiLU activation. Mirrors eigenbackend.cpp.
  • ANE path (CoreML/katagocoreml): the same transformer trunk in
    the shared MIL converter, with FP16 accuracy precision tiers gated
    on actual transformer-block presence (recursing into nested-bottleneck
    blocks) so plain convnets stay pure FP16 on the FP16-only ANE. For
    transformer trunks: narrow (<256ch) build fully FP32; wider ones
    escalate the accuracy-sensitive non-spatial ops (matmuls + global
    pooling) to FP32; very wide (≥320ch) also escalate convs; RMSNorm
    reductions run FP32 in FP16 mode while attention scores/softmax and the
    per-head out-projection stay FP16. Mirrors the TensorRT precision split
    and the Metal/CoreML backend's tiers. The converter's public API is
    unchanged, so the MLX ANE call site needs no edits.

Build / wiring

  • cpp/CMakeLists.txtUSE_BACKEND=MLX target; pulls in the
    Metal/Swift CoreML bridge so the ANE path links cleanly. MLX needs
    CMake 3.27; cmake_minimum_required stays at 3.18.2 so other
    backends keep building on older CMake. Links Homebrew's prebuilt
    libmlx.dylib; OSX deployment target is intentionally not pinned
    so the executable's minos matches the linked dylib.
  • cpp/main.cpp, cpp/program/setup.cpp, cpp/command/benchmark.cpp
    — wire MLX into backend selection / benchmark.
  • cpp/command/tune.cpp — wires the MLX winotuner into the ./katago tuner subcommand (previously OpenCL-only), mirroring OpenCL's
    command-only tuning. ./katago tuner -model <net> always re-runs the
    sweep and overwrites the cache the backend reads at model load; add
    -full for the wide candidate grid. Args: -output (default: shared
    MLX cache), -xsize / -ysize (default 19), -batchsize (default
    8), -testFP16 (true|false|auto, default auto → FP16, the
    precision the engine and the cache filename key use), -full. The
    OpenCL-only FP16 sub-knobs (storage / compute / tensorcores) have no
    MLX analog and are omitted.
  • cpp/configs/{gtp,analysis,match,contribute}_example.cfg
    document mlxUseFP16 (default auto → fp16) and the
    numNNServerThreadsPerModel / mlxDeviceToUseThread<N> dispatch
    knobs (GPU-only / ANE-only / mux), with the note that
    mlxUseFP16=false on an ANE thread falls back to CPU FP32.
  • cpp/rungpuerrortest.sh — backend-agnostic
    deviceToUseThread0=100 for the ANE mode, so the same script
    drives whichever backend the binary was built with.
  • Compiling.md — build instructions.

Conversion + ANE steady-state memory

The ANE path carries the on-device memory levers from #1202 (cut CoreML
conversion peak and ANE steady-state RSS, with byte-identical converter
output), cherry-picked onto this branch and wired into the MLX backend:

  • Non-owning weight views — the converter's weight tensors are
    non-owning FloatViews into the parsed model instead of owning extra
    FP32 copies (derived/transposed tensors keep an owned buffer). The
    rvalue registerWeight / addConstOp overloads are = deleted so a
    temporary can't bind to the view path and dangle, and CoreML
    serialization is made deterministic for byte-stable output.
  • Streaming parser — the KataGo model parser streams the gzip
    through a bounded ~1 MB refill buffer (RAII gzFile) instead of
    decompressing the whole file into memory, preserving the existing
    NaN/Inf weight validation.
  • Release weights on the ANE-only pathModelDesc::releaseWeights()
    frees the in-memory FP32 weight arrays (keeping scalar shape metadata),
    gated by a new ComputeContext::aneOnly flag that is true only when
    every configured device index is MLX_MUX_ANE. On the MLX backend it
    fires in convertAndCreateCoreMLOnlyHandleMLX after the model has been
    converted to CoreML on disk; it is safe because the ANE path re-reads
    the model from modelPath (not the in-memory arrays), the
    ComputeHandle ctor takes the ANE early-return before building any
    MLX/GPU model (the only weight-array consumer), only scalar dims are
    read afterward, and it runs under computeHandleMutex. The GPU/MLX
    path keeps its weights (aneOnly is false whenever any thread uses the
    GPU). releaseWeights() covers the transformer descriptors too
    (attention incl. ropeFreqs, FFN, RMSNorm tips).

Measured (19×19, MLX ANE path)

Resident-memory A/B (mlxDeviceToUse=100, cold conversion every run),
before vs after the levers. Peak RSS is the /usr/bin/time -l maximum
over load+convert; idle steady-state is ps-sampled after load+genmove.

Model Peak RSS (load+convert), before → after Idle steady-state RSS, before → after
Classical 40b (b40c256) 1.61 → 1.05 GB (−35%) 0.41 → 0.16 GB (−62%)
18b nbt (b18c384nbt) 0.94 → 0.63 GB (−33%) 0.64 → 0.32 GB (−49%)
Transformer (b10c384h6nbttflrs) 1.69 → 1.52 GB (−10%) 1.69 → 1.52 GB (−10%)

The peak win is largest on the classical 40b (the duplicate owning weight
copies + whole-gzip decompress the levers target dominate there); the
transformer gains least because most of its converter-side weight is
derived (RoPE tables, rotation matrix, per-head out-proj slices) which
keeps an owned buffer, and its absolute peak is dominated by ANE
compilation of the MIL graph downstream of these host-side levers.

The levers don't change inference numerics: testgpuerror ANE output is
bit-identical before vs after on a convnet and all three transformer nets
(release happens after conversion; the MLX/GPU path never invokes the
converter), and benchmark throughput is within run-to-run noise on both
the GPU and ANE paths.

How to build

cd cpp
cmake -G Ninja -DUSE_BACKEND=MLX
ninja

Requires CMake ≥ 3.27 and brew install mlx.

Validation

Cross-backend testgpuerror vs an Eigen FP32 reference, via
cpp/rungpuerrortest.sh (all 12 nets across 9×9 / 13×13 / 19×19 / 10×14 /
rectangular / rect-buffer / weird-settings). MLX FP16 (auto) vs Eigen
FP32 — winrate error per net (99%-ile / worst-case max across that
net's configs):

Network Ver GPU FP16 (99% / max) ANE FP16 (99% / max)
run4-b6c96 v3 0.22% / 0.43% — pre-v8
grun50-b6c96 v4 0.31% / 1.11% — pre-v8
g103-b6c96 v5 0.28% / 1.31% — pre-v8
g170e-b10c128 v8 0.60% / 2.35% 0.40% / 1.45%
kata1-b18c384nbt (s5832) v11 0.45% / 1.03% 0.39% / 0.65%
kata1-b18c384nbt (s9996) v14 0.40% / 1.36% 0.76% / 1.38%
kata1-b28c512nbt v15 0.35% / 1.30% 0.56% / 2.40%
b18c384nbt-humanv0 SL 0.14% / 0.42% 0.22% / 0.52%
b5c192nbt-v16test v16 0.11% / 0.18% 0.19% / 0.25%
b7c96h3tfrs · RoPE+RMSNorm v17 0.59% / 2.49% 0.0003% / 0.0004% ᵗ
b4c256h4nbttflrs · SiLU+sp-RMSNorm v17 0.20% / 0.30% 0.48% / 0.79%
b7c96h6kv3qk32v16tflrs · GQA+bnorm v17 0.24% / 0.59% 0.0001% / 0.0003% ᵗ

Pass: GPU 37/37 configs, ANE 31/31 supported configs. FP32-vs-Eigen
sanity (numerical noise): GPU ≤ 0.00092%, ANE ≤ 0.00082%.
./katago runtests and ./katago runnnlayertests pass.

ᵗ 96ch transformers hit the converter's full-FP32 tier on ANE (FP16 == FP32);
the 256ch SiLU transformer uses the non-spatial-FP32 tier (real FP16 residual).

The 6 pre-v8 ANE configs (v3/v4/v5) are skipped — the CoreML converter
supports model versions 8–17 only. Run ANE testgpuerror serially or with a
per-process TMPDIR: the converter stages weights in a shared
$TMPDIR/katagocoreml_weights dir, so concurrent processes race.

GPU throughput: MLX vs Metal

Single-GPU throughput of the two Apple-Silicon GPU backends, both configured
GPU-only (no ANE; MLX gpuIdx 0, Metal default device 0), on an Apple
M3 Max
. The two backends take different GPU strategies, so this compares
each backend's actual GPU-only behavior rather than an identical-precision
kernel:

  • MLX GPU — F(2×2, 3×3) Winograd path in FP16 (mlxUseFP16 auto).
  • Metal GPUMPSGraph in FP32 (this backend's GPU path, incl. the
    transformer trunk, is FP32-only).

The Metal binary is built from feature/metal-transformer-silu-b10c384h6
(the branch whose Metal GPU path supports the v15+ transformer trunk). Each
net was benchmarked with benchmark -half-batch-size -t 8,16,32 -v 800 at
19×19, run sequentially — one backend / one net at a time. Cells are
visits/s at each thread count; the human-SL net used
humanSLProfile=rank_9d (same override on both backends). The three v17
transformer MLX rows were re-measured after the FP16-attention fix
(52240987, median of 3 runs); the Metal column and all other rows are
unchanged.

Network Ver MLX GPU · FP16
visits/s @ t=8 / 16 / 32
Metal GPU · FP32
visits/s @ t=8 / 16 / 32
MLX best ÷ Metal best
run4-b6c96 v3 3808 / 5797 / 6930 3112 / 3804 / 3232 1.82×
grun50-b6c96 v4 3986 / 6160 / 7771 3614 / 5357 / 6400 1.21×
g103-b6c96 v5 4166 / 6354 / 7860 3458 / 5691 / 6254 1.26×
g170e-b10c128 v8 2710 / 3773 / 4538 2479 / 2943 / 3796 1.20×
kata1-b18c384nbt (s5832) v11 529 / 662 / 711 359 / 506 / 476 1.41×
kata1-b18c384nbt (s9996) v14 548 / 666 / 712 359 / 466 / 447 1.53×
kata1-b28c512nbt v15 221 / 236 / 221 138 / 164 / 164 1.44×
b18c384nbt-humanv0 SL 487 / 576 / 552 333 / 430 / 413 1.34×
b5c192nbt-v16test v16 2066 / 3372 / 3969 1926 / 3009 / 3790 1.05×
b7c96h3tfrs · RoPE+RMSNorm v17 1698 / 2418 / 2925 1532 / 1907 / 1551 1.53×
b4c256h4nbttflrs · SiLU v17 1226 / 1693 / 1848 900 / 1175 / 1007 1.57×
b7c96h6kv3qk32v16tflrs · GQA v17 1364 / 1817 / 2005 1119 / 1240 / 1275 1.57×

MLX-GPU is faster than Metal-GPU on every net (best-thread throughput
1.05×–1.82×), with the largest margins on the deep NBT nets
(b18c384nbt ~1.4–1.5×, b28c512nbt ~1.44×) and the v17 transformer nets
(~1.5–1.6×, now that the GPU attention path runs end-to-end in FP16) where
the FP16 Winograd path has the most to exploit. The FP16-vs-FP32 precision gap is part of that lead;
accuracy of each path vs an Eigen FP32 reference is in the Validation
table above.

Status

Draft — opening for early feedback on the backend's structure, the
tuner approach, the GPU/ANE dispatch wiring, and the transformer-trunk
support (incl. the ANE FP16 precision tiers) before promoting to
ready-for-review.

Introduces a new neural-net backend (USE_BACKEND=MLX) targeting Apple
Silicon via Apple's MLX framework. The backend implements the full
nninterface contract (model load, batched evaluation, FP16/FP32 paths)
and ships with a Winograd 3x3 convolution path plus an adaptive
per-shape tuner that picks the fastest implementation for each
conv-3x3 shape at model load.

Backend
- cpp/neuralnet/mlxbackend.cpp: backend implementation. Supports
  variable board sizes via input masking (same nnXLen/nnYLen
  contract as other backends; the global COMPILE_MAX_BOARD_LEN
  bound still applies). FP16/FP32 selected by the mlxUseFP16 config
  (default auto -> fp16); same input feature layout as the other
  backends. Mish activation runs FP16-safe (asserts on
  ACTIVATION_MISH_SCALE8 so out-of-range variants are caught
  explicitly rather than silently truncated).
- cpp/neuralnet/mlxwinograd.h: F(4x4, 3x3) Winograd transform with
  fused activation + residual add.
- cpp/neuralnet/mlxwinotuner.{cpp,h}: per-shape Winograd tuner with
  adaptive scoring (rotates the candidate set per shape, scores by
  median-time delta against a baked-default baseline). Logs the
  conv-3x3 shape distribution at model load.
- cpp/neuralnet/mlxtests.cpp: unit tests for the Winograd path
  and tuner numeric-consistency, gated under runnnlayertests.

Build / wiring
- cpp/CMakeLists.txt: USE_BACKEND=MLX target. MLX requires CMake
  3.27 (cmake_minimum_required stays at 3.18.2 so other backends
  keep building on older CMake). Links Homebrew's prebuilt
  libmlx.dylib; OSX deployment target intentionally not pinned so
  the executable's minos matches the dylib it was linked against.
- cpp/main.cpp, cpp/program/setup.cpp, cpp/command/benchmark.cpp:
  wire MLX into backend selection / benchmark.
- cpp/configs/{gtp,analysis,match,contribute}_example.cfg: document
  mlxUseFP16 (auto / true / false), default auto -> fp16.
- Compiling.md: build instructions for the MLX backend.

Validation
- Cross-backend validation against an Eigen reference (testgpuerror)
  for b18c384nbt, b40v8, and humanv0 nets shows FP32 max winrate
  error 0.00095% and FP16 max 2.63%, well within the existing
  backend tolerances.

This is the squash of 130 commits from feature/mlx-backend.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ChinChangYang ChinChangYang changed the title Add MLX backend for Apple Silicon Add MLX backend for Apple Silicon (GPU + ANE/CoreML dispatch) May 26, 2026
ChinChangYang and others added 4 commits June 2, 2026 23:47
master consolidated createComputeContext's trailing params (openCLTunerFile,
openCLReTunePerBoardSize, useNHWCMode) into a single ConfigParser& cfg. The
Metal backend was already updated; update the MLX backend to match so it
compiles against the merged interface. NHWC is still enforced per-handle via
inputsUseNHWC in createComputeHandle.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The MLX backend implemented only the convnet path; transformer nets crashed
(empty-(0)-tensor broadcast on rmsnorm/SiLU tips) or produced garbage (GQA).
Implement the transformer trunk/tip path in mlxbackend.cpp, mirroring
eigenbackend.cpp:
- ACTIVATION_SILU (x * sigmoid(x))
- TransformerRMSNormLayer (spatial rmsnorm tip) + TransformerTrunkRMSNormLayer (pre-LN)
- GQA TransformerAttentionBlock + SwiGLU FFNBlock
- branch the trunk tip on trunkNormKind; wire the new block kinds into the
  block-variant and nested-bottleneck loops; thread nnX/nnY through

Verified via testgpuerror against fresh Eigen references (boardsize 19):
fp32 winrateError max — rope 0.00094%, silu 0.00046%, gqa 0.00029% (bar 0.10%);
convnet g170-b6c96 unregressed (0.00036%); runtests + runnnlayertests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The MLX backend's ANE mux path (mlxDeviceToUseThread0=100) drives inference
through the shared external/katagocoreml converter -- the same library the
Metal backend uses. Bring PR#1205's CoreML/ANE transformer work into that
converter so the MLX ANE path supports the v15+ transformer trunk:

- Transformer MIL support: attention (incl. grouped-query attention),
  learnable RoPE, SiLU, RMSNorm/batchnorm tips, SwiGLU FFN.
- FP16 accuracy precision tiers, gated on actual transformer-block presence
  (blocksContainTransformer, recursing into nested-bottleneck blocks):
  narrow trunks (<256ch) build fully FP32; wider ones escalate non-spatial
  matmuls + global pooling to FP32; very wide (>=320ch) also escalate convs;
  RMSNorm reductions FP32 in FP16 mode. Plain convnets stay pure FP16 on the
  ANE (the d052d2a regression-fix behavior is preserved).

The converter's public API is unchanged, so the MLX call site
(CoreMLConversion::convertModelToTemp) needs no edits. The Metal-GPU/MPSGraph
portions of PR#1205 (metalbackend.cpp, metallayers.swift) are intentionally
not ported -- the MLX backend's native GPU path already has transformer
support.

Verified on the MLX ANE mux (testgpuerror vs fresh Eigen FP32 references):
all 3 transformer test nets pass FP16 thresholds across board sizes/buffer
configs (7 configs), a plain convnet stays pure FP16 (non-regression), and
runtests + runnnlayertests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@ChinChangYang ChinChangYang changed the title Add MLX backend for Apple Silicon (GPU + ANE/CoreML dispatch) Add MLX backend for Apple Silicon (GPU + ANE/CoreML dispatch, incl. transformer nets) Jun 3, 2026
ChinChangYang and others added 18 commits June 3, 2026 18:13
applyGlobalPooling / applyValueHeadPooling summed in fp16 but produced an
fp32 mean (division by the fp32 maskSum), which also leaked fp32 into the
downstream gpool-bias and value-v2 head matmuls. Cast maskSum to the input
dtype so the whole pooling and the heads stay in the compute dtype (fp16
when useFP16), maximizing fp16 utilization rather than escalating to fp32
for negligible accuracy gain.

The masked-max keeps its 1e9 constant in fp32 (1e9 overflows fp16 ->
inf -> 0*inf=NaN), then casts the max result back to the compute dtype.
The fp32 path is unaffected (the astype casts are no-ops in fp32).

Verified via testgpuerror vs fresh Eigen fp32 references on all 3
transformer nets (7 board-size/buffer configs): fp16 winrate error max
<= 2.07% (within tolerance, winrate unchanged vs baseline), fp32 path
byte-identical, ownership output bit-identical, runnnlayertests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…nt warmup

M2: A candidate whose threadgroup exceeds the pipeline's register-pressure-
dependent maxTotalThreadsPerThreadgroup (can be < 1024), or that hits a
transient GPU error, throws out of mx::eval during the flat sweep. Previously
this propagated out of loadOrAutoTune and aborted model load with no fallback.
Now each candidate's scoring is wrapped in try/catch: a throw is counted and
skipped (mirroring the OpenCL tuner's mark-bad-and-continue), and best/bestTime
are seeded with the baked default so even a fully-failing sweep returns a valid
result. A separate "flatSweep{Input,Output} skipped=N" log line is emitted only
when skips occur; it intentionally omits the colon after the function name so it
cannot collide with the regex-tested "flatSweepInput: considered" log line.

M3: timeOneInputTransform/timeOneOutputUntransform ran an untimed warmup eval on
every call, but the scoring functions already warmed up once before the measured
loop -- so every measured rep paid an extra full warmup (~doubling tuning cost).
Add a doWarmup parameter gating the internal warmup; the scoring functions drop
their explicit warmup and pass (r == 0), so each shape warms exactly once on its
first measured rep.

Verified by triggering autotuning: gated flat-sweep tests (convergence,
log-format, baseline-consistency, per-shape) pass; an end-to-end re-tune via
loadOrAutoTune runs a fresh sweep and saves valid fp16/fp32 caches; testgpuerror
output is unchanged (tuner params are numerically inert); runtests passes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The METAL and MLX backend branches each ran set(CMAKE_OSX_DEPLOYMENT_TARGET
13.0) *after* project(), where it is a silent no-op: for this Swift project
the deployment target is fixed during project()/enable_language, so a later
set() never affects the produced binary. Both shipped binaries already carry
minos 26.0 (the build host / libmlx's floor), not 13.0, confirming the pins
were inert dead code that contradicted the pre-project comment explaining why
the deployment target is deliberately not pinned.

Delete both pins so code, comment, and reality agree; the comment becomes
literally true. Add a guard note documenting that a post-project pin is a
no-op so it is not reintroduced. No behavior change: binaries still build at
minos 26.0, matching libmlx's minos and MLX's macOS >= 14 requirement.

Verified: MLX reconfigure+build clean; METAL branch configures clean; binary
minos unchanged (26.0); runtests pass; testgpuerror unchanged (fp32 max
0.00036%, fp16 max 0.863%).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The katagocoreml parser read the q/k/v/out projection matmuls of a
transformer attention block without checking their declared dimensions
against the head geometry or trunk width. The CoreML graph builder
(MILBuilder) then reshapes each flat projection into a [seq, heads,
headDim] grid and reshapes the out-projection result back into the trunk,
so a mismatched dimension would either read past the weight buffer or
build a graph that compiles but computes nonsense - the exact failure
mode the existing checkBlockChannels() guard was added to prevent for
conv blocks.

Thread trunk_num_channels into parseTransformerAttentionBlock (mirroring
parseNestedBottleneckBlock) and add a checkAttentionProjDim() helper in
the style of checkBlockChannels(), then validate all four projections:

  qProj.outChannels  == numHeads   * qHeadDim   (master desc.cpp:1129)
  kProj.outChannels  == numKVHeads * qHeadDim   (master desc.cpp:1131)
  vProj.outChannels  == numKVHeads * vHeadDim   (master desc.cpp:1133)
  outProj.inChannels == numHeads   * vHeadDim   (master desc.cpp:1135)
  qProj.inChannels   == trunkNumChannels        (master desc.cpp:1430)
  outProj.outChannels== trunkNumChannels        (master desc.cpp:1437)
  k/vProj.inChannels == trunkNumChannels        (gap master leaves
                                                 implicit to the backend)

Six checks mirror master desc.cpp's transformer attention consistency
checks exactly; the k/v inChannels checks additionally close a gap master
leaves to the backend (all three QKV projections consume the same
normed-trunk input, so their inChannels must equal the trunk width). K
pairs with Q in the QK^T dot product, so kProj uses qHeadDim; only V
carries vHeadDim.

Purely additive: throws std::runtime_error on a malformed model, no-op on
valid ones, no numerics touched. Verified: MLX build clean, runtests
pass, and ANE-path testgpuerror on all three transformer nets (incl. the
GQA net with 6 heads/3 KV heads, qk=32/v=16) loads and converts with zero
false-positive throws and unchanged numerics.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Cuts memory during the on-device KataGo -> CoreML conversion and while
running the ANE/CoreML path, with byte-identical converter output:

- The converter's weight tensors become non-owning views into the parsed
  model instead of owning extra FP32 copies; derived/transposed tensors keep
  an owned buffer. This drops redundant resident weight copies during
  conversion. CoreML model serialization is made deterministic
  (SetSerializationDeterministic) so the output is byte-stable.

- The KataGo model parser streams the gzip through a bounded ~1 MB refill
  buffer instead of decompressing the whole file into memory, while
  preserving the existing NaN/Inf weight validation.

- ModelDesc gains releaseWeights(), which frees the in-memory weight arrays
  (keeping scalar shape metadata). The Metal backend calls it on the ANE
  (CoreML) path after converting from the model file on disk, gated by a new
  ComputeContext::aneOnly flag so it only fires when every configured device
  is ANE -- the GPU/MPSGraph path keeps its weights. The call is serialized
  under computeHandleMutex and only scalar dims are read afterward.

Measured on b18c384nbt (19x19) over the ANE path: idle steady-state RSS
0.59 GB -> 0.19 GB; peak (load+convert) 0.87 GB -> 0.48 GB. Cross-backend
parity vs an Eigen reference is unchanged on both the GPU and ANE paths.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit b05f559)
WeightEntry stores a non-owning view (const float*, count) into the live
KataGoModelDesc, so the backing std::vector must outlive serialization.
addConstOp/registerWeight took the data by const& and silently stored a
pointer to it; a caller passing a temporary would bind to that const& and
leave the view dangling, read much later during serialization.

Delete the rvalue overloads of both so any such call fails to compile,
forcing temporaries through addOwnedConstOp/registerOwnedWeight (which take
ownership). Named lvalues (the model-member call sites) still bind to the
const& overload, so no existing caller changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit 971fa9d)
Own the gzFile with a custom-deleter unique_ptr so it closes on every
exit path (normal return, exception, bad_alloc); removes the manual
try/catch+gzclose in parse() and the ordering caveat on buffer allocation.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit eeefc97)
Introduce a KataGo-local non-owning FloatView for WeightEntry::data instead
of a raw const float*/size_t pair; convert to MILBlob::Util::Span only inside
WeightSerializer, keeping the MILBlob dependency out of Operations.hpp.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit 6bfa617)
The ComputeHandle member-order comment claimed that declaring
mpsGraphOnlyHandle before coremlOnlyHandle is what prevents a GPU handle
from reading freed weights. That overstates the ordering's role: within a
single ComputeHandle exactly one handle is built (mutually exclusive on
gpuIdx, enforced by the ctor's exactly-one check), and releaseWeights()
only fires on an aneOnly context where no MPSGraph handle is ever built.
Reframe the declaration order as belt-and-suspenders and point at
ComputeContext::aneOnly as the actual invariant. Comment-only change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit 4159930)
Replace the file-local releaseXXX free functions in desc.cpp (which
reached into each desc struct's internals from outside) with
releaseWeights() member methods on each weight-bearing struct, matching
the existing OO convention used by applyScale8ToReduceActivations() and
iterConvLayers(). Each container delegates to its members; type-erased
block dispatch is inlined with the same cast pattern those methods use.

Behavior-preserving: same set of freed vectors, same block recursion,
same metaEncoderVersion guard. ModelDesc::releaseWeights() keeps its
signature, so the metalbackend.cpp call site is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit 44342a3)
Move the 11 leaf/container releaseWeights() definitions in desc.cpp out
of the bottom cluster (inherited from the old free-function layout) and
place each immediately after its struct's last existing method, matching
the file's per-struct grouping convention used by every other method.
ModelDesc::releaseWeights() stays put, already adjacent to its siblings.

Pure relocation: function bodies and desc.h are unchanged; only two
stray double-blank lines were normalized to single. Verified clean Metal
build, testgpuerror vs Eigen reference (g170-b6c96) at <0.0004% winrate
error, and runtests all pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit 98b17eb)
…contract

The transformer attention builder emits four function-local std::vector<float>
tensors: RoPE cos/sin tables, the rotation matrix R, and per-head out-projection
weight slices. After merging the transformer support onto the FloatView branch,
these needed two fixes:

1. Dangling view. lightvector#1202 made WeightEntry::data a non-owning FloatView, so
   addConstOp registers a view whose backing buffer must outlive serialization.
   These locals were passed to addConstOp and would dangle once the build
   function returns (serialization runs afterwards). Route them through
   addOwnedConstOp so KataGoOps owns the buffer until serialization. (Under
   lightvector#1205's owning WeightEntry they were copied, so this only surfaces post-merge.)

2. dtype mismatch. emitConstOp declares each const's dtype as m_weight_dtype, but
   addOwnedConstOp / registerOwnedWeight stored at the global mode (is_fp32
   hardcoded false). In an FP16 model these derived consts land in the attention /
   value-head FP32 sub-region (m_weight_dtype == FLOAT32), so they were declared
   FP32 but stored FP16. CoreML/ANE then rejects the model at load ("Metadata data
   type does not match requested type", BNNS error -14), which SIGABRT'd every
   FP16 ANE transformer. Thread is_fp32 through registerOwnedWeight and have
   addOwnedConstOp pass is_fp32 = (m_weight_dtype == FLOAT32), mirroring addConstOp
   so the stored dtype always matches the declared dtype. This also fixes the same
   latent mismatch for addLinearOp's transposed value-head weights.

Verified with testgpuerror against fresh Eigen FP32 references: b7c96h3tfrs and
b7c96h6gqa, which previously SIGABRT'd on the FP16 ANE path, now load and match
to <0.0005% winrate; convnet ANE output is byte-identical and the Metal GPU path
is unchanged. katago runtests and runnnlayertests also pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit 8481a94)
The cherry-picked per-struct releaseWeights() refactor (44342a3/98b17ebb)
predates this branch's MLX transformer port, so it only added releaseWeights()
to the non-transformer descriptors. Extend the coverage to the transformer
descriptors present on this branch (RMSNormLayerDesc, TransformerRMSNormDesc,
TransformerAttentionDesc incl. ropeFreqs, TransformerFFNDesc) and handle
TRANSFORMER_ATTENTION_BLOCK_KIND / TRANSFORMER_FFN_BLOCK_KIND plus
trunkTipRMSNorm in the trunk release walk. Without this, releasing weights on
a transformer model would hit ASSERT_UNREACHABLE. This makes desc.cpp/desc.h
byte-identical to the lightvector#1202 feature branch.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Port the lightvector#1202 ANE steady-state memory lever to the MLX backend. Add
ComputeContext::aneOnly, set in createComputeContext when every configured
device index is MLX_MUX_ANE, and call ModelDesc::releaseWeights() in
convertAndCreateCoreMLOnlyHandleMLX after the model has been converted to
CoreML on disk.

Safe because: the ANE path re-reads the model from modelPath (not the
in-memory weight arrays); the ComputeHandle ctor takes the MLX_MUX_ANE
early-return before building any MLX/GPU model (the only weight-array
consumer); only scalar dims are read afterward, which releaseWeights()
preserves; and it runs under computeHandleMutex. Mirrors the Metal backend's
aneOnly release. GPU path unaffected (aneOnly is false whenever any thread
uses MLX_MUX_GPU).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
parseTransformerFFNBlock only checked num_channels/ffn_channels > 0, while the
attention block validates all of its projection dimensions. Thread
trunk_num_channels through and add the mirror checks: num_channels must equal the
trunk width (the block adds its output back into the trunk residually) and the
linear layers must chain numChannels -> ffnChannels -> numChannels (with the
SwiGLU gate also numChannels -> ffnChannels). A malformed FFN block now fails at
parse time instead of producing an opaque CoreML compile error or silently-wrong
activations. Reuses the existing checkAttentionProjDim helper.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Six conv/matmul/RMSNorm/FFN sites hand-rolled save/restore of m_weight_dtype
around their FP32 escalation windows. An exception thrown inside a window would
leave m_weight_dtype stuck at FLOAT32, causing later FP16 consts to be tagged
FP32 -> the BNNS "Metadata data type does not match" SIGABRT on the FP16 ANE.
Give ScopedFp32 an active flag (so a conditional window needs no construction-time
branch) and an idempotent restore() (to end the window before a trailing
cast-down while keeping the dtor's exception-safe restore), then route all six
sites through it. The guard is constructed exactly where the manual flip was and
restore() called exactly where the manual restore was, so op-emission order -- and
thus the serialized converter output -- is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The raw-output forward path had no callers: production inference goes through
getOutput() -> Model::applyCompiled(). It also duplicated applyCompiled's input
setup and output copy. Delete it.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… test comment

CMakeLists.txt: uppercase USE_BACKEND into USE_BACKEND_NORMALIZED before the
pre-project() MLX version guard and the Swift language selection, mirroring the
post-project() string(TOUPPER). Previously a lowercase -DUSE_BACKEND=mlx skipped
the CMake 3.27 guard and Swift enablement, then still tried to build the Swift
sources later, producing a confusing failure instead of a clear message.

rungpuerrortest.sh: the gpu/ane modes drive whichever backend the binary was
built with (backend-agnostic deviceToUseThread0), so reword the usage comment
from "the Metal backend" to "the active backend (Metal or MLX)".

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ChinChangYang and others added 10 commits June 5, 2026 21:13
The F(2x2,3x3) input transform and output untransform did all their add/sub arithmetic in T, so in fp16 mode every intermediate of B^T d B and A^T M A rounded to fp16 -- a precision sink independent of the matmul (which already accumulates in fp32 via the steel GEMM). Compute the transform arithmetic in float and cast only the final stored V/M/Y back to T, leaving fp16 storage and memory traffic unchanged (no-op on the fp32 path).

Cuts fp16 Winograd kernel error ~33% (runnnlayertests ConvLayer fp16 winograd maxErr 0.0107 -> 0.0071) while the fp32 path stays bit-exact (maxErr=0). testgpuerror FP16-vs-Eigen avg/90%/99% errors drop across configs; g170e-b10c128 worst case 2.35% -> 2.13%. Benchmark throughput within run-to-run noise on both a convnet and a deep NBT net.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…uning

The model-load auto-tune swept a nearly-full candidate grid (~2000 configs, ~16s) even with full=false -- only tg1 differed between non-full and full. Measured on this hardware the winning configs form a broad plateau (many within ~7% of each other, all ~25-40% better than the baked default) and geometry moves end-to-end throughput <=1.5%, so that 16s was spent discriminating run-to-run noise: three forced re-tunes of the same net picked three entirely different winners, all within ~7%.

Mirror OpenCL's split: full=false (auto) now sweeps a coarse grid (tg0 {8,16,32,64,128}, tg1 {1,2,4,8,16}, wpt {1,2,4}) -- ~2.7s, still landing ~21% above the default and within ~6% of the wide-sweep winner. full=true keeps the wide grid as the deliberate command-tune, opt-in via KATAGO_MLX_WINOTUNER_FULL=1 (the analog of './katago tuner --full', which openclbackend.cpp pins to full=false at load). Cache format is unchanged; existing caches still load, and FULL+FORCE overwrites with the wide-swept winner.

Also loosen two gated tuner stress-test budgets (baseline-anchor 0.25->0.50, convergence 1.10->1.30) that compared single sub-millisecond timing samples and flaked ~1-in-4 on both this and the pre-change binary -- the same dispatch/sync-overhead noise the tuner itself tolerates. They remain gross-error sanity checks. runtests + runnnlayertests pass, gated tuner tests 16/16, testgpuerror fp32 bit-exact and fp16 winrate max 0.55% on the coarse-tuned config.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
M4 - silent fallthroughs in mlxbackend.cpp, now loud (ASSERT_UNREACHABLE,
the file's existing idiom):
- BlockVariant::apply switch default returned the input unchanged, a silent
  identity no-op block. Now asserts; kept as a default: label so the active
  -Wswitch-default (CMakeLists.txt) stays clean.
- The trunk block-construction loop silently dropped unknown kinds. Added an
  else { ASSERT_UNREACHABLE; }.
- The nested-bottleneck construction loop was a genuine latent bug, not just a
  missing guard: parseResidualBlockStack (desc.cpp, shared by trunk and nested)
  accepts nested_bottleneck_block inside a nested bottleneck and the desc layer
  handles it, but the MLX nested loop omitted NESTED_BOTTLENECK_BLOCK_KIND and
  silently dropped such a block. Added the missing case (mirroring the trunk
  loop and Eigen's shared BlockStack) plus an else-assert.

M3 - Compiling.md implied the MLX backend builds with make, but CMakeLists.txt
hard-fails MLX without the Ninja generator (same Swift/C++ interop requirement
as Metal). Added -G Ninja to the MLX cmake example, listed MLX alongside Metal
for the Ninja prerequisite, and noted MLX uses ninja to build.

Verification: build clean; runtests + runnnlayertests pass; testgpuerror
g170-b6c96 vs eigen_reference.json fp32 near-exact (winrate max 0.00036%) /
fp16 max 0.55% (unchanged); testgpuerror on the b4c256h4nbttflrs nested-
bottleneck+transformer model loads through the modified nested loop and runs
forward with no assert (fp16 ~0.27%).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The `tuner` command previously did nothing on non-OpenCL backends. Add an
MLX branch that loads the model, builds its conv-3x3 shape histograms, and
runs MLXWinogradTuner::loadOrAutoTune with reTune=true so the Winograd
input/output transform search runs and overwrites the cache the backend
reads at model load.

This is the first-class "command tune" path; the load-time auto-tune stays
coarse/fast. -full selects the wide candidate grid (the env-var
KATAGO_MLX_WINOTUNER_FULL=1 behavior, which still works for triggering a
full tune through benchmark/gtp). -testFP16 (auto->FP16) matches the
engine's useFP16 default and the cache-filename key. The default output
path is MLXWinogradTuner::defaultDirectory/defaultFileName - the exact file
the backend loads - verified end-to-end: after a default-path tune, a
benchmark model-load logs "Loaded MLX Winograd tuning parameters from"
that same file. The OpenCL-only FP16 sub-knobs have no MLX analog and are
omitted.

The backend guard is restructured from #ifndef USE_OPENCL_BACKEND to a
three-way #if/#elif/#else; the OpenCL body is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The env var was an explicit stopgap "command tune analog" for the wide
Winograd sweep before `./katago tuner` supported MLX. Now that the literal
tuner subcommand exists, drop the env var and pin the model-load path to the
coarse grid (full=false), so the wide sweep is reached only through
`./katago tuner -full`. This mirrors openclbackend.cpp, which pins full=false
at load and passes full only from the explicit tuner command, and reinforces
the coarse-auto-tune design.

No capability is lost: `tuner -full` writes the same cache the backend reads
at load, so the prior one-shot workflow (FULL=1 FORCE=1 benchmark) becomes
the cleaner two-step (tune once, then run) - the OpenCL workflow.

KATAGO_MLX_WINOTUNER (enable/disable) and KATAGO_MLX_WINOTUNER_FORCE (force
re-tune) are unchanged; FORCE now only ever drives a coarse re-tune at load.

Verified: FULL=1 FORCE=1 benchmark now re-tunes coarse (considered=288 for
b10c128, was 2176), while `tuner -full` still sweeps the wide grid
(considered=2176). runtests pass; testgpuerror unchanged (fp32 0.0005%,
fp16 ~1.0%).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Like KATAGO_MLX_WINOTUNER_FULL, this env var predates MLX support in the
tuner subcommand. It set reTune=true on the model-load path so a normal
benchmark/gtp run would ignore the cache, re-tune, and overwrite.

OpenCL has no load-time force-retune analog: OpenCLTuner::loadOrAutoTune
loads the cached params if present (falling back to the full-size cache) and
only auto-tunes on a complete miss; the only re-tune paths are the explicit
tuner command or deleting the cache file. Now that `./katago tuner` works on
MLX and always re-runs + overwrites (reTune=true), the env var is redundant
with that established pattern, so drop it and pin the load-time reTune=false.

The model-load path now loads a valid cache or coarse-tunes once on a miss,
never re-tuning a valid cache and never sweeping the wide grid - both are
reached only through `./katago tuner` (-full for the wide grid). To refresh
the cache, run the tuner command (or delete the cache file). No reachable
end-state is lost; only the inline-during-benchmark/gtp convenience, which
OpenCL never had.

KATAGO_MLX_WINOTUNER (disable tuning, use baked defaults) is unchanged - a
distinct switch the command does not cover.

Verified: FORCE=1 benchmark now loads the cache instead of re-tuning;
KATAGO_MLX_WINOTUNER=0 still disables tuning; runtests pass; testgpuerror
unchanged (fp32 0.0005%).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The MLX backend pulls in external/katagocoreml, which requires protobuf and abseil, but Compiling.md listed them only under the Metal backend. `brew install mlx` does not provide them transitively, so an MLX-only build following the docs failed at find_package(Protobuf REQUIRED).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The tuner cache filename hardcoded "AppleSilicon", so every Apple chip shared one cache file: a cache tuned on e.g. an M1 would be loaded verbatim on an M4 Max, where the optimal Winograd launch geometry differs. Add a shared MLXWinogradTuner::detectGpuName() that reads the chip brand string (sysctl machdep.cpu.brand_string, e.g. "Apple M3 Max") with an "AppleSilicon" fallback, and call it from both the backend model-load path and the `tuner` command so their cache keys always match.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
In MLX C++ there are no weak scalars: mx::array(scale) is a strong
float32 array, so `matmul(q,kT) * mx::array(scale)` promoted the
attention scores -- and the whole transformer residual stream, since
each block adds its output into the trunk -- to fp32. Cast the scale to
the compute dtype so the fp16 path stays fp16 end-to-end, matching the
pooling/BN/RMSNorm layers (which already astype their fp32 intermediates
back) and the maximize-fp16 goal for the GPU path.

Verified by full rungpuerrortest.sh: GPU 37/37 pass, ANE 31/31 supported
pass (6 pre-v8 SIGABRTs expected/unrelated); worst-case fp16 winrate
2.50% / topPolicy 1.25% (thresholds 5.0% / 6.0%). runtests and
runnnlayertests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… param, transformer test

Address four follow-up items from the PR#1199 review (test/tuner-only;
no inference forward-pass changes):

1. Atomic tuner-cache save. MLXWinogradTuneParams::save now writes to a
   per-process temp path (filename + ".tmp.<pid>") and FileUtils::rename's
   it onto the final path, so two processes that cache-miss and tune the
   same model concurrently can no longer tear the shared cache file.

2. Independent Winograd oracle. The GPU and FP16 Winograd metal_kernel
   tests previously asserted only against cpuConv2d3x3, itself a Winograd
   F(2,3) impl sharing the kernel's B/G/A transform matrices -- a shared
   sign/transpose error would cancel and pass. They now also assert against
   the independent naive direct-conv oracle.

3. Remove the dead seedOverride parameter from MLXWinogradTuner::
   loadOrAutoTune (declaration, definition, and both call sites). It was
   documented "reserved ... currently ignored" and always passed nullptr.

4. Transformer-layer numeric test (runMLXTransformerLayerFP16Test): the
   transformer path (RMSNorm / attention / RoPE) had no layer-level
   coverage -- only end-to-end via testgpuerror. Adds RMSNorm fp32-vs-CPU
   correctness, attention fp16 output-dtype preservation (the regression
   guard for the just-fixed fp16->fp32 attention-scale promotion), fp16/
   fp32 closeness, and a zero-outProj residual-identity anchor; covers
   fixed-RoPE on/off and mask on/off.

Verified: build clean; runtests and runnnlayertests pass; the three
transformer nets pass testgpuerror on the MLX GPU path within thresholds.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant