Skip to content

feat: migrate pipeline to nnx#2885

Open
mesakhcienet wants to merge 7 commits into
AI-Hypercomputer:mainfrom
CIeNET-International:test/pipeline-scan-nnx
Open

feat: migrate pipeline to nnx#2885
mesakhcienet wants to merge 7 commits into
AI-Hypercomputer:mainfrom
CIeNET-International:test/pipeline-scan-nnx

Conversation

@mesakhcienet

@mesakhcienet mesakhcienet commented Dec 24, 2025

Copy link
Copy Markdown
Collaborator

Description

implement nnx-based pipeline.

This PR extends PR#2831

Main changes:

  1. nnx_decoders.py: implementing the missing pipeline logic in nnx_decoders.py.
  2. pipeline.py : add a new class NNXPipeline, which is a nnx-based pipeline class.

Tests

we run the pipeline process with command below:

MODEL_NAME=llama2-7b
python -m MaxText.train src/maxtext/configs/base.yml \
    run_name=pipeline_test_${MODEL_NAME}_nnx \
    base_output_directory=/dev/shm/pipeline_test_nnx \
    model_name=${MODEL_NAME}\
    dataset_type=synthetic \
    steps=15 \
    debug_sharding=true \
    per_device_batch_size=2 \
    max_target_length=32 \
    ici_pipeline_parallelism=2 \
    num_pipeline_microbatches=4 \
    num_layers_per_pipeline_stage=2 \
    enable_checkpointing=false \
    enable_nnx=true \
    pure_nnx_decoder=true \
    scan_layers_per_stage=false \
    async_checkpointing=false > nnx-porting-log/pipeline/custom_${MODEL_NAME}.log 2>&1

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@mesakhcienet mesakhcienet changed the title core: migrate pipeline to nnx feat: migrate pipeline to nnx Dec 24, 2025
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 8 times, most recently from 6875da8 to f34b1a3 Compare January 15, 2026 23:43
@codecov

codecov Bot commented Jan 19, 2026

Copy link
Copy Markdown

@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 4 times, most recently from 12a3907 to 2c16599 Compare January 28, 2026 08:04
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 2 times, most recently from 64dc147 to 9e4518e Compare February 2, 2026 01:58
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch from 631a73e to ac97a1d Compare March 2, 2026 08:48
@mesakhcienet mesakhcienet changed the base branch from main to xibin/nnx_all March 2, 2026 08:48
@ecnal-cienet ecnal-cienet force-pushed the xibin/nnx_all branch 12 times, most recently from 1849f0b to 669dc01 Compare March 3, 2026 19:59
@mesakhcienet

mesakhcienet commented May 4, 2026

Copy link
Copy Markdown
Collaborator Author

@bvandermoon
Update on the pipeline migration approach:
After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:
The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).
Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.
Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

Thank you for your response. A few clarifications on the sequencing:

  1. The NNX path is opt-in, not the default. This PR adds a new use_nnx_pipeline flag specifically for choosing between the Linen and NNX pipeline implementations, and it defaults to false. Combined with enable_nnx and pure_nnx_decoder (both also false by default), existing users stay on Linen and aren't affected by NNX changes. Linen remains the source of truth until we explicitly flip the defaults, so divergence is much less risky than if NNX were the live path.

  2. Removing Linen fully right now carries unknown risk. I'm not yet confident what breaks downstream if the Linen pipeline classes are deleted outright — there may be call sites or configurations that depend on them in ways the current test coverage doesn't surface. Keeping both lets us de-risk the removal incrementally.

  3. NNX isn't unverified. I've run it against the train-compile configs Nuoj asked about plus the end-to-end run in the PR description. It's just not under unit-test coverage yet.

  4. Existing unit tests are intentionally left on the Linen path so the battle-tested version keeps rigorous coverage during the migration. Plan is to flip them to NNX once the decoder migration is done.

  5. Decoder NNX migration is still in flight (open PRs adding missing models + their unit tests). Adding NNX pipeline unit tests now would likely need rework once that lands.

  6. This PR is already large — I'd rather keep it focused on the migration and handle the test flip in a follow-up.

So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you!

@bvandermoon

Copy link
Copy Markdown
Collaborator

@bvandermoon
Update on the pipeline migration approach:
After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:
The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).
Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.
Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

