Skip to content

Support Qwix quantization on NNX#4040

Open
hsuan-lun-chiang wants to merge 1 commit into
mainfrom
feat/Support-Qwix-quantization-on-NNX
Open

Support Qwix quantization on NNX#4040
hsuan-lun-chiang wants to merge 1 commit into
mainfrom
feat/Support-Qwix-quantization-on-NNX

Conversation

@hsuan-lun-chiang

@hsuan-lun-chiang hsuan-lun-chiang commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR enables full support for Qwix quantization algorithms on NNX models and restores support for Linen models. Because the underlying Transformer architecture has been migrated to an NNX module, the existing Qwix quantization integration for Linen models was broken. This PR fixes the Linen integration and extends full feature parity to the NNX backend.

Tests

Gpt3-6b
Qwix Quantization Test - Linen
Qwix Quantization Test - NNX

Gemma4-26b
Qwix Quantization Test - Linen
Qwix Quantization Test - NNX

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 2, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 66.19718% with 24 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/nnx_wrappers.py 33.33% 10 Missing and 6 partials ⚠️
src/maxtext/layers/quantizations.py 81.81% 4 Missing and 2 partials ⚠️
src/maxtext/layers/nnx_decoders.py 50.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Support-Qwix-quantization-on-NNX branch 4 times, most recently from 135d9bc to 2993617 Compare June 3, 2026 10:44
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions Bot Jun 3, 2026

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces support for Qwix quantization on NNX models, primarily focusing on FP8 support. It includes changes to the NNX bridge wrappers, quantization providers, and the training loop to handle special variable types for FP8 stats.

🔍 General Feedback

  • The move to NNX for quantization is a positive step, as it simplifies the state management and removes previous workarounds for jax.lax.scan.
  • However, the dynamic reconstruction of module structures in ToLinen is risky and could lead to runtime errors with complex model architectures (especially those involving lists or sequences).
  • The training loop logic for OverwriteWithGradient variables relies on custom gradient behaviors that should be clearly documented to avoid confusion with standard parameter updates.
  • Test coverage for pure NNX quantization is a good addition.

Comment thread src/maxtext/layers/quantizations.py
Comment thread src/maxtext/layers/quantizations.py
Comment thread src/maxtext/trainers/pre_train/train.py Outdated
Comment thread src/maxtext/layers/nnx_wrappers.py
Comment thread src/maxtext/utils/model_creation_utils.py Outdated
Comment thread src/maxtext/layers/quantizations.py

@bvandermoon bvandermoon left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the NaN in the last line of https://paste.googleplex.com/5058255918333952?

@RexBearIU

RexBearIU commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

What is the NaN in the last line of https://paste.googleplex.com/5058255918333952?

The warning NaN or Inf found in input tensor in x2num.py is a benign TensorBoard logging warning caused by the perplexity metric overflowing, not by any actual NaN/Inf in model weights or gradients.

Since perplexity is calculated as exp(loss), and at the start of training the loss is extremely high (above 88.7), exp(loss) overflows to positive infinity (inf) in standard float32 precision. When TensorBoard's
summary writer attempts to log this infinity value, its internal validation check in x2num.py raises the warning.

This is completely normal and occurs identically in unquantized pre-training. Here is a log snippet from a standard, unquantized baseline pre-training run showing the exact same loss progression and warning:

    completed step: 0, loss: 445.785, lm_loss: 445.785, perplexity: inf 
    W0610 08:12:16.985936 x2num.py:13] NaN or Inf found in input tensor.
  
    completed step: 1, loss: 445.785, lm_loss: 445.785, perplexity: inf 
    W0610 08:12:18.443920 x2num.py:13] NaN or Inf found in input tensor.
  
    completed step: 2, loss: 217.815, lm_loss: 217.815, perplexity: inf 
    W0610 08:12:19.904166 x2num.py:13] NaN or Inf found in input tensor.

@RexBearIU RexBearIU force-pushed the feat/Support-Qwix-quantization-on-NNX branch from b65d6f5 to 0804e1b Compare June 10, 2026 08:55
Comment thread src/maxtext/layers/nnx_decoders.py Outdated
return out, nnx.state(merged_layer)

# Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint
# re-traces and hits UnexpectedTracerError. Skip remat for FP8.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we skipping remat for FP8? This is different from Linen, right?

@RexBearIU RexBearIU Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under Flax Linen, nn.remat is a custom wrapper provided by Flax that has specialized handling for Flax's internal variable collections (such as the amax_history mutable collection used by FP8 operations) during rematerialization.

Under the NNX decoder when bridging NNX to Linen, the _apply_layer_with_remat helper splits the NNX module state and wraps the pure functionalized pass in standard jax.checkpoint. Since standard jax.checkpoint does not possess Flax Linen's specialized state-tracking mechanisms, the re-tracing process intercepts Flax's mutable tracers (which are captured inside Flax's Scope object), resulting in UnexpectedTracerError.

Therefore, we must skip the standard jax.checkpoint rematerialization when using Linen-based FP8 quantization configurations (fp8_nanoo, fp8_gpu) inside the NNX framework to avoid this tracing issue. This is indeed a difference between pure Linen and NNX decoder checkpointing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae can you please take a look at this? Is there a way to make this work from the NNX persepctive? I don't think we will be able to complete the migration without FP8 remat working

@RexBearIU RexBearIU force-pushed the feat/Support-Qwix-quantization-on-NNX branch from 0804e1b to fc657bf Compare June 15, 2026 10:45
Comment thread src/maxtext/layers/nnx_wrappers.py Outdated
return fn


def isolate_linen_stacks(fn: tp.Callable[..., tp.Any], *args, **kwargs):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this doing?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What isolate_linen_stacks does and why it's needed:

During JAX compilation (.init and .apply), Flax Linen uses a thread-local stack (linen.module._context.module_stack) to resolve the hierarchy and scopes of active Linen modules.

However, on our Flax NNX path, we dynamically execute wrapped Linen operators (via ToNNX) inside an outer NNX model context. If the outer context or JAX transforms have active module contexts, the thread-local Linen stack can retain stale/unexpected elements, leading to compiled tracer leaks and JAX UnexpectedTracerError.

isolate_linen_stacks solves this by temporarily shadowing and clearing the thread-local Linen and Bridge module stacks during the inner ToNNX execution, and restoring them immediately after.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_with_output already does this isolation though...

@RexBearIU RexBearIU Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're totally right, good catch. I overlooked that init_with_output already handles that isolation under the hood. I've removed the redundant wrapper.

@RexBearIU RexBearIU force-pushed the feat/Support-Qwix-quantization-on-NNX branch 5 times, most recently from 7e93a31 to 45a799b Compare June 16, 2026 04:16
Comment thread src/maxtext/layers/quantizations.py Outdated
Comment on lines +721 to +722
except Exception: # pylint: disable=broad-exception-caught
is_nnx = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the exception being caught here? Is it ValueError from Qwix? Mind just catching the specific error if so?

@RexBearIU RexBearIU Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to point this out! When called outside of a module context, flax_util.get_current_module() raises a ValueError ("Current module is not known."). I have updated the block to catch ValueError explicitly.

Co-authored-by: Jacky Fang <rexbear87941@gmail.com>
@RexBearIU RexBearIU force-pushed the feat/Support-Qwix-quantization-on-NNX branch from 45a799b to f07713f Compare June 16, 2026 10:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants