Support Qwix quantization on NNX#4040
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
135d9bc to
2993617
Compare
There was a problem hiding this comment.
This Pull Request introduces support for Qwix quantization on NNX models, primarily focusing on FP8 support. It includes changes to the NNX bridge wrappers, quantization providers, and the training loop to handle special variable types for FP8 stats.
🔍 General Feedback
- The move to NNX for quantization is a positive step, as it simplifies the state management and removes previous workarounds for
jax.lax.scan. - However, the dynamic reconstruction of module structures in
ToLinenis risky and could lead to runtime errors with complex model architectures (especially those involving lists or sequences). - The training loop logic for
OverwriteWithGradientvariables relies on custom gradient behaviors that should be clearly documented to avoid confusion with standard parameter updates. - Test coverage for pure NNX quantization is a good addition.
870ac23 to
b65d6f5
Compare
bvandermoon
left a comment
There was a problem hiding this comment.
What is the NaN in the last line of https://paste.googleplex.com/5058255918333952?
The warning Since perplexity is calculated as This is completely normal and occurs identically in unquantized pre-training. Here is a log snippet from a standard, unquantized baseline pre-training run showing the exact same loss progression and warning: |
b65d6f5 to
0804e1b
Compare
| return out, nnx.state(merged_layer) | ||
|
|
||
| # Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint | ||
| # re-traces and hits UnexpectedTracerError. Skip remat for FP8. |
There was a problem hiding this comment.
Why are we skipping remat for FP8? This is different from Linen, right?
There was a problem hiding this comment.
Under Flax Linen, nn.remat is a custom wrapper provided by Flax that has specialized handling for Flax's internal variable collections (such as the amax_history mutable collection used by FP8 operations) during rematerialization.
Under the NNX decoder when bridging NNX to Linen, the _apply_layer_with_remat helper splits the NNX module state and wraps the pure functionalized pass in standard jax.checkpoint. Since standard jax.checkpoint does not possess Flax Linen's specialized state-tracking mechanisms, the re-tracing process intercepts Flax's mutable tracers (which are captured inside Flax's Scope object), resulting in UnexpectedTracerError.
Therefore, we must skip the standard jax.checkpoint rematerialization when using Linen-based FP8 quantization configurations (fp8_nanoo, fp8_gpu) inside the NNX framework to avoid this tracing issue. This is indeed a difference between pure Linen and NNX decoder checkpointing.
There was a problem hiding this comment.
@cgarciae can you please take a look at this? Is there a way to make this work from the NNX persepctive? I don't think we will be able to complete the migration without FP8 remat working
0804e1b to
fc657bf
Compare
| return fn | ||
|
|
||
|
|
||
| def isolate_linen_stacks(fn: tp.Callable[..., tp.Any], *args, **kwargs): |
There was a problem hiding this comment.
What isolate_linen_stacks does and why it's needed:
During JAX compilation (.init and .apply), Flax Linen uses a thread-local stack (linen.module._context.module_stack) to resolve the hierarchy and scopes of active Linen modules.
However, on our Flax NNX path, we dynamically execute wrapped Linen operators (via ToNNX) inside an outer NNX model context. If the outer context or JAX transforms have active module contexts, the thread-local Linen stack can retain stale/unexpected elements, leading to compiled tracer leaks and JAX UnexpectedTracerError.
isolate_linen_stacks solves this by temporarily shadowing and clearing the thread-local Linen and Bridge module stacks during the inner ToNNX execution, and restoring them immediately after.
There was a problem hiding this comment.
init_with_output already does this isolation though...
There was a problem hiding this comment.
You're totally right, good catch. I overlooked that init_with_output already handles that isolation under the hood. I've removed the redundant wrapper.
7e93a31 to
45a799b
Compare
| except Exception: # pylint: disable=broad-exception-caught | ||
| is_nnx = False |
There was a problem hiding this comment.
What is the exception being caught here? Is it ValueError from Qwix? Mind just catching the specific error if so?
There was a problem hiding this comment.
Thanks to point this out! When called outside of a module context, flax_util.get_current_module() raises a ValueError ("Current module is not known."). I have updated the block to catch ValueError explicitly.
Co-authored-by: Jacky Fang <rexbear87941@gmail.com>
45a799b to
f07713f
Compare
Description
This PR enables full support for Qwix quantization algorithms on NNX models and restores support for Linen models. Because the underlying Transformer architecture has been migrated to an NNX module, the existing Qwix quantization integration for Linen models was broken. This PR fixes the Linen integration and extends full feature parity to the NNX backend.
Tests
Gpt3-6b
Qwix Quantization Test - Linen
Qwix Quantization Test - NNX
Gemma4-26b
Qwix Quantization Test - Linen
Qwix Quantization Test - NNX
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.