Thank you for your response. A few clarifications on the sequencing:

  1. The NNX path is opt-in, not the default. This PR adds a new use_nnx_pipeline flag specifically for choosing between the Linen and NNX pipeline implementations, and it defaults to false. Combined with enable_nnx and pure_nnx_decoder (both also false by default), existing users stay on Linen and aren't affected by NNX changes. Linen remains the source of truth until we explicitly flip the defaults, so divergence is much less risky than if NNX were the live path.
  2. Removing Linen fully right now carries unknown risk. I'm not yet confident what breaks downstream if the Linen pipeline classes are deleted outright — there may be call sites or configurations that depend on them in ways the current test coverage doesn't surface. Keeping both lets us de-risk the removal incrementally.
  3. NNX isn't unverified. I've run it against the train-compile configs Nuoj asked about plus the end-to-end run in the PR description. It's just not under unit-test coverage yet.
  4. Existing unit tests are intentionally left on the Linen path so the battle-tested version keeps rigorous coverage during the migration. Plan is to flip them to NNX once the decoder migration is done.
  5. Decoder NNX migration is still in flight (open PRs adding missing models + their unit tests). Adding NNX pipeline unit tests now would likely need rework once that lands.
  6. This PR is already large — I'd rather keep it focused on the migration and handle the test flip in a follow-up.

So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you!

I can be open to this approach since some models are not integrated into the decoder layer. But can you please ensure there is a plan to avoid divergence between the NNX and Linen versions? I am concerned that if the Linen version changes, we could miss some functionality in the NNX version

I am also concerned that if we don't run the unit tests on the NNX version now, it will end up being more painful when we go to make the real cutover. Is there anything we can do to be more confident here?

@mesakhcienet

mesakhcienet commented May 6, 2026

Copy link
Copy Markdown
Collaborator Author

@bvandermoon
Update on the pipeline migration approach:
After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:
The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).
Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.
Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

Thank you for your response. A few clarifications on the sequencing:

  1. The NNX path is opt-in, not the default. This PR adds a new use_nnx_pipeline flag specifically for choosing between the Linen and NNX pipeline implementations, and it defaults to false. Combined with enable_nnx and pure_nnx_decoder (both also false by default), existing users stay on Linen and aren't affected by NNX changes. Linen remains the source of truth until we explicitly flip the defaults, so divergence is much less risky than if NNX were the live path.
  2. Removing Linen fully right now carries unknown risk. I'm not yet confident what breaks downstream if the Linen pipeline classes are deleted outright — there may be call sites or configurations that depend on them in ways the current test coverage doesn't surface. Keeping both lets us de-risk the removal incrementally.
  3. NNX isn't unverified. I've run it against the train-compile configs Nuoj asked about plus the end-to-end run in the PR description. It's just not under unit-test coverage yet.
  4. Existing unit tests are intentionally left on the Linen path so the battle-tested version keeps rigorous coverage during the migration. Plan is to flip them to NNX once the decoder migration is done.
  5. Decoder NNX migration is still in flight (open PRs adding missing models + their unit tests). Adding NNX pipeline unit tests now would likely need rework once that lands.
  6. This PR is already large — I'd rather keep it focused on the migration and handle the test flip in a follow-up.

So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you!

I can be open to this approach since some models are not integrated into the decoder layer. But can you please ensure there is a plan to avoid divergence between the NNX and Linen versions? I am concerned that if the Linen version changes, we could miss some functionality in the NNX version

I am also concerned that if we don't run the unit tests on the NNX version now, it will end up being more painful when we go to make the real cutover. Is there anything we can do to be more confident here?

@bvandermoon

Thank you sir. Thinking it through, I'd actually like to go with your original suggestion — remove the Linen pipeline entirely in this PR rather than keep both. Maintaining the hybrid with proper safeguards (dual-path CI, equivalence tests, etc.) ends up being a similar amount of work to just doing the cutover with solid NNX tests, and the cutover leaves us in a cleaner place.

Proposed plan for this PR:

  1. Remove the Linen pipeline classes as originally proposed.
  2. Update the existing unit tests (mostly shoud be on pipeline_parallelism_test.py) to run against the NNX path so the pipeline stays under unit-test coverage through the migration.

Does this work for you? I'll need a bit of time to get the NNX tests passing before this is ready for merge. On our side, we'll also need #3114 to land into main branch first before we can merge these changes — happy to adjust the plan if you'd sequence the test work differently.

Comment thread src/maxtext/configs/base.yml Outdated
@bvandermoon

Copy link
Copy Markdown
Collaborator

@bvandermoon
Update on the pipeline migration approach:
After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:
The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).
Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.
Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

