Skip to content

[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526

Open
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-set-defaults-true
Open

[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-set-defaults-true

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. ✅ NNX-aware QK-Clip + remaining checkpoint utilities. (PR [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities #3836)
    9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
  10. ✅ Vocab tiling custom_vjp for NNX. (PR [NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out #3849)
  11. 🔄 [This PR] Flip enable_nnx, pure_nnx, pure_nnx_decoder from False to True in base.yml. Pin Linen-coupled tests so the flip doesn't silently swap their backend. Bundle the NNX-only fixes that only surface once pure_nnx=True (DiLoCo, Zero-1, MTP, param-only checkpoint, maxengine Linen-parity removal).
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

PR6–PR10 promoted every routed-to-Linen feature to NNX-native. This PR makes NNX the default by flipping three flags in base.yml, and bundles the NNX-only fixes that only surface once the default is True.

Diff: +631 / −445 across 29 files (1 squashed commit).

Changes

src/maxtext/configs/base.yml — flip defaults

  • enable_nnx: False → True, pure_nnx: False → True, pure_nnx_decoder: False → True.

src/maxtext/utils/sharding.py — Zero-1 on flat nnx.State

  • New build_zero1_input_state_mesh_shardings overlays Param-leaf shardings on the flat nnx.State. The Linen path called state_mesh_shardings.replace(params=...), which only exists on TrainState; the NNX flat state is now dispatched through the new builder.

src/maxtext/trainers/pre_train/train.py, train_compile.py — NNX dispatch + pipeline guard

  • AOT compile dispatches to the Zero-1 NNX builder when pure_nnx=True.
  • Pops Intermediate sown variables before grad so MTP auxiliary losses don't get differentiated as part of the main loss.
  • Fail-fast NotImplementedError for pipeline + pure_nnx=True (deferred to PR11.5).

src/maxtext/trainers/diloco/diloco.py, src/maxtext/common/checkpointing.py — DiLoCo under NNX

  • DiLoCoTrainState merge/split use nnx.split and guard against double-merging.
  • maybe_save_checkpoint derives actual_step from state.step when enable_diloco, otherwise from state.optimizer.step.

src/maxtext/utils/generate_param_only_checkpoint.py — NNX param-only restore

  • Pure-dict restore ({"value": ...} wrapping), opt_state path skipping, bf16 cast skipping rng leaves. Linen flow unchanged.

src/maxtext/inference/maxengine/maxengine.py — drop Linen-only parity tests

  • Removes Linen-vs-NNX parity asserts that no longer make sense when NNX is the default. NNX-only prefill/decode/cache assertions stay.

src/maxtext/utils/{muon_utils,qk_clip_utils,train_utils}.py — NNX-shape adjustments

  • muon_utils.get_muon_weight_dimension_numbers dispatches by NNX-vs-Linen state shape.
  • qk_clip_utils broadcasts over the correct axis under NNX.
  • train_utils.jit_train_step threads dropout_rng=None on the NNX path.

Tests

Pinned to Linen (default flip would otherwise silently swap their backend):

  • tests/unit/tiling_test.py::LossAndGradientCorrectnessTest — builds via transformer_as_linen; pin in setUp + drop 6 stale pytest.skip("vocab tiling on NNX") guards (now NNX-native via PR10).
  • tests/integration/pipeline_parallelism_test.py — class-level _LINEN_PIN appended to the 6 train_main tests; the 7 unit-style tests get the same pin via pyconfig.initialize kwargs.
  • tests/integration/sparsity_test.py — fp8 sparsity cases (b/509790223: Linen Fp8DotGeneralBase leaks intermediates inside the NNX context).
  • tests/unit/quantizations_test.pytest_fp8_gpu_quantization, test_fp8_nanoo_quantization (same Fp8 issue).

Linen-only tests removed:

  • tests/integration/maxengine_test.py — Linen-vs-NNX prefill/decode parity tests removed; NNX-only assertions kept.

NNX shape/dispatch fixups (no semantic change):

  • tests/unit/{max_utils,muon_utils,maxtext_utils,optimizers,state_dtypes,train_state_nnx_checkpoint}_test.py — adjusted for TrainStateNNX / flat nnx.State shapes.
  • tests/integration/diloco_test.py — NNX training-loop simulation + checkpoint coverage.
  • tests/integration/generate_param_only_checkpoint_test.py — NNX param-only restore coverage.

Stats

  • Diff: +631 / −445 across 29 files.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 5 times, most recently from bac289f to db75887 Compare April 6, 2026 21:09
@ecnal-cienet ecnal-cienet changed the title Feat/nnx set defaults true NNX migration prep (5/N): enable NNX by default Apr 6, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 17 times, most recently from 5a7f63b to 73213e0 Compare April 9, 2026 23:47
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (5/N): enable NNX by default NNX migration prep (6/N): enable NNX by default Apr 16, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (6/N): enable NNX by default NNX migration prep (5/N): enable NNX by default Apr 20, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 4 times, most recently from 7e33a09 to 2f34cfb Compare April 28, 2026 14:16
Comment thread src/maxtext/layers/nnx_wrappers.py Outdated
Comment thread src/maxtext/layers/nnx_decoders.py Outdated
Comment thread tests/unit/max_utils_test.py
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 3 times, most recently from 28ea6c4 to 8412184 Compare May 21, 2026 19:58
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/nnx-set-defaults-true branch from f4674bb to b7d1f6d Compare May 22, 2026 10:29
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 2 times, most recently from 9d05b96 to 450ef8d Compare May 25, 2026 15:26
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/nnx-set-defaults-true branch 5 times, most recently from e420909 to 8a27207 Compare May 26, 2026 09:27
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 10 times, most recently from 2d3b8a6 to 52219b5 Compare May 28, 2026 20:08
PR6-PR10 promoted every routed-to-Linen feature to NNX-native. This PR flips the three defaults in base.yml so NNX is the production path, pins Linen-coupled tests so the flip doesn't silently swap their backend, and bundles the NNX-only fixes that surface once pure_nnx=True (DiLoCo merge/checkpoint, Zero-1 input shardings on flat nnx.State, MTP sown-Variable handling, generate_param_only_checkpoint NNX flow, maxengine Linen-parity removal).

NNX pipeline parallelism deferred to PR11.5; train_compile fails fast under pure_nnx=True with pipeline configured.
@github-actions

Copy link
Copy Markdown

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

1 similar comment
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request successfully transitions MaxText to use JAX's NNX API by default by flipping enable_nnx, pure_nnx, and pure_nnx_decoder to True. The changes are comprehensive and robust, addressing flat NNX state checkpointing, DiLoCo training, parameter-only restoration, and pinning Linen-coupled integration/unit tests.

🔍 General Feedback

  • High-Quality Code Migration: The migration of core features (including DiLoCo, param-only generation, and maxengine) to support NNX-native behavior is very well-structured and thoroughly covered by updated tests.
  • Robust Sharding Alignments: Adopting build_zero1_input_state_mesh_shardings to overlay Param-leaf shardings on the flat nnx.State ensures seamless compatibility with ZeRO-1 optimizers under NNX.
  • Preserved Parity: Pinning complex pipeline parallelism and fp8/sparsity tests to the Linen path is a pragmatic decision that avoids regressions while these features are migrated in subsequent iterations.

Comment on lines +295 to +300
if hasattr(s, "keys"):
leaves, treedef = jax.tree_util.tree_flatten(s)
new_model_leaves, _ = jax.tree_util.tree_flatten(new_model)
N = len(new_model_leaves)
new_leaves = new_model_leaves + leaves[N:]
return jax.tree_util.tree_unflatten(treedef, new_leaves)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium: Using JAX flattening (tree_flatten/tree_unflatten) to replace a key in a dictionary or nnx.State is brittle as it relies on the implicit alphabetical sorting order of the keys (where "model" is sorted alphabetically before "optimizer", guaranteeing it occupies the first N leaves). If other keys are added in the future or if JAX's flattening behavior changes, this assumption will silently break and corrupt the unflattened state.

A more robust and idiomatic approach is to copy the dictionary/State and assign the "model" key directly. This handles dictionary/State replacement cleanly and is independent of any flattening assumptions or leaf ordering.

Suggested change
if hasattr(s, "keys"):
leaves, treedef = jax.tree_util.tree_flatten(s)
new_model_leaves, _ = jax.tree_util.tree_flatten(new_model)
N = len(new_model_leaves)
new_leaves = new_model_leaves + leaves[N:]
return jax.tree_util.tree_unflatten(treedef, new_leaves)
if hasattr(s, "keys"):
result = type(s)({})
for k, v in s.items():
result[k] = new_model if k == "model" else v
return result

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

1 similar comment
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants