diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c527b2ade9..44546bd0c8 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -25,7 +25,7 @@ import jax.numpy as jnp from flax import linen as nn from flax import nnx -from flax.nnx import wrappers as nnx_wrappers +from maxtext.layers import nnx_wrappers from jax.ad_checkpoint import checkpoint_name from jax.sharding import Mesh @@ -716,17 +716,8 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - # Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint - # re-traces and hits UnexpectedTracerError. Skip remat for FP8. - uses_linen_fp8_mutable_state = self.config.quantization in { - "fp8_nanoo", - "fp8_gpu", - } - if uses_linen_fp8_mutable_state: - out, new_state = pure_layer_fn(state, y) - else: - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -854,26 +845,7 @@ def layer_fn(carry, scanned_vars): params = nnx_ensure_scan_leading_axis(params, length) state = nnx_ensure_scan_leading_axis(state, length) - # Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan - # leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop - # for FP8 instead. - uses_linen_fp8_mutable_state = self.config.quantization in { - "fp8_nanoo", - "fp8_gpu", - } - if uses_linen_fp8_mutable_state: - carry = x_in - per_layer_states = [] - for i in range(length): - current_params = jax.tree.map(lambda x, i=i: x[i], params) - current_state = jax.tree.map(lambda x, i=i: x[i], state) - carry, new_state_i = layer_fn(carry, (current_params, current_state)) - per_layer_states.append(new_state_i) - final_carry = carry - # pylint: disable-next=no-value-for-parameter (*per_layer_states supplies the `tree` arg) - scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) - else: - final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) + final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index b483649c9e..07faeefda2 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -285,14 +285,18 @@ def __call__( # Get `mutable` from top level bridge.Module context if any if mutable is not None: pass - elif (m := bdg_module.current_module()) is not None: + elif getattr(bdg_module.MODULE_CONTEXT, "module_stack", None) and (m := bdg_module.current_module()) is not None: assert m.scope is not None mutable = m.scope.mutable elif (m := current_linen_module()) is not None: assert m.scope is not None mutable = m.scope.mutable else: - mutable = False + # Safe fallback mutability: when running functionally isolated inside standard JAX transforms, + # we determine which collections (such as "stats" or "amax_history") are present and mark them mutable. + mutable = [k for k in variables.keys() if k != "params"] + if not mutable: + mutable = False out = self.to_nnx__module.apply(variables, *args, rngs=_rngs, method=method, mutable=mutable, **kwargs) @@ -509,7 +513,31 @@ def maybe_unbox(x): for path, _ in unknown_state_flat.items(): paths_str += f"\n - {'/'.join(map(str, path))}" - warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") + # Dynamically reconstruct the unknown variables + curr = module + for p in path[:-1]: + if isinstance(curr, dict): + if p not in curr: + curr[p] = nnx.Module() + curr = curr[p] + elif isinstance(curr, list): + if not isinstance(p, int): + raise TypeError(f"Expected int index for list, got {type(p)}: {p}") + while len(curr) <= p: + curr.append(nnx.Module()) + curr = curr[p] + elif isinstance(curr, tuple): + raise ValueError(f"Cannot dynamically reconstruct elements within a tuple at path {path}.") + else: + if not isinstance(p, str): + p = str(p) + if not hasattr(curr, p): + setattr(curr, p, nnx.Module()) + curr = getattr(curr, p) + + warnings.warn( + f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed." + ) _fix_for_qwix_quantization(module) nnx.update(module, new_state) diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index d4688abb80..a91226b5d3 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -38,6 +38,9 @@ from flax.linen import fp8_ops from flax.linen import initializers as flax_initializers import flax.linen as nn +from flax import nnx +from qwix._src import flax_util +from maxtext.layers import nnx_wrappers from maxtext.common.common_types import DType, Config from maxtext.inference.kvcache import KVQuant @@ -710,6 +713,32 @@ def configure_kv_quant(config): return None if not config.quantize_kvcache else KVQuant(config) +def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs): + """Applies a Linen module within an NNX context.""" + try: + parent = flax_util.get_current_module() + is_nnx = isinstance(parent, nnx.Module) + except ValueError: + is_nnx = False + + if is_nnx: + attr_name = f"_qwix_fp8_gpu_{op_id}" + if not hasattr(parent, attr_name): + rngs = getattr(parent, "qwix_rngs", None) + if rngs is None: + parent_rngs = getattr(parent, "rngs", None) + if parent_rngs is not None and hasattr(parent_rngs, "fork"): + rngs = parent_rngs.fork() + else: + rngs = nnx.Rngs(0) + wrapper = nnx_wrappers.ToNNX(linen_module_cls(name=op_id), rngs=rngs) + wrapper.lazy_init(*args, **kwargs) + setattr(parent, attr_name, wrapper) + return getattr(parent, attr_name)(*args, mutable=["_overwrite_with_gradient"], **kwargs) + else: + return linen_module_cls(name=op_id)(*args, **kwargs) + + class NvidaFp8Provider(qwix.QtProvider): """Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface.""" @@ -718,13 +747,13 @@ def dot_general(self, *args, **kwargs): rule, op_id = self._get_current_rule_and_op_id("dot_general") if rule is None: return jax.lax.dot_general(*args, **kwargs) - return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs) + return _apply_linen_module_in_nnx(nn.Fp8DirectDotGeneralOp, op_id, *args, **kwargs) def einsum(self, *args, **kwargs): rule, op_id = self._get_current_rule_and_op_id("einsum") if rule is None: return jnp.einsum(*args, **kwargs) - return nn.Fp8Einsum(name=op_id)(*args, **kwargs) + return _apply_linen_module_in_nnx(nn.Fp8Einsum, op_id, *args, **kwargs) class NANOOFp8Provider(qwix.QtProvider): @@ -734,7 +763,7 @@ def dot_general(self, *args, **kwargs): rule, op_id = self._get_current_rule_and_op_id("dot_general") if rule is None: return jax.lax.dot_general(*args, **kwargs) - return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs) + return _apply_linen_module_in_nnx(nn.NANOOFp8DotGeneralOp, op_id, *args, **kwargs) def get_fp8_full_qwix_rule_w_sparsity(config: Config): @@ -815,7 +844,21 @@ def maybe_quantize_model(model, config): if config.use_qwix_quantization and not config.use_batch_split_schedule: quantization_provider = get_qt_provider(config) if quantization_provider: - model = qwix.quantize_model(model, quantization_provider) + if config.pure_nnx: + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32) + dummy_positions = jnp.ones(input_shape, dtype=jnp.int32) + dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32) + model = qwix.quantize_model( + model, + quantization_provider, + dummy_tokens, + dummy_positions, + dummy_segment_ids, + enable_dropout=False, + ) + else: + model = qwix.quantize_model(model, quantization_provider) return model diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 3be6baff8c..033cc5a0d0 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -38,6 +38,7 @@ from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning +from flax.nnx import variablelib from maxtext.configs import pyconfig from maxtext.utils.globals import EPS @@ -359,7 +360,9 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat is_train=True, ) else: - model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + owg_type = variablelib.variable_type_from_name("_overwrite_with_gradient", allow_register=True) + custom_param_filter = nnx.Any(owg_type) + model_graphdef, curr_params, custom_params, rest = nnx.split(state.model, nnx.Param, custom_param_filter, ...) if config.parameter_memory_host_offload: # Params are kept on host (pinned_host) in in_shardings. Move only Param # variables to device before the forward/backward pass so that all dot_general @@ -381,15 +384,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) nnx.update(state.model, curr_params) - def diff_wrapper(param, rest, config, data): - local_model = nnx.merge(model_graphdef, param, rest, copy=True) + def diff_wrapper(curr_params, custom_params, rest, config, data): + local_model = nnx.merge(model_graphdef, curr_params, custom_params, rest, copy=True) loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) - _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + _, _, _, new_rest = nnx.split(local_model, nnx.Param, custom_param_filter, ...) return loss, (aux, new_rest) - grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) - (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) - nnx.update(state.model, new_rest) + grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True) + (loss, (aux, new_rest)), (raw_grads, custom_grads) = grad_func(curr_params, custom_params, rest, config, data) + nnx.update(state.model, nnx.State.merge(custom_grads, new_rest)) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index b0af64d9fc..fff430e646 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -22,6 +22,7 @@ from aqt.jax.v2 import aqt_tensor from aqt.jax.v2.flax import aqt_flax from flax import nnx +from flax.nnx import traversals import jax from jax import lax from jax import numpy as jnp @@ -48,7 +49,7 @@ def __init__( self, quantization: quantizations.AqtQuantization, data_type: Any, - rngs: nnx.Rngs, + rngs: nnx.Rngs, # pylint: disable=unused-argument ): self.quantization = quantization self.identity = jnp.identity(2, dtype=data_type) @@ -387,17 +388,116 @@ def compare_fn(path, x, y): jax.tree_util.tree_map_with_path(compare_fn, a, b) - def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1): + def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1, **kwargs): """Run forward pass and backward pass for quantized model and compare with base model.""" # pylint: disable=protected-access - cfg = self.init_pyconfig(quantization=quant) - qt_model = model_creation_utils.create_model(cfg, self.mesh) - + cfg = self.init_pyconfig(quantization=quant, **kwargs) ids, decoder_segment_ids, decoder_positions = self.get_data() - if not hasattr(self.__class__, "_cached_base_results"): - model = model_creation_utils.create_model(self.cfg, self.mesh) - var = model.init( + if cfg.pure_nnx: + qt_model = model_creation_utils.create_model(cfg, self.mesh, rngs=nnx.Rngs(0)) + if getattr(self.__class__, "_cached_base_results_nnx", None) is None: + base_cfg = self.init_pyconfig(quantization="", **kwargs) + base_model = model_creation_utils.create_model(base_cfg, self.mesh, rngs=nnx.Rngs(0)) + + def loss_base(model): + logits = model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + return jnp.mean((logits) ** 2) + + grads_base = nnx.grad(loss_base)(base_model) + logits_base = base_model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + self.__class__._cached_base_results_nnx = (grads_base, logits_base) + + grads_base, logits = self.__class__._cached_base_results_nnx + + def loss_quant(model): + logits_q = model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + return jnp.mean((logits_q) ** 2) + + grads_quant = nnx.grad(loss_quant)(qt_model) + quant_logits = qt_model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + enable_dropout=False, + ) + + print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}") + assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance + + # nnx.grad returns a State object which is a mapping of paths to gradients. + # Flatten them to check for tolerance. + grads_base_flat = traversals.flatten_mapping(grads_base) + grads_quant_flat = traversals.flatten_mapping(grads_quant) + + # Filter for param collections to compare only parameters and not stats/buffers if any + # Note: NNX grads structure might contain variables like 'kernel', 'bias'. + # For simplicity we compare all matching keys. + def flatten_and_filter(grads_flat): + return {k: v for k, v in grads_flat.items() if hasattr(v, "shape") and "quant_stats" not in str(k)} + + gb_f = flatten_and_filter(grads_base_flat) + gq_f = flatten_and_filter(grads_quant_flat) + + for k in gb_f: + if k in gq_f: + diff = jnp.abs(gb_f[k] - gq_f[k]).mean() / (jnp.abs(gb_f[k]).mean() + 1e-8) + if diff > grad_tolerance: + print(f"Gradient mismatch for {k}: rel_error = {diff}") + assert diff <= grad_tolerance + else: + qt_model = model_creation_utils.create_model(cfg, self.mesh) + if not hasattr(self.__class__, "_cached_base_results"): + model = model_creation_utils.create_model(self.cfg, self.mesh) + var = model.init( + {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + mutable=True, + ) + + def loss_base_linen(all_vars, inputs): + logits_b, _ = model.apply( + all_vars, + *inputs, + enable_dropout=False, + rngs={"params": self.rng}, + mutable=True, + ) + return jnp.mean((logits_b) ** 2) + + grads_base_linen = jax.grad(loss_base_linen)(var, (ids, decoder_positions, decoder_segment_ids)) + logits_b, _ = model.apply( + var, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + rngs={"params": self.rng}, + mutable=True, + ) + self.__class__._cached_base_results = (grads_base_linen, logits_b) + + grads_base_linen, logits = self.__class__._cached_base_results + + quantized_vars = qt_model.init( {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, ids, decoder_positions, @@ -406,19 +506,20 @@ def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1) mutable=True, ) - def loss_base(all_vars, inputs): - logits, _ = model.apply( + def loss_quant_linen(all_vars, inputs): + logits_q, _ = qt_model.apply( all_vars, *inputs, enable_dropout=False, rngs={"params": self.rng}, mutable=True, ) - return jnp.mean((logits) ** 2) + return jnp.mean((logits_q) ** 2) + + grads_quant_linen = jax.grad(loss_quant_linen)(quantized_vars, (ids, decoder_positions, decoder_segment_ids)) - grads_base = jax.grad(loss_base)(var, (ids, decoder_positions, decoder_segment_ids)) - logits, _ = model.apply( - var, + quant_logits, _ = qt_model.apply( + quantized_vars, ids, decoder_positions, decoder_segment_ids, @@ -426,74 +527,61 @@ def loss_base(all_vars, inputs): rngs={"params": self.rng}, mutable=True, ) - self.__class__._cached_base_results = (grads_base, logits) - - grads_base, logits = self.__class__._cached_base_results - - quantized_vars = qt_model.init( - {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - mutable=True, - ) - - def loss_quant(all_vars, inputs): - logits, _ = qt_model.apply( - all_vars, - *inputs, - enable_dropout=False, - rngs={"params": self.rng}, - mutable=True, + print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}") + assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance + self.print_grad_diff(grads_base_linen["params"], grads_quant_linen["params"]) + self.assertTrue( + self.pytree_allclose( + grads_base_linen["params"], + grads_quant_linen["params"], + tolerance=grad_tolerance, + ) ) - return jnp.mean((logits) ** 2) - - # Compute gradients w.r.t. both models - grads_quant = jax.grad(loss_quant)(quantized_vars, (ids, decoder_positions, decoder_segment_ids)) - - quant_logits, _ = qt_model.apply( - quantized_vars, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - rngs={"params": self.rng}, - mutable=True, - ) - print("relative error in logits:" f" {jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean()}") - assert jnp.abs(quant_logits - logits).mean() / jnp.abs(logits).mean() < logits_tolerance - self.print_grad_diff(grads_base["params"], grads_quant["params"]) - self.assertTrue( - self.pytree_allclose( - grads_base["params"], - grads_quant["params"], - tolerance=grad_tolerance, - ) - ) @pytest.mark.tpu_only def test_int8_quantization(self): self.quantization_config("int8") + @pytest.mark.tpu_only + def test_int8_quantization_nnx(self): + self.quantization_config("int8", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.tpu_only def test_fp8_quantization(self): self.quantization_config("fp8") + @pytest.mark.tpu_only + def test_fp8_quantization_nnx(self): + self.quantization_config("fp8", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.tpu_only def test_fp8_full_quantization(self): self.quantization_config("fp8_full") + @pytest.mark.tpu_only + def test_fp8_full_quantization_nnx(self): + self.quantization_config("fp8_full", enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.gpu_only @pytest.mark.external_serving def test_fp8_gpu_quantization(self): self.quantization_config("fp8_gpu", grad_tolerance=1.5) + @pytest.mark.gpu_only + @pytest.mark.external_serving + def test_fp8_gpu_quantization_nnx(self): + self.quantization_config("fp8_gpu", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.gpu_only @pytest.mark.external_serving def test_fp8_nanoo_quantization(self): self.quantization_config("fp8_nanoo", grad_tolerance=1.5) + @pytest.mark.gpu_only + @pytest.mark.external_serving + def test_fp8_nanoo_quantization_nnx(self): + self.quantization_config("fp8_nanoo", grad_tolerance=1.5, enable_nnx=True, pure_nnx_decoder=True, pure_nnx=True) + @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") @pytest.mark.gpu_only def test_fp8_te_fp8_delayedscaling_quantization(self):