Thank you for your response. A few clarifications on the sequencing:

  1. The NNX path is opt-in, not the default. This PR adds a new use_nnx_pipeline flag specifically for choosing between the Linen and NNX pipeline implementations, and it defaults to false. Combined with enable_nnx and pure_nnx_decoder (both also false by default), existing users stay on Linen and aren't affected by NNX changes. Linen remains the source of truth until we explicitly flip the defaults, so divergence is much less risky than if NNX were the live path.
  2. Removing Linen fully right now carries unknown risk. I'm not yet confident what breaks downstream if the Linen pipeline classes are deleted outright — there may be call sites or configurations that depend on them in ways the current test coverage doesn't surface. Keeping both lets us de-risk the removal incrementally.
  3. NNX isn't unverified. I've run it against the train-compile configs Nuoj asked about plus the end-to-end run in the PR description. It's just not under unit-test coverage yet.
  4. Existing unit tests are intentionally left on the Linen path so the battle-tested version keeps rigorous coverage during the migration. Plan is to flip them to NNX once the decoder migration is done.
  5. Decoder NNX migration is still in flight (open PRs adding missing models + their unit tests). Adding NNX pipeline unit tests now would likely need rework once that lands.
  6. This PR is already large — I'd rather keep it focused on the migration and handle the test flip in a follow-up.

So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you!

I can be open to this approach since some models are not integrated into the decoder layer. But can you please ensure there is a plan to avoid divergence between the NNX and Linen versions? I am concerned that if the Linen version changes, we could miss some functionality in the NNX version
I am also concerned that if we don't run the unit tests on the NNX version now, it will end up being more painful when we go to make the real cutover. Is there anything we can do to be more confident here?

@bvandermoon

Thank you sir. Thinking it through, I'd actually like to go with your original suggestion — remove the Linen pipeline entirely in this PR rather than keep both. Maintaining the hybrid with proper safeguards (dual-path CI, equivalence tests, etc.) ends up being a similar amount of work to just doing the cutover with solid NNX tests, and the cutover leaves us in a cleaner place.

Proposed plan for this PR:

  1. Remove the Linen pipeline classes as originally proposed.
  2. Update the existing unit tests (mostly shoud be on pipeline_parallelism_test.py) to run against the NNX path so the pipeline stays under unit-test coverage through the migration.

Does this work for you? I'll need a bit of time to get the NNX tests passing before this is ready for merge. On our side, we'll also need #3114 to land into main branch first before we can merge these changes — happy to adjust the plan if you'd sequence the test work differently.

That sounds great. Thank you @mesakhcienet

Comment thread src/maxtext/layers/decoders.py Outdated
remat_policy = self.get_remat_policy()
nnx_blocks = self._get_nnx_decoder_block_classes()

def stage_factory(rngs):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe give it a better name, and add some comments

Comment thread src/maxtext/layers/decoders.py
model_mode=self.model_mode,
)
return stage_module

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gobbleturk could you review changes in decoder.py as well?

out = layer(y, *layer_args, **layer_kwargs)
y = out[0] if isinstance(out, tuple) else out

else:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below is a huge if-else block with many branches. I wonder if there is a way to simplify the logic for better readabiltiy.

Comment thread src/maxtext/layers/pipeline.py Outdated
stages on the same physical devices. To hide the communication overhead of Fully
Sharded Data Parallelism (FSDP), this module utilizes a Buffer Sliding Window (BSW)
to prefetch and all-gather the weights for the *next* pipeline repeat while the
*current* repeat is executing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why previous comments are deleted?

Comment thread src/maxtext/layers/pipeline.py Outdated
Comment thread src/maxtext/layers/pipeline.py
Comment thread src/maxtext/layers/pipeline.py Outdated
Comment thread src/maxtext/layers/pipeline.py Outdated
Comment thread src/maxtext/layers/pipeline.py
Comment thread src/maxtext/layers/pipeline.py
Comment thread src/maxtext/layers/nnx_decoders.py Outdated
return layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rng)

# Workaround for Deepseek MTP test failure.
# TODO: Handle this properly.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please specify which error it is instead of using general except

Comment thread src/maxtext/layers/pipeline.py Outdated
layers: nn.Module
mesh: Mesh
remat_policy: Any = None
class NNXPipelineBase(nnx.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal is to deprecate linen version, we could just call it PipelineBase

Comment thread src/maxtext/layers/pipeline.py Outdated
Saves two named tensors during jax.checkpoint recomputation:
- "iteration_input": routed microbatch data entering the decoder
- "decoder_layer_input": input to the decoder layer itself
Everything else is recomputed during backward to save memory.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring is not accurate. The remat policy, even using pipeline parallelism, should be controlled by maxtext flags, e.g. remat_policy=custom, context=device...

enable_nnx: false
pure_nnx_decoder: false
enable_nnx: true
pure_nnx_decoder: true

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this change (using nnx decoder) have to be a part of this PR?

@mesakhcienet mesakhcienet Jun 12, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change belongs on a different PR ,
it is just here to prove to reviewers that unit tests pass with the nnx decoder enabled.

FYI, currently we are on progress to push the nnx_decoders.py update on this PR. We expect that PR to be merged first, since it contains similar changes to current branch's nnx_decoders.py

@NuojCheng NuojCheng left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants