feat: migrate pipeline to nnx#2885
Conversation
6875da8 to
f34b1a3
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
12a3907 to
2c16599
Compare
64dc147 to
9e4518e
Compare
631a73e to
ac97a1d
Compare
1849f0b to
669dc01
Compare
Thank you for your response. A few clarifications on the sequencing:
So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you! |
I can be open to this approach since some models are not integrated into the decoder layer. But can you please ensure there is a plan to avoid divergence between the NNX and Linen versions? I am concerned that if the Linen version changes, we could miss some functionality in the NNX version I am also concerned that if we don't run the unit tests on the NNX version now, it will end up being more painful when we go to make the real cutover. Is there anything we can do to be more confident here? |
Thank you sir. Thinking it through, I'd actually like to go with your original suggestion — remove the Linen pipeline entirely in this PR rather than keep both. Maintaining the hybrid with proper safeguards (dual-path CI, equivalence tests, etc.) ends up being a similar amount of work to just doing the cutover with solid NNX tests, and the cutover leaves us in a cleaner place. Proposed plan for this PR:
Does this work for you? I'll need a bit of time to get the NNX tests passing before this is ready for merge. On our side, we'll also need #3114 to land into main branch first before we can merge these changes — happy to adjust the plan if you'd sequence the test work differently. |
That sounds great. Thank you @mesakhcienet |
| remat_policy = self.get_remat_policy() | ||
| nnx_blocks = self._get_nnx_decoder_block_classes() | ||
|
|
||
| def stage_factory(rngs): |
There was a problem hiding this comment.
maybe give it a better name, and add some comments
| model_mode=self.model_mode, | ||
| ) | ||
| return stage_module | ||
|
|
There was a problem hiding this comment.
@gobbleturk could you review changes in decoder.py as well?
| out = layer(y, *layer_args, **layer_kwargs) | ||
| y = out[0] if isinstance(out, tuple) else out | ||
|
|
||
| else: |
There was a problem hiding this comment.
below is a huge if-else block with many branches. I wonder if there is a way to simplify the logic for better readabiltiy.
| stages on the same physical devices. To hide the communication overhead of Fully | ||
| Sharded Data Parallelism (FSDP), this module utilizes a Buffer Sliding Window (BSW) | ||
| to prefetch and all-gather the weights for the *next* pipeline repeat while the | ||
| *current* repeat is executing. |
There was a problem hiding this comment.
why previous comments are deleted?
| return layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rng) | ||
|
|
||
| # Workaround for Deepseek MTP test failure. | ||
| # TODO: Handle this properly. |
There was a problem hiding this comment.
please specify which error it is instead of using general except
| layers: nn.Module | ||
| mesh: Mesh | ||
| remat_policy: Any = None | ||
| class NNXPipelineBase(nnx.Module): |
There was a problem hiding this comment.
If the goal is to deprecate linen version, we could just call it PipelineBase
| Saves two named tensors during jax.checkpoint recomputation: | ||
| - "iteration_input": routed microbatch data entering the decoder | ||
| - "decoder_layer_input": input to the decoder layer itself | ||
| Everything else is recomputed during backward to save memory. |
There was a problem hiding this comment.
This docstring is not accurate. The remat policy, even using pipeline parallelism, should be controlled by maxtext flags, e.g. remat_policy=custom, context=device...
| enable_nnx: false | ||
| pure_nnx_decoder: false | ||
| enable_nnx: true | ||
| pure_nnx_decoder: true |
There was a problem hiding this comment.
does this change (using nnx decoder) have to be a part of this PR?
There was a problem hiding this comment.
This change belongs on a different PR ,
it is just here to prove to reviewers that unit tests pass with the nnx decoder enabled.
FYI, currently we are on progress to push the nnx_decoders.py update on this PR. We expect that PR to be merged first, since it contains similar changes to current branch's nnx_decoders.py
Description
implement nnx-based pipeline.
This PR extends PR#2831
Main changes:
NNXPipeline, which is a nnx-based pipeline class.Tests
we run the pipeline process with command below:
MODEL_NAME=llama2-7b python -m MaxText.train src/maxtext/configs/base.yml \ run_name=pipeline_test_${MODEL_NAME}_nnx \ base_output_directory=/dev/shm/pipeline_test_nnx \ model_name=${MODEL_NAME}\ dataset_type=synthetic \ steps=15 \ debug_sharding=true \ per_device_batch_size=2 \ max_target_length=32 \ ici_pipeline_parallelism=2 \ num_pipeline_microbatches=4 \ num_layers_per_pipeline_stage=2 \ enable_checkpointing=false \ enable_nnx=true \ pure_nnx_decoder=true \ scan_layers_per_stage=false \ async_checkpointing=false > nnx-porting-log/pipeline/custom_${MODEL_NAME}.log 2>&1Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.