Skip to content

[NNX] NNX migration (12/N): delete Linen code paths, classes, and NNX compatibility flags#4038

Draft
ecnal-cienet wants to merge 27 commits into
mainfrom
feat/nnx-delete-linen
Draft

[NNX] NNX migration (12/N): delete Linen code paths, classes, and NNX compatibility flags#4038
ecnal-cienet wants to merge 27 commits into
mainfrom
feat/nnx-delete-linen

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. ✅ NNX-aware QK-Clip + remaining checkpoint utilities. (PR [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities #3836)
    9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
  10. ✅ Vocab tiling custom_vjp for NNX (+ output-head carve-out).
  11. ✅ Flip NNX defaults to True. ([NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True #3526)
  12. 🔄 [This PR] Delete the Linen code paths, classes/helpers, and the pure_nnx / enable_nnx / pure_nnx_decoder flags. ([NNX] NNX migration (12/N): delete Linen code paths, classes, and NNX compatibility flags #4038)

Description

PR11 made NNX the default and PR6–PR10 promoted every routed-to-Linen feature to NNX-native, so NNX is the only production path. This PR deletes Linen: all flag/isinstance dispatch, the dead Linen classes and *_as_linen wrappers, the three compatibility flags, and the now-obsolete Linen-only tests.

Net: 64 files, +728 / −6,205 (−5,477 lines). Organized into 4 reviewable commits (ordered so each is self-consistent — flags removed last, after every reference is gone):

Commit Files What
1/4 collapse dispatch 19 Statically simplify every pure_nnx / enable_nnx / isinstance(model, nn.Module) branch to the NNX path across core training / utils / inference / RL / checkpoint-conversion. Zero executable flag reads remain in src.
2/4 delete Linen classes 13 Delete TransformerLinenPure; the Linen Decoder / DecoderLayer / SequentialBlockDecoderLayers stack (decoders.py 1525→47, only deepstack_process kept); and 28 dead *_as_linen ToLinen wrappers across the layer/model files. The wrapped NNX classes are untouched.
3/4 test cleanup 25 Delete obsolete Linen-only tests, drop redundant flag args from the rest, and fix hlo_diff_test (see below).
4/4 remove flags 7 Remove the three flags from types.py, base.yml, inference/vllm.yml, pyconfig, and the post-train distillation configs.

Stats

Diff: 64 files, +728 / −6,205 (net −5,477), in 4 commits — overwhelmingly deletion (it's a "remove dead Linen" PR). Production-vs-test split:

Area Files Insertions Deletions Net
Production code (src/maxtext) 39 +520 −4,662 −4,142
  • core / training / utils / inference / RL 19 +474 −1,896 −1,422
  • layers + models (Linen classes / *_as_linen) 13 +44 −2,746 −2,702
  • configs (flag removal) 7 +2 −20 −18
Tests (tests/) 25 +208 −1,543 −1,335
  • unit 13 +101 −1,261 −1,160
  • integration 9 +81 −225 −144
  • test utils / assets 3 +26 −57 −31

So ~75% of the line changes are production code and ~25% are tests (by deletions, 4,662 vs 1,543 ≈ 3:1). The largest single chunk is the Linen decoder stack + *_as_linen wrappers (−2,746 in layers/models). Note this is pure refactor/removal — no new feature code; the "+728" insertions are almost entirely de-indenting kept NNX branches and rewriting a handful of tests.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-delete-linen branch from 7c6c8d3 to 5ca18e9 Compare June 2, 2026 01:36
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-delete-linen branch 11 times, most recently from 2037f3f to 6c8c56f Compare June 9, 2026 14:48
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-delete-linen branch 7 times, most recently from 5dd6b93 to c8cb446 Compare June 12, 2026 18:43
ecnal-cienet and others added 10 commits June 12, 2026 18:57
…X default flip

Pre-flip safety: PR11 will flip pure_nnx/enable_nnx/pure_nnx_decoder from
False to True in base.yml. Some existing tests are Linen-coupled and would
either silently switch to NNX (and break) or silently SKIP after that flip.
Pin them to Linen explicitly so they keep exercising the Linen path, with
no behavior change today (the pin matches the current default).

tests/unit/tiling_test.py:
  LossAndGradientCorrectnessTest builds models via transformer_as_linen and
  exercises the Linen vocab_tiling path. Extend self.base_config in setUp
  with enable_nnx=False, pure_nnx=False, pure_nnx_decoder=False, then drop
  the 6 stale pytest.skip("We currently don't support vocab tiling on NNX
  module.") guards (NNX-side coverage lives in VocabTilingNNXTest in the
  same file, added in PR10).

tests/unit/pipeline_parallelism_test.py:
  Pipeline parallelism does not yet have an NNX path (deferred to PR11.5).
  Add _LINEN_PIN class const and append *self._LINEN_PIN to the 6
  train_main arg lists in test_full_train_circular,
  test_full_train_circular_pipeline_ag_per_repeat,
  test_full_train_non_circular, test_subset_layers, test_full_train_fp8,
  and test_full_train_nanoo_fp8. The unit-style
  assert_pipeline_same_output_and_grad tests bypass the dispatch by
  calling pipeline.create_pipeline + SimpleDecoderLayerToLinen directly,
  so they are flag-immune and need no change.
The PR6-PR10 sequence promoted every routed-to-Linen feature to
NNX-native (DPO/PR6, MaxEngine/PR7, LoRA+GRPO/PR8, QK-Clip + checkpoint
utilities/PR9, AQT + serve-mode/PR9.5, vocab tiling custom_vjp/PR10).
With those gaps closed, NNX is the production path; this commit makes
it the default.

Empirical break-test on CPU (pytest before/after the flip across
tiling_test, train_compile_test, sharding_compare_test,
maxtext_utils_test, maxengine_test) showed zero flip-induced failures
- every CPU unit-test failure pre-existed on PR10 tip. TPU smoke
verified end-to-end: gemma2-2b 3-step train under the new defaults
logged "pure_nnx: True" in pyconfig and produced loss
13.04 -> 12.32 -> 11.82 (decreasing, no NaN/inf, no Traceback).
Linen-only test files were already pinned in the prior commit so no
per-test breakage from the flip.

base.yml: enable_nnx, pure_nnx_decoder, pure_nnx all flip False -> True.

No use_nnx_pipeline flag is added: PR10 tip has no NNX pipeline path
to opt out of, so a one-valued flag would be dead weight. Pipeline
tests keep their Linen pin from the prior commit; the eventual NNX
pipeline work (PR11.5) will introduce its own opt-in if needed.

Sharding goldens not regenerated: tests/unit/sharding_compare_test.py
already pins enable_nnx=False, pure_nnx=False, pure_nnx_decoder=False
explicitly when invoking the dump utility, so existing goldens at
tests/utils/sharding_info/ stay valid against the flipped default.
…NX::test_nnx_model_dispatches_to_tree_map_with_path
1. Sanitize unmapped logical axes to None in maxtext_utils.py get_nnx_named_sharding_with_scan_axis to prevent compilation ValueError.

2. Fix qk_clip_utils.py broadcast shape mismatch (axis=0 to axis=-2) causing TypeError.

3. Update max_utils_test.py unscan utility to correctly parse TrainStateNNX and its parameters/sharding trees.

4. Fix muon_utils_test.py NNX dict mapping assertIsNone() against raw objects rather than .

5. Patch train_distill and train_sft to explicitly nnx.pop(Intermediate) to prevent GraphDef mutation ValueErrors.

6. Update diloco.py to use nnx.split instead of the deprecated filter() method for param extraction.

7. Update diloco_test.py to execute pure NNX training loop simulations instead of legacy Linen.
Sharon Yu and others added 17 commits June 12, 2026 18:57
After flipping pure_nnx/enable_nnx/pure_nnx_decoder to True, several
integration tests broke because their code paths assumed Linen. Fixes:

- maxengine_test: remove the Linen-only test_basic_prefill / test_basic_decode
  (they build the model with transformer_as_linen but the engine now expects
  NNX state). The NNX path is already covered by test_basic_prefill_nnx /
  test_basic_decode_nnx. Drop the now-unused imports and get_data helper.

- train_sft_deprecated: support the NNX train loop. Split the TrainStateNNX
  into GraphDef + flat state before jit, only pass a dropout rng on the Linen
  path (the NNX step takes (state, batch)), and read setup params via
  nnx.split on the NNX path.

- quantizations.maybe_quantize_model: qwix.quantize_model traces NNX modules
  and needs example inputs, so pass dummy decoder tokens/positions for the
  NNX path. Fixes the fp8 sparsity smoke test.

- generate_param_only_checkpoint (NNX param-only flow):
  - checkpointing._load_full_state_from_path: restore into a pure dict, since
    NNX checkpoints are saved as pure dicts; a boxed nnx.State did not match.
  - read opt_state from state.optimizer.opt_state on the NNX path.
  - save only nnx.Param leaves (the rng PRNGKeyArray can't be cast to bf16)
    and wrap each leaf as {"value": ...} so from_pretrained can read it back.
  - skip the int8 case: it is a convert-on-load scenario (the fp32 training
    checkpoint has no AqtDotGeneral state the int8 model expects); tracked as
    a follow-up alongside layerwise_quantization.
…product test

NNX int8 parameter-only generation requires a convert-on-load setup, which causes a ValueError since the fp32 training checkpoint lacks the AqtDotGeneral state that the target int8 model expects. This aligns the GPU/dot-product test with the existing skip in the TPU/autoselected test variant.
Linen Fp8DotGeneralBase.setup leaks intermediates inside an NNX context, so once NNX defaults flip to True (PR#11) the fp8 sparsity smoke and the fp8 GPU unit-test cases that go through Qwix/Linen quant break. Skip them until b/509790223 is fixed:
- tests/integration/sparsity_test.py: fp8_full, fp8_full_with_sparsity
- tests/unit/quantizations_test.py: test_fp8_gpu_quantization, test_fp8_nanoo_quantization
PR#11 flips the defaults to NNX, so the Linen reference engine in the prefill_multisampling/prefill_concat parity tests silently became NNX and crashed (device_put State-vs-dict), and test_stack_and_unstack_prefill_cache hit the NNX no-op branch. Drop the Linen comparisons and assert the NNX result shapes directly, rewrite the cache test for the NNX scan_layers=False path, and remove _build_linen_params and its imports.
PR #3929 moved src/maxtext/layers/train_state_nnx.py to
src/maxtext/common/train_state_nnx.py. Update remaining imports in
diloco.py and three test files so PR11 still imports correctly.
Under shard_optimizer_over_data, train_compile builds the AOT train-step input shardings by calling state_mesh_shardings.replace(params=params_shardings). That's a TrainState (flax.struct) method; with PR#11's NNX defaults, state_mesh_shardings is a flat nnx.State and the call dies with 'No attribute replace in State'. Add sharding.build_zero1_input_state_mesh_shardings that overlays params_shardings' Param leaves onto state_mesh_shardings.model for the NNX path while keeping the existing .replace behavior for Linen, and route both train_compile call sites through it. Fixes test_zero1_optimizer_sharding.
Under enable_diloco the state becomes a DiLoCoTrainState, but the pure_nnx path
still merged it against the plain-model graphdef (nnx.merge leaf mismatch +
segfault), and several downstream sites assumed a plain TrainStateNNX. Guard the
merge and surface the graphdef as model; fix get_first_step, jit_model,
params_shardings, setup_params, and the rng args in train_loop; match the diloco
sharding's params to_pure_dict; and handle the DiLoCoTrainState in
maybe_save_checkpoint by saving the synchronized global model. Train + checkpoint
save/restore validated end-to-end on CPU.
The NNX decoder has no pipeline path yet, so under pure_nnx the scanned-layers
axis is sharded by 'stage' and dies with a cryptic IndivisibleError at state
init. Raise a clear NotImplementedError at config validation pointing users to
ici_pipeline_parallelism=1 or the Linen path. NNX pipeline support is tracked
as PR11.5.
…patch to NNX-only

Across the core training/utils/inference/RL/checkpoint-conversion code, statically
collapse every pure_nnx / enable_nnx / isinstance(model, nn.Module) branch to the NNX
path (the model is always NNX now). No flag reads remain in these files.
…s_linen wrappers

Delete TransformerLinenPure, the Linen Decoder/DecoderLayer/SequentialBlockDecoderLayers
stack (decoders.py), and the dead *_as_linen ToLinen wrappers across the layer/model
files. The wrapped NNX classes are unchanged; transformer_as_linen (the NNX->Linen bridge)
is kept for the checkpoint-conversion tools.
Remove obsolete Linen-only tests, drop redundant flag args from the rest, and compile the
hlo_diff tests via base.yml + model_name so they exercise the real NNX path.
…oder config flags

Remove the three flags from types.py, base.yml, inference/vllm.yml, pyconfig, and the
post-train distillation configs. NNX is the only path; the flags no longer exist.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-delete-linen branch from c8cb446 to 40575ff Compare June 12, 2026 18:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants