Add MLX backend for Apple Silicon (GPU + ANE/CoreML dispatch, incl. transformer nets)#1199
Draft
ChinChangYang wants to merge 34 commits into
Draft
Add MLX backend for Apple Silicon (GPU + ANE/CoreML dispatch, incl. transformer nets)#1199ChinChangYang wants to merge 34 commits into
ChinChangYang wants to merge 34 commits into
Conversation
b544c66 to
dcf296a
Compare
Introduces a new neural-net backend (USE_BACKEND=MLX) targeting Apple
Silicon via Apple's MLX framework. The backend implements the full
nninterface contract (model load, batched evaluation, FP16/FP32 paths)
and ships with a Winograd 3x3 convolution path plus an adaptive
per-shape tuner that picks the fastest implementation for each
conv-3x3 shape at model load.
Backend
- cpp/neuralnet/mlxbackend.cpp: backend implementation. Supports
variable board sizes via input masking (same nnXLen/nnYLen
contract as other backends; the global COMPILE_MAX_BOARD_LEN
bound still applies). FP16/FP32 selected by the mlxUseFP16 config
(default auto -> fp16); same input feature layout as the other
backends. Mish activation runs FP16-safe (asserts on
ACTIVATION_MISH_SCALE8 so out-of-range variants are caught
explicitly rather than silently truncated).
- cpp/neuralnet/mlxwinograd.h: F(4x4, 3x3) Winograd transform with
fused activation + residual add.
- cpp/neuralnet/mlxwinotuner.{cpp,h}: per-shape Winograd tuner with
adaptive scoring (rotates the candidate set per shape, scores by
median-time delta against a baked-default baseline). Logs the
conv-3x3 shape distribution at model load.
- cpp/neuralnet/mlxtests.cpp: unit tests for the Winograd path
and tuner numeric-consistency, gated under runnnlayertests.
Build / wiring
- cpp/CMakeLists.txt: USE_BACKEND=MLX target. MLX requires CMake
3.27 (cmake_minimum_required stays at 3.18.2 so other backends
keep building on older CMake). Links Homebrew's prebuilt
libmlx.dylib; OSX deployment target intentionally not pinned so
the executable's minos matches the dylib it was linked against.
- cpp/main.cpp, cpp/program/setup.cpp, cpp/command/benchmark.cpp:
wire MLX into backend selection / benchmark.
- cpp/configs/{gtp,analysis,match,contribute}_example.cfg: document
mlxUseFP16 (auto / true / false), default auto -> fp16.
- Compiling.md: build instructions for the MLX backend.
Validation
- Cross-backend validation against an Eigen reference (testgpuerror)
for b18c384nbt, b40v8, and humanv0 nets shows FP32 max winrate
error 0.00095% and FP16 max 2.63%, well within the existing
backend tolerances.
This is the squash of 130 commits from feature/mlx-backend.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
dcf296a to
81b00db
Compare
…arity smoke test (#26)
31eec63 to
628e377
Compare
# Conflicts: # cpp/rungpuerrortest.sh
master consolidated createComputeContext's trailing params (openCLTunerFile, openCLReTunePerBoardSize, useNHWCMode) into a single ConfigParser& cfg. The Metal backend was already updated; update the MLX backend to match so it compiles against the merged interface. NHWC is still enforced per-handle via inputsUseNHWC in createComputeHandle. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The MLX backend implemented only the convnet path; transformer nets crashed (empty-(0)-tensor broadcast on rmsnorm/SiLU tips) or produced garbage (GQA). Implement the transformer trunk/tip path in mlxbackend.cpp, mirroring eigenbackend.cpp: - ACTIVATION_SILU (x * sigmoid(x)) - TransformerRMSNormLayer (spatial rmsnorm tip) + TransformerTrunkRMSNormLayer (pre-LN) - GQA TransformerAttentionBlock + SwiGLU FFNBlock - branch the trunk tip on trunkNormKind; wire the new block kinds into the block-variant and nested-bottleneck loops; thread nnX/nnY through Verified via testgpuerror against fresh Eigen references (boardsize 19): fp32 winrateError max — rope 0.00094%, silu 0.00046%, gqa 0.00029% (bar 0.10%); convnet g170-b6c96 unregressed (0.00036%); runtests + runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The MLX backend's ANE mux path (mlxDeviceToUseThread0=100) drives inference through the shared external/katagocoreml converter -- the same library the Metal backend uses. Bring PR#1205's CoreML/ANE transformer work into that converter so the MLX ANE path supports the v15+ transformer trunk: - Transformer MIL support: attention (incl. grouped-query attention), learnable RoPE, SiLU, RMSNorm/batchnorm tips, SwiGLU FFN. - FP16 accuracy precision tiers, gated on actual transformer-block presence (blocksContainTransformer, recursing into nested-bottleneck blocks): narrow trunks (<256ch) build fully FP32; wider ones escalate non-spatial matmuls + global pooling to FP32; very wide (>=320ch) also escalate convs; RMSNorm reductions FP32 in FP16 mode. Plain convnets stay pure FP16 on the ANE (the d052d2a regression-fix behavior is preserved). The converter's public API is unchanged, so the MLX call site (CoreMLConversion::convertModelToTemp) needs no edits. The Metal-GPU/MPSGraph portions of PR#1205 (metalbackend.cpp, metallayers.swift) are intentionally not ported -- the MLX backend's native GPU path already has transformer support. Verified on the MLX ANE mux (testgpuerror vs fresh Eigen FP32 references): all 3 transformer test nets pass FP16 thresholds across board sizes/buffer configs (7 configs), a plain convnet stays pure FP16 (non-regression), and runtests + runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
applyGlobalPooling / applyValueHeadPooling summed in fp16 but produced an fp32 mean (division by the fp32 maskSum), which also leaked fp32 into the downstream gpool-bias and value-v2 head matmuls. Cast maskSum to the input dtype so the whole pooling and the heads stay in the compute dtype (fp16 when useFP16), maximizing fp16 utilization rather than escalating to fp32 for negligible accuracy gain. The masked-max keeps its 1e9 constant in fp32 (1e9 overflows fp16 -> inf -> 0*inf=NaN), then casts the max result back to the compute dtype. The fp32 path is unaffected (the astype casts are no-ops in fp32). Verified via testgpuerror vs fresh Eigen fp32 references on all 3 transformer nets (7 board-size/buffer configs): fp16 winrate error max <= 2.07% (within tolerance, winrate unchanged vs baseline), fp32 path byte-identical, ownership output bit-identical, runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…nt warmup
M2: A candidate whose threadgroup exceeds the pipeline's register-pressure-
dependent maxTotalThreadsPerThreadgroup (can be < 1024), or that hits a
transient GPU error, throws out of mx::eval during the flat sweep. Previously
this propagated out of loadOrAutoTune and aborted model load with no fallback.
Now each candidate's scoring is wrapped in try/catch: a throw is counted and
skipped (mirroring the OpenCL tuner's mark-bad-and-continue), and best/bestTime
are seeded with the baked default so even a fully-failing sweep returns a valid
result. A separate "flatSweep{Input,Output} skipped=N" log line is emitted only
when skips occur; it intentionally omits the colon after the function name so it
cannot collide with the regex-tested "flatSweepInput: considered" log line.
M3: timeOneInputTransform/timeOneOutputUntransform ran an untimed warmup eval on
every call, but the scoring functions already warmed up once before the measured
loop -- so every measured rep paid an extra full warmup (~doubling tuning cost).
Add a doWarmup parameter gating the internal warmup; the scoring functions drop
their explicit warmup and pass (r == 0), so each shape warms exactly once on its
first measured rep.
Verified by triggering autotuning: gated flat-sweep tests (convergence,
log-format, baseline-consistency, per-shape) pass; an end-to-end re-tune via
loadOrAutoTune runs a fresh sweep and saves valid fp16/fp32 caches; testgpuerror
output is unchanged (tuner params are numerically inert); runtests passes.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The METAL and MLX backend branches each ran set(CMAKE_OSX_DEPLOYMENT_TARGET 13.0) *after* project(), where it is a silent no-op: for this Swift project the deployment target is fixed during project()/enable_language, so a later set() never affects the produced binary. Both shipped binaries already carry minos 26.0 (the build host / libmlx's floor), not 13.0, confirming the pins were inert dead code that contradicted the pre-project comment explaining why the deployment target is deliberately not pinned. Delete both pins so code, comment, and reality agree; the comment becomes literally true. Add a guard note documenting that a post-project pin is a no-op so it is not reintroduced. No behavior change: binaries still build at minos 26.0, matching libmlx's minos and MLX's macOS >= 14 requirement. Verified: MLX reconfigure+build clean; METAL branch configures clean; binary minos unchanged (26.0); runtests pass; testgpuerror unchanged (fp32 max 0.00036%, fp16 max 0.863%). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The katagocoreml parser read the q/k/v/out projection matmuls of a
transformer attention block without checking their declared dimensions
against the head geometry or trunk width. The CoreML graph builder
(MILBuilder) then reshapes each flat projection into a [seq, heads,
headDim] grid and reshapes the out-projection result back into the trunk,
so a mismatched dimension would either read past the weight buffer or
build a graph that compiles but computes nonsense - the exact failure
mode the existing checkBlockChannels() guard was added to prevent for
conv blocks.
Thread trunk_num_channels into parseTransformerAttentionBlock (mirroring
parseNestedBottleneckBlock) and add a checkAttentionProjDim() helper in
the style of checkBlockChannels(), then validate all four projections:
qProj.outChannels == numHeads * qHeadDim (master desc.cpp:1129)
kProj.outChannels == numKVHeads * qHeadDim (master desc.cpp:1131)
vProj.outChannels == numKVHeads * vHeadDim (master desc.cpp:1133)
outProj.inChannels == numHeads * vHeadDim (master desc.cpp:1135)
qProj.inChannels == trunkNumChannels (master desc.cpp:1430)
outProj.outChannels== trunkNumChannels (master desc.cpp:1437)
k/vProj.inChannels == trunkNumChannels (gap master leaves
implicit to the backend)
Six checks mirror master desc.cpp's transformer attention consistency
checks exactly; the k/v inChannels checks additionally close a gap master
leaves to the backend (all three QKV projections consume the same
normed-trunk input, so their inChannels must equal the trunk width). K
pairs with Q in the QK^T dot product, so kProj uses qHeadDim; only V
carries vHeadDim.
Purely additive: throws std::runtime_error on a malformed model, no-op on
valid ones, no numerics touched. Verified: MLX build clean, runtests
pass, and ANE-path testgpuerror on all three transformer nets (incl. the
GQA net with 6 heads/3 KV heads, qk=32/v=16) loads and converts with zero
false-positive throws and unchanged numerics.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Cuts memory during the on-device KataGo -> CoreML conversion and while running the ANE/CoreML path, with byte-identical converter output: - The converter's weight tensors become non-owning views into the parsed model instead of owning extra FP32 copies; derived/transposed tensors keep an owned buffer. This drops redundant resident weight copies during conversion. CoreML model serialization is made deterministic (SetSerializationDeterministic) so the output is byte-stable. - The KataGo model parser streams the gzip through a bounded ~1 MB refill buffer instead of decompressing the whole file into memory, while preserving the existing NaN/Inf weight validation. - ModelDesc gains releaseWeights(), which frees the in-memory weight arrays (keeping scalar shape metadata). The Metal backend calls it on the ANE (CoreML) path after converting from the model file on disk, gated by a new ComputeContext::aneOnly flag so it only fires when every configured device is ANE -- the GPU/MPSGraph path keeps its weights. The call is serialized under computeHandleMutex and only scalar dims are read afterward. Measured on b18c384nbt (19x19) over the ANE path: idle steady-state RSS 0.59 GB -> 0.19 GB; peak (load+convert) 0.87 GB -> 0.48 GB. Cross-backend parity vs an Eigen reference is unchanged on both the GPU and ANE paths. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit b05f559)
WeightEntry stores a non-owning view (const float*, count) into the live KataGoModelDesc, so the backing std::vector must outlive serialization. addConstOp/registerWeight took the data by const& and silently stored a pointer to it; a caller passing a temporary would bind to that const& and leave the view dangling, read much later during serialization. Delete the rvalue overloads of both so any such call fails to compile, forcing temporaries through addOwnedConstOp/registerOwnedWeight (which take ownership). Named lvalues (the model-member call sites) still bind to the const& overload, so no existing caller changes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 971fa9d)
Own the gzFile with a custom-deleter unique_ptr so it closes on every exit path (normal return, exception, bad_alloc); removes the manual try/catch+gzclose in parse() and the ordering caveat on buffer allocation. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit eeefc97)
Introduce a KataGo-local non-owning FloatView for WeightEntry::data instead of a raw const float*/size_t pair; convert to MILBlob::Util::Span only inside WeightSerializer, keeping the MILBlob dependency out of Operations.hpp. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 6bfa617)
The ComputeHandle member-order comment claimed that declaring mpsGraphOnlyHandle before coremlOnlyHandle is what prevents a GPU handle from reading freed weights. That overstates the ordering's role: within a single ComputeHandle exactly one handle is built (mutually exclusive on gpuIdx, enforced by the ctor's exactly-one check), and releaseWeights() only fires on an aneOnly context where no MPSGraph handle is ever built. Reframe the declaration order as belt-and-suspenders and point at ComputeContext::aneOnly as the actual invariant. Comment-only change. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 4159930)
Replace the file-local releaseXXX free functions in desc.cpp (which reached into each desc struct's internals from outside) with releaseWeights() member methods on each weight-bearing struct, matching the existing OO convention used by applyScale8ToReduceActivations() and iterConvLayers(). Each container delegates to its members; type-erased block dispatch is inlined with the same cast pattern those methods use. Behavior-preserving: same set of freed vectors, same block recursion, same metaEncoderVersion guard. ModelDesc::releaseWeights() keeps its signature, so the metalbackend.cpp call site is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 44342a3)
Move the 11 leaf/container releaseWeights() definitions in desc.cpp out of the bottom cluster (inherited from the old free-function layout) and place each immediately after its struct's last existing method, matching the file's per-struct grouping convention used by every other method. ModelDesc::releaseWeights() stays put, already adjacent to its siblings. Pure relocation: function bodies and desc.h are unchanged; only two stray double-blank lines were normalized to single. Verified clean Metal build, testgpuerror vs Eigen reference (g170-b6c96) at <0.0004% winrate error, and runtests all pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 98b17eb)
…contract The transformer attention builder emits four function-local std::vector<float> tensors: RoPE cos/sin tables, the rotation matrix R, and per-head out-projection weight slices. After merging the transformer support onto the FloatView branch, these needed two fixes: 1. Dangling view. lightvector#1202 made WeightEntry::data a non-owning FloatView, so addConstOp registers a view whose backing buffer must outlive serialization. These locals were passed to addConstOp and would dangle once the build function returns (serialization runs afterwards). Route them through addOwnedConstOp so KataGoOps owns the buffer until serialization. (Under lightvector#1205's owning WeightEntry they were copied, so this only surfaces post-merge.) 2. dtype mismatch. emitConstOp declares each const's dtype as m_weight_dtype, but addOwnedConstOp / registerOwnedWeight stored at the global mode (is_fp32 hardcoded false). In an FP16 model these derived consts land in the attention / value-head FP32 sub-region (m_weight_dtype == FLOAT32), so they were declared FP32 but stored FP16. CoreML/ANE then rejects the model at load ("Metadata data type does not match requested type", BNNS error -14), which SIGABRT'd every FP16 ANE transformer. Thread is_fp32 through registerOwnedWeight and have addOwnedConstOp pass is_fp32 = (m_weight_dtype == FLOAT32), mirroring addConstOp so the stored dtype always matches the declared dtype. This also fixes the same latent mismatch for addLinearOp's transposed value-head weights. Verified with testgpuerror against fresh Eigen FP32 references: b7c96h3tfrs and b7c96h6gqa, which previously SIGABRT'd on the FP16 ANE path, now load and match to <0.0005% winrate; convnet ANE output is byte-identical and the Metal GPU path is unchanged. katago runtests and runnnlayertests also pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 8481a94)
The cherry-picked per-struct releaseWeights() refactor (44342a3/98b17ebb) predates this branch's MLX transformer port, so it only added releaseWeights() to the non-transformer descriptors. Extend the coverage to the transformer descriptors present on this branch (RMSNormLayerDesc, TransformerRMSNormDesc, TransformerAttentionDesc incl. ropeFreqs, TransformerFFNDesc) and handle TRANSFORMER_ATTENTION_BLOCK_KIND / TRANSFORMER_FFN_BLOCK_KIND plus trunkTipRMSNorm in the trunk release walk. Without this, releasing weights on a transformer model would hit ASSERT_UNREACHABLE. This makes desc.cpp/desc.h byte-identical to the lightvector#1202 feature branch. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Port the lightvector#1202 ANE steady-state memory lever to the MLX backend. Add ComputeContext::aneOnly, set in createComputeContext when every configured device index is MLX_MUX_ANE, and call ModelDesc::releaseWeights() in convertAndCreateCoreMLOnlyHandleMLX after the model has been converted to CoreML on disk. Safe because: the ANE path re-reads the model from modelPath (not the in-memory weight arrays); the ComputeHandle ctor takes the MLX_MUX_ANE early-return before building any MLX/GPU model (the only weight-array consumer); only scalar dims are read afterward, which releaseWeights() preserves; and it runs under computeHandleMutex. Mirrors the Metal backend's aneOnly release. GPU path unaffected (aneOnly is false whenever any thread uses MLX_MUX_GPU). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
parseTransformerFFNBlock only checked num_channels/ffn_channels > 0, while the attention block validates all of its projection dimensions. Thread trunk_num_channels through and add the mirror checks: num_channels must equal the trunk width (the block adds its output back into the trunk residually) and the linear layers must chain numChannels -> ffnChannels -> numChannels (with the SwiGLU gate also numChannels -> ffnChannels). A malformed FFN block now fails at parse time instead of producing an opaque CoreML compile error or silently-wrong activations. Reuses the existing checkAttentionProjDim helper. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Six conv/matmul/RMSNorm/FFN sites hand-rolled save/restore of m_weight_dtype around their FP32 escalation windows. An exception thrown inside a window would leave m_weight_dtype stuck at FLOAT32, causing later FP16 consts to be tagged FP32 -> the BNNS "Metadata data type does not match" SIGABRT on the FP16 ANE. Give ScopedFp32 an active flag (so a conditional window needs no construction-time branch) and an idempotent restore() (to end the window before a trailing cast-down while keeping the dtor's exception-safe restore), then route all six sites through it. The guard is constructed exactly where the manual flip was and restore() called exactly where the manual restore was, so op-emission order -- and thus the serialized converter output -- is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The raw-output forward path had no callers: production inference goes through getOutput() -> Model::applyCompiled(). It also duplicated applyCompiled's input setup and output copy. Delete it. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… test comment CMakeLists.txt: uppercase USE_BACKEND into USE_BACKEND_NORMALIZED before the pre-project() MLX version guard and the Swift language selection, mirroring the post-project() string(TOUPPER). Previously a lowercase -DUSE_BACKEND=mlx skipped the CMake 3.27 guard and Swift enablement, then still tried to build the Swift sources later, producing a confusing failure instead of a clear message. rungpuerrortest.sh: the gpu/ane modes drive whichever backend the binary was built with (backend-agnostic deviceToUseThread0), so reword the usage comment from "the Metal backend" to "the active backend (Metal or MLX)". Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The F(2x2,3x3) input transform and output untransform did all their add/sub arithmetic in T, so in fp16 mode every intermediate of B^T d B and A^T M A rounded to fp16 -- a precision sink independent of the matmul (which already accumulates in fp32 via the steel GEMM). Compute the transform arithmetic in float and cast only the final stored V/M/Y back to T, leaving fp16 storage and memory traffic unchanged (no-op on the fp32 path). Cuts fp16 Winograd kernel error ~33% (runnnlayertests ConvLayer fp16 winograd maxErr 0.0107 -> 0.0071) while the fp32 path stays bit-exact (maxErr=0). testgpuerror FP16-vs-Eigen avg/90%/99% errors drop across configs; g170e-b10c128 worst case 2.35% -> 2.13%. Benchmark throughput within run-to-run noise on both a convnet and a deep NBT net. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…uning
The model-load auto-tune swept a nearly-full candidate grid (~2000 configs, ~16s) even with full=false -- only tg1 differed between non-full and full. Measured on this hardware the winning configs form a broad plateau (many within ~7% of each other, all ~25-40% better than the baked default) and geometry moves end-to-end throughput <=1.5%, so that 16s was spent discriminating run-to-run noise: three forced re-tunes of the same net picked three entirely different winners, all within ~7%.
Mirror OpenCL's split: full=false (auto) now sweeps a coarse grid (tg0 {8,16,32,64,128}, tg1 {1,2,4,8,16}, wpt {1,2,4}) -- ~2.7s, still landing ~21% above the default and within ~6% of the wide-sweep winner. full=true keeps the wide grid as the deliberate command-tune, opt-in via KATAGO_MLX_WINOTUNER_FULL=1 (the analog of './katago tuner --full', which openclbackend.cpp pins to full=false at load). Cache format is unchanged; existing caches still load, and FULL+FORCE overwrites with the wide-swept winner.
Also loosen two gated tuner stress-test budgets (baseline-anchor 0.25->0.50, convergence 1.10->1.30) that compared single sub-millisecond timing samples and flaked ~1-in-4 on both this and the pre-change binary -- the same dispatch/sync-overhead noise the tuner itself tolerates. They remain gross-error sanity checks. runtests + runnnlayertests pass, gated tuner tests 16/16, testgpuerror fp32 bit-exact and fp16 winrate max 0.55% on the coarse-tuned config.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
M4 - silent fallthroughs in mlxbackend.cpp, now loud (ASSERT_UNREACHABLE,
the file's existing idiom):
- BlockVariant::apply switch default returned the input unchanged, a silent
identity no-op block. Now asserts; kept as a default: label so the active
-Wswitch-default (CMakeLists.txt) stays clean.
- The trunk block-construction loop silently dropped unknown kinds. Added an
else { ASSERT_UNREACHABLE; }.
- The nested-bottleneck construction loop was a genuine latent bug, not just a
missing guard: parseResidualBlockStack (desc.cpp, shared by trunk and nested)
accepts nested_bottleneck_block inside a nested bottleneck and the desc layer
handles it, but the MLX nested loop omitted NESTED_BOTTLENECK_BLOCK_KIND and
silently dropped such a block. Added the missing case (mirroring the trunk
loop and Eigen's shared BlockStack) plus an else-assert.
M3 - Compiling.md implied the MLX backend builds with make, but CMakeLists.txt
hard-fails MLX without the Ninja generator (same Swift/C++ interop requirement
as Metal). Added -G Ninja to the MLX cmake example, listed MLX alongside Metal
for the Ninja prerequisite, and noted MLX uses ninja to build.
Verification: build clean; runtests + runnnlayertests pass; testgpuerror
g170-b6c96 vs eigen_reference.json fp32 near-exact (winrate max 0.00036%) /
fp16 max 0.55% (unchanged); testgpuerror on the b4c256h4nbttflrs nested-
bottleneck+transformer model loads through the modified nested loop and runs
forward with no assert (fp16 ~0.27%).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The `tuner` command previously did nothing on non-OpenCL backends. Add an MLX branch that loads the model, builds its conv-3x3 shape histograms, and runs MLXWinogradTuner::loadOrAutoTune with reTune=true so the Winograd input/output transform search runs and overwrites the cache the backend reads at model load. This is the first-class "command tune" path; the load-time auto-tune stays coarse/fast. -full selects the wide candidate grid (the env-var KATAGO_MLX_WINOTUNER_FULL=1 behavior, which still works for triggering a full tune through benchmark/gtp). -testFP16 (auto->FP16) matches the engine's useFP16 default and the cache-filename key. The default output path is MLXWinogradTuner::defaultDirectory/defaultFileName - the exact file the backend loads - verified end-to-end: after a default-path tune, a benchmark model-load logs "Loaded MLX Winograd tuning parameters from" that same file. The OpenCL-only FP16 sub-knobs have no MLX analog and are omitted. The backend guard is restructured from #ifndef USE_OPENCL_BACKEND to a three-way #if/#elif/#else; the OpenCL body is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The env var was an explicit stopgap "command tune analog" for the wide Winograd sweep before `./katago tuner` supported MLX. Now that the literal tuner subcommand exists, drop the env var and pin the model-load path to the coarse grid (full=false), so the wide sweep is reached only through `./katago tuner -full`. This mirrors openclbackend.cpp, which pins full=false at load and passes full only from the explicit tuner command, and reinforces the coarse-auto-tune design. No capability is lost: `tuner -full` writes the same cache the backend reads at load, so the prior one-shot workflow (FULL=1 FORCE=1 benchmark) becomes the cleaner two-step (tune once, then run) - the OpenCL workflow. KATAGO_MLX_WINOTUNER (enable/disable) and KATAGO_MLX_WINOTUNER_FORCE (force re-tune) are unchanged; FORCE now only ever drives a coarse re-tune at load. Verified: FULL=1 FORCE=1 benchmark now re-tunes coarse (considered=288 for b10c128, was 2176), while `tuner -full` still sweeps the wide grid (considered=2176). runtests pass; testgpuerror unchanged (fp32 0.0005%, fp16 ~1.0%). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Like KATAGO_MLX_WINOTUNER_FULL, this env var predates MLX support in the tuner subcommand. It set reTune=true on the model-load path so a normal benchmark/gtp run would ignore the cache, re-tune, and overwrite. OpenCL has no load-time force-retune analog: OpenCLTuner::loadOrAutoTune loads the cached params if present (falling back to the full-size cache) and only auto-tunes on a complete miss; the only re-tune paths are the explicit tuner command or deleting the cache file. Now that `./katago tuner` works on MLX and always re-runs + overwrites (reTune=true), the env var is redundant with that established pattern, so drop it and pin the load-time reTune=false. The model-load path now loads a valid cache or coarse-tunes once on a miss, never re-tuning a valid cache and never sweeping the wide grid - both are reached only through `./katago tuner` (-full for the wide grid). To refresh the cache, run the tuner command (or delete the cache file). No reachable end-state is lost; only the inline-during-benchmark/gtp convenience, which OpenCL never had. KATAGO_MLX_WINOTUNER (disable tuning, use baked defaults) is unchanged - a distinct switch the command does not cover. Verified: FORCE=1 benchmark now loads the cache instead of re-tuning; KATAGO_MLX_WINOTUNER=0 still disables tuning; runtests pass; testgpuerror unchanged (fp32 0.0005%). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The MLX backend pulls in external/katagocoreml, which requires protobuf and abseil, but Compiling.md listed them only under the Metal backend. `brew install mlx` does not provide them transitively, so an MLX-only build following the docs failed at find_package(Protobuf REQUIRED). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The tuner cache filename hardcoded "AppleSilicon", so every Apple chip shared one cache file: a cache tuned on e.g. an M1 would be loaded verbatim on an M4 Max, where the optimal Winograd launch geometry differs. Add a shared MLXWinogradTuner::detectGpuName() that reads the chip brand string (sysctl machdep.cpu.brand_string, e.g. "Apple M3 Max") with an "AppleSilicon" fallback, and call it from both the backend model-load path and the `tuner` command so their cache keys always match. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
In MLX C++ there are no weak scalars: mx::array(scale) is a strong float32 array, so `matmul(q,kT) * mx::array(scale)` promoted the attention scores -- and the whole transformer residual stream, since each block adds its output into the trunk -- to fp32. Cast the scale to the compute dtype so the fp16 path stays fp16 end-to-end, matching the pooling/BN/RMSNorm layers (which already astype their fp32 intermediates back) and the maximize-fp16 goal for the GPU path. Verified by full rungpuerrortest.sh: GPU 37/37 pass, ANE 31/31 supported pass (6 pre-v8 SIGABRTs expected/unrelated); worst-case fp16 winrate 2.50% / topPolicy 1.25% (thresholds 5.0% / 6.0%). runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… param, transformer test Address four follow-up items from the PR#1199 review (test/tuner-only; no inference forward-pass changes): 1. Atomic tuner-cache save. MLXWinogradTuneParams::save now writes to a per-process temp path (filename + ".tmp.<pid>") and FileUtils::rename's it onto the final path, so two processes that cache-miss and tune the same model concurrently can no longer tear the shared cache file. 2. Independent Winograd oracle. The GPU and FP16 Winograd metal_kernel tests previously asserted only against cpuConv2d3x3, itself a Winograd F(2,3) impl sharing the kernel's B/G/A transform matrices -- a shared sign/transpose error would cancel and pass. They now also assert against the independent naive direct-conv oracle. 3. Remove the dead seedOverride parameter from MLXWinogradTuner:: loadOrAutoTune (declaration, definition, and both call sites). It was documented "reserved ... currently ignored" and always passed nullptr. 4. Transformer-layer numeric test (runMLXTransformerLayerFP16Test): the transformer path (RMSNorm / attention / RoPE) had no layer-level coverage -- only end-to-end via testgpuerror. Adds RMSNorm fp32-vs-CPU correctness, attention fp16 output-dtype preservation (the regression guard for the just-fixed fp16->fp32 attention-scale promotion), fp16/ fp32 closeness, and a zero-outProj residual-identity anchor; covers fixed-RoPE on/off and mask on/off. Verified: build clean; runtests and runnnlayertests pass; the three transformer nets pass testgpuerror on the MLX GPU path within thresholds. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds a new neural-net backend (
USE_BACKEND=MLX) targeting AppleSilicon via Apple's MLX framework,
with two dispatch paths sharing a single backend:
per-shape tuner.
backend's
gpuIdx = 100convention. Usable standalone or muxed withthe GPU path on the same model to overlap forward passes.
It implements the full
nninterface.hcontract (model load, batchedevaluation, FP16/FP32 paths) and reuses the existing CoreML conversion
pipeline shared with the Metal backend. Both dispatch paths support
the v15+ transformer trunk (attention incl. GQA, learnable RoPE, SiLU,
RMSNorm/batchnorm tips, SwiGLU FFN) in addition to convnets.
What's new
Backend (
cpp/neuralnet/)mlxbackend.cpp— backend implementation.nnXLen/nnYLencontract as other backends; the globalCOMPILE_MAX_BOARD_LENbound still applies), FP16/FP32 selectedby
mlxUseFP16(defaultauto→ fp16). Mish runs FP16-safe; thecode asserts on
ACTIVATION_MISH_SCALE8so out-of-range variantsfail loudly rather than truncate silently.
mlxDeviceToUseThread<N> = 100(or the backend-agnosticdeviceToUseThread<N>); shares the model+converter cache withMetal. Feeds the real spatial mask (channel 0 NHWC → buffer) so
rectangular / sub-NN-frame boards predict correctly, transposes
NHWC → NCHW into a per-batch staging buffer for the Swift
MLMultiArraycontract, and uses path-correct strides in thepolicy-optimism postprocessor for v12+ models.
ComputeHandleconstruction witha file-static mutex (CoreML converter is not concurrent), and
eagerly evaluates FP16 weight casts so secondary MLX/GPU server
threads don't trip MLX 0.31.2's thread-local command-encoder map.
mlxwinograd.h— F(2×2, 3×3) Winograd transform with fusedactivation + residual add.
mlxwinotuner.{cpp,h}— per-shape Winograd tuner with adaptivescoring (rotates the candidate set per shape, scores by median-time
delta against a baked-default baseline). Logs the conv-3x3 shape
distribution at model load. At model load it loads a valid cache or,
on a miss, coarse-tunes once — it never re-tunes a valid cache and
never sweeps the wide grid; explicit re-tuning is command-only via
./katago tuner(see Build / wiring).KATAGO_MLX_WINOTUNER=0skips tuning entirely and uses baked-default launch geometry, while
KATAGO_MLX_WINOGRAD=0falls 3×3 convs back tomx::conv2d.mlxtests.cpp— Winograd + tuner numeric-consistency tests, gatedunder
runnnlayertests.Transformer trunk support (both dispatch paths)
in
mlxbackend.cpp— scaled-dot-product attention with grouped-queryattention (GQA), learnable RoPE, RMSNorm and batchnorm trunk tips,
SwiGLU FFN, and the SiLU activation. Mirrors
eigenbackend.cpp.katagocoreml): the same transformer trunk inthe shared MIL converter, with FP16 accuracy precision tiers gated
on actual transformer-block presence (recursing into nested-bottleneck
blocks) so plain convnets stay pure FP16 on the FP16-only ANE. For
transformer trunks: narrow (<256ch) build fully FP32; wider ones
escalate the accuracy-sensitive non-spatial ops (matmuls + global
pooling) to FP32; very wide (≥320ch) also escalate convs; RMSNorm
reductions run FP32 in FP16 mode while attention scores/softmax and the
per-head out-projection stay FP16. Mirrors the TensorRT precision split
and the Metal/CoreML backend's tiers. The converter's public API is
unchanged, so the MLX ANE call site needs no edits.
Build / wiring
cpp/CMakeLists.txt—USE_BACKEND=MLXtarget; pulls in theMetal/Swift CoreML bridge so the ANE path links cleanly. MLX needs
CMake 3.27;
cmake_minimum_requiredstays at 3.18.2 so otherbackends keep building on older CMake. Links Homebrew's prebuilt
libmlx.dylib; OSX deployment target is intentionally not pinnedso the executable's
minosmatches the linked dylib.cpp/main.cpp,cpp/program/setup.cpp,cpp/command/benchmark.cpp— wire MLX into backend selection / benchmark.
cpp/command/tune.cpp— wires the MLX winotuner into the./katago tunersubcommand (previously OpenCL-only), mirroring OpenCL'scommand-only tuning.
./katago tuner -model <net>always re-runs thesweep and overwrites the cache the backend reads at model load; add
-fullfor the wide candidate grid. Args:-output(default: sharedMLX cache),
-xsize/-ysize(default 19),-batchsize(default8),
-testFP16(true|false|auto, defaultauto→ FP16, theprecision the engine and the cache filename key use),
-full. TheOpenCL-only FP16 sub-knobs (storage / compute / tensorcores) have no
MLX analog and are omitted.
cpp/configs/{gtp,analysis,match,contribute}_example.cfg—document
mlxUseFP16(defaultauto→ fp16) and thenumNNServerThreadsPerModel/mlxDeviceToUseThread<N>dispatchknobs (GPU-only / ANE-only / mux), with the note that
mlxUseFP16=falseon an ANE thread falls back to CPU FP32.cpp/rungpuerrortest.sh— backend-agnosticdeviceToUseThread0=100for the ANE mode, so the same scriptdrives whichever backend the binary was built with.
Compiling.md— build instructions.Conversion + ANE steady-state memory
The ANE path carries the on-device memory levers from #1202 (cut CoreML
conversion peak and ANE steady-state RSS, with byte-identical converter
output), cherry-picked onto this branch and wired into the MLX backend:
non-owning
FloatViews into the parsed model instead of owning extraFP32 copies (derived/transposed tensors keep an owned buffer). The
rvalue
registerWeight/addConstOpoverloads are= deleted so atemporary can't bind to the view path and dangle, and CoreML
serialization is made deterministic for byte-stable output.
through a bounded ~1 MB refill buffer (RAII
gzFile) instead ofdecompressing the whole file into memory, preserving the existing
NaN/Inf weight validation.
ModelDesc::releaseWeights()frees the in-memory FP32 weight arrays (keeping scalar shape metadata),
gated by a new
ComputeContext::aneOnlyflag that is true only whenevery configured device index is
MLX_MUX_ANE. On the MLX backend itfires in
convertAndCreateCoreMLOnlyHandleMLXafter the model has beenconverted to CoreML on disk; it is safe because the ANE path re-reads
the model from
modelPath(not the in-memory arrays), theComputeHandlector takes the ANE early-return before building anyMLX/GPU model (the only weight-array consumer), only scalar dims are
read afterward, and it runs under
computeHandleMutex. The GPU/MLXpath keeps its weights (
aneOnlyis false whenever any thread uses theGPU).
releaseWeights()covers the transformer descriptors too(attention incl.
ropeFreqs, FFN, RMSNorm tips).Measured (19×19, MLX ANE path)
Resident-memory A/B (
mlxDeviceToUse=100, cold conversion every run),before vs after the levers. Peak RSS is the
/usr/bin/time -lmaximumover load+convert; idle steady-state is
ps-sampled after load+genmove.b40c256)b18c384nbt)b10c384h6nbttflrs)The peak win is largest on the classical 40b (the duplicate owning weight
copies + whole-gzip decompress the levers target dominate there); the
transformer gains least because most of its converter-side weight is
derived (RoPE tables, rotation matrix, per-head out-proj slices) which
keeps an owned buffer, and its absolute peak is dominated by ANE
compilation of the MIL graph downstream of these host-side levers.
The levers don't change inference numerics:
testgpuerrorANE output isbit-identical before vs after on a convnet and all three transformer nets
(release happens after conversion; the MLX/GPU path never invokes the
converter), and benchmark throughput is within run-to-run noise on both
the GPU and ANE paths.
How to build
cd cpp cmake -G Ninja -DUSE_BACKEND=MLX ninjaRequires CMake ≥ 3.27 and
brew install mlx.Validation
Cross-backend
testgpuerrorvs an Eigen FP32 reference, viacpp/rungpuerrortest.sh(all 12 nets across 9×9 / 13×13 / 19×19 / 10×14 /rectangular / rect-buffer / weird-settings). MLX FP16 (
auto) vs EigenFP32 — winrate error per net (99%-ile / worst-case max across that
net's configs):
Pass: GPU 37/37 configs, ANE 31/31 supported configs. FP32-vs-Eigen
sanity (numerical noise): GPU ≤ 0.00092%, ANE ≤ 0.00082%.
./katago runtestsand./katago runnnlayertestspass.ᵗ 96ch transformers hit the converter's full-FP32 tier on ANE (FP16 == FP32);
the 256ch SiLU transformer uses the non-spatial-FP32 tier (real FP16 residual).
The 6 pre-v8 ANE configs (v3/v4/v5) are skipped — the CoreML converter
supports model versions 8–17 only. Run ANE
testgpuerrorserially or with aper-process
TMPDIR: the converter stages weights in a shared$TMPDIR/katagocoreml_weightsdir, so concurrent processes race.GPU throughput: MLX vs Metal
Single-GPU throughput of the two Apple-Silicon GPU backends, both configured
GPU-only (no ANE; MLX
gpuIdx 0, Metal default device 0), on an AppleM3 Max. The two backends take different GPU strategies, so this compares
each backend's actual GPU-only behavior rather than an identical-precision
kernel:
mlxUseFP16 auto).transformer trunk, is FP32-only).
The Metal binary is built from
feature/metal-transformer-silu-b10c384h6(the branch whose Metal GPU path supports the v15+ transformer trunk). Each
net was benchmarked with
benchmark -half-batch-size -t 8,16,32 -v 800at19×19, run sequentially — one backend / one net at a time. Cells are
visits/sat each thread count; the human-SL net usedhumanSLProfile=rank_9d(same override on both backends). The three v17transformer MLX rows were re-measured after the FP16-attention fix
(
52240987, median of 3 runs); the Metal column and all other rows areunchanged.
visits/s @ t=8 / 16 / 32
visits/s @ t=8 / 16 / 32
MLX-GPU is faster than Metal-GPU on every net (best-thread throughput
1.05×–1.82×), with the largest margins on the deep NBT nets
(
b18c384nbt~1.4–1.5×,b28c512nbt~1.44×) and the v17 transformer nets(~1.5–1.6×, now that the GPU attention path runs end-to-end in FP16) where
the FP16 Winograd path has the most to exploit. The FP16-vs-FP32 precision gap is part of that lead;
accuracy of each path vs an Eigen FP32 reference is in the Validation
table above.
Status
Draft — opening for early feedback on the backend's structure, the
tuner approach, the GPU/ANE dispatch wiring, and the transformer-trunk
support (incl. the ANE FP16 precision tiers) before promoting to
ready-for-review.