[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526
[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526ecnal-cienet wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
bac289f to
db75887
Compare
5a7f63b to
73213e0
Compare
7e33a09 to
2f34cfb
Compare
28ea6c4 to
8412184
Compare
f4674bb to
b7d1f6d
Compare
9d05b96 to
450ef8d
Compare
e420909 to
8a27207
Compare
2d3b8a6 to
52219b5
Compare
PR6-PR10 promoted every routed-to-Linen feature to NNX-native. This PR flips the three defaults in base.yml so NNX is the production path, pins Linen-coupled tests so the flip doesn't silently swap their backend, and bundles the NNX-only fixes that surface once pure_nnx=True (DiLoCo merge/checkpoint, Zero-1 input shardings on flat nnx.State, MTP sown-Variable handling, generate_param_only_checkpoint NNX flow, maxengine Linen-parity removal). NNX pipeline parallelism deferred to PR11.5; train_compile fails fast under pure_nnx=True with pipeline configured.
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
1 similar comment
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This Pull Request successfully transitions MaxText to use JAX's NNX API by default by flipping enable_nnx, pure_nnx, and pure_nnx_decoder to True. The changes are comprehensive and robust, addressing flat NNX state checkpointing, DiLoCo training, parameter-only restoration, and pinning Linen-coupled integration/unit tests.
🔍 General Feedback
- High-Quality Code Migration: The migration of core features (including DiLoCo, param-only generation, and maxengine) to support NNX-native behavior is very well-structured and thoroughly covered by updated tests.
- Robust Sharding Alignments: Adopting
build_zero1_input_state_mesh_shardingsto overlay Param-leaf shardings on the flatnnx.Stateensures seamless compatibility with ZeRO-1 optimizers under NNX. - Preserved Parity: Pinning complex pipeline parallelism and fp8/sparsity tests to the Linen path is a pragmatic decision that avoids regressions while these features are migrated in subsequent iterations.
| if hasattr(s, "keys"): | ||
| leaves, treedef = jax.tree_util.tree_flatten(s) | ||
| new_model_leaves, _ = jax.tree_util.tree_flatten(new_model) | ||
| N = len(new_model_leaves) | ||
| new_leaves = new_model_leaves + leaves[N:] | ||
| return jax.tree_util.tree_unflatten(treedef, new_leaves) |
There was a problem hiding this comment.
🟡 Medium: Using JAX flattening (tree_flatten/tree_unflatten) to replace a key in a dictionary or nnx.State is brittle as it relies on the implicit alphabetical sorting order of the keys (where "model" is sorted alphabetically before "optimizer", guaranteeing it occupies the first N leaves). If other keys are added in the future or if JAX's flattening behavior changes, this assumption will silently break and corrupt the unflattened state.
A more robust and idiomatic approach is to copy the dictionary/State and assign the "model" key directly. This handles dictionary/State replacement cleanly and is independent of any flattening assumptions or leaf ordering.
| if hasattr(s, "keys"): | |
| leaves, treedef = jax.tree_util.tree_flatten(s) | |
| new_model_leaves, _ = jax.tree_util.tree_flatten(new_model) | |
| N = len(new_model_leaves) | |
| new_leaves = new_model_leaves + leaves[N:] | |
| return jax.tree_util.tree_unflatten(treedef, new_leaves) | |
| if hasattr(s, "keys"): | |
| result = type(s)({}) | |
| for k, v in s.items(): | |
| result[k] = new_model if k == "model" else v | |
| return result |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
1 similar comment
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
custom_vjpfor NNX. (PR [NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out #3849)enable_nnx,pure_nnx,pure_nnx_decoderfromFalsetoTrueinbase.yml. Pin Linen-coupled tests so the flip doesn't silently swap their backend. Bundle the NNX-only fixes that only surface oncepure_nnx=True(DiLoCo, Zero-1, MTP, param-only checkpoint, maxengine Linen-parity removal).Description
PR6–PR10 promoted every routed-to-Linen feature to NNX-native. This PR makes NNX the default by flipping three flags in
base.yml, and bundles the NNX-only fixes that only surface once the default isTrue.Diff: +631 / −445 across 29 files (1 squashed commit).
Changes
src/maxtext/configs/base.yml— flip defaultsenable_nnx: False → True,pure_nnx: False → True,pure_nnx_decoder: False → True.src/maxtext/utils/sharding.py— Zero-1 on flatnnx.Statebuild_zero1_input_state_mesh_shardingsoverlaysParam-leaf shardings on the flatnnx.State. The Linen path calledstate_mesh_shardings.replace(params=...), which only exists onTrainState; the NNX flat state is now dispatched through the new builder.src/maxtext/trainers/pre_train/train.py,train_compile.py— NNX dispatch + pipeline guardpure_nnx=True.Intermediatesown variables before grad so MTP auxiliary losses don't get differentiated as part of the main loss.NotImplementedErrorfor pipeline +pure_nnx=True(deferred to PR11.5).src/maxtext/trainers/diloco/diloco.py,src/maxtext/common/checkpointing.py— DiLoCo under NNXDiLoCoTrainStatemerge/splitusennx.splitand guard against double-merging.maybe_save_checkpointderivesactual_stepfromstate.stepwhenenable_diloco, otherwise fromstate.optimizer.step.src/maxtext/utils/generate_param_only_checkpoint.py— NNX param-only restore{"value": ...}wrapping), opt_state path skipping, bf16 cast skipping rng leaves. Linen flow unchanged.src/maxtext/inference/maxengine/maxengine.py— drop Linen-only parity testssrc/maxtext/utils/{muon_utils,qk_clip_utils,train_utils}.py— NNX-shape adjustmentsmuon_utils.get_muon_weight_dimension_numbersdispatches by NNX-vs-Linen state shape.qk_clip_utilsbroadcasts over the correct axis under NNX.train_utils.jit_train_stepthreadsdropout_rng=Noneon the NNX path.Tests
Pinned to Linen (default flip would otherwise silently swap their backend):
tests/unit/tiling_test.py::LossAndGradientCorrectnessTest— builds viatransformer_as_linen; pin insetUp+ drop 6 stalepytest.skip("vocab tiling on NNX")guards (now NNX-native via PR10).tests/integration/pipeline_parallelism_test.py— class-level_LINEN_PINappended to the 6train_maintests; the 7 unit-style tests get the same pin viapyconfig.initializekwargs.tests/integration/sparsity_test.py— fp8 sparsity cases (b/509790223: LinenFp8DotGeneralBaseleaks intermediates inside the NNX context).tests/unit/quantizations_test.py—test_fp8_gpu_quantization,test_fp8_nanoo_quantization(same Fp8 issue).Linen-only tests removed:
tests/integration/maxengine_test.py— Linen-vs-NNX prefill/decode parity tests removed; NNX-only assertions kept.NNX shape/dispatch fixups (no semantic change):
tests/unit/{max_utils,muon_utils,maxtext_utils,optimizers,state_dtypes,train_state_nnx_checkpoint}_test.py— adjusted forTrainStateNNX/ flatnnx.Stateshapes.tests/integration/diloco_test.py— NNX training-loop simulation + checkpoint coverage.tests/integration/generate_param_only_checkpoint_test.py— NNX param-only restore coverage.Stats
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.