Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 4 additions & 32 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import jax.numpy as jnp
from flax import linen as nn
from flax import nnx
from flax.nnx import wrappers as nnx_wrappers
from maxtext.layers import nnx_wrappers
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh

Expand Down Expand Up @@ -716,17 +716,8 @@ def pure_layer_fn(state_in, y_in):
out = merged_layer(y_in, **kwargs)
return out, nnx.state(merged_layer)

# Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint
# re-traces and hits UnexpectedTracerError. Skip remat for FP8.
uses_linen_fp8_mutable_state = self.config.quantization in {
"fp8_nanoo",
"fp8_gpu",
}
if uses_linen_fp8_mutable_state:
out, new_state = pure_layer_fn(state, y)
else:
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
nnx.update(layer, new_state)

return out
Expand Down Expand Up @@ -854,26 +845,7 @@ def layer_fn(carry, scanned_vars):
params = nnx_ensure_scan_leading_axis(params, length)
state = nnx_ensure_scan_leading_axis(state, length)

# Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan
# leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop
# for FP8 instead.
uses_linen_fp8_mutable_state = self.config.quantization in {
"fp8_nanoo",
"fp8_gpu",
}
if uses_linen_fp8_mutable_state:
carry = x_in
per_layer_states = []
for i in range(length):
current_params = jax.tree.map(lambda x, i=i: x[i], params)
current_state = jax.tree.map(lambda x, i=i: x[i], state)
carry, new_state_i = layer_fn(carry, (current_params, current_state))
per_layer_states.append(new_state_i)
final_carry = carry
# pylint: disable-next=no-value-for-parameter (*per_layer_states supplies the `tree` arg)
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
else:
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
returned_kv_stacked = None

if scan_axis != 0:
Expand Down
34 changes: 31 additions & 3 deletions src/maxtext/layers/nnx_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,18 @@ def __call__(
# Get `mutable` from top level bridge.Module context if any
if mutable is not None:
pass
elif (m := bdg_module.current_module()) is not None:
elif getattr(bdg_module.MODULE_CONTEXT, "module_stack", None) and (m := bdg_module.current_module()) is not None:
assert m.scope is not None
mutable = m.scope.mutable
elif (m := current_linen_module()) is not None:
assert m.scope is not None
mutable = m.scope.mutable
else:
mutable = False
# Safe fallback mutability: when running functionally isolated inside standard JAX transforms,
# we determine which collections (such as "stats" or "amax_history") are present and mark them mutable.
mutable = [k for k in variables.keys() if k != "params"]
if not mutable:
mutable = False

out = self.to_nnx__module.apply(variables, *args, rngs=_rngs, method=method, mutable=mutable, **kwargs)

Expand Down Expand Up @@ -509,7 +513,31 @@ def maybe_unbox(x):
for path, _ in unknown_state_flat.items():
paths_str += f"\n - {'/'.join(map(str, path))}"

warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
# Dynamically reconstruct the unknown variables
curr = module
for p in path[:-1]:
if isinstance(curr, dict):
if p not in curr:
curr[p] = nnx.Module()
curr = curr[p]
elif isinstance(curr, list):
if not isinstance(p, int):
raise TypeError(f"Expected int index for list, got {type(p)}: {p}")
while len(curr) <= p:
curr.append(nnx.Module())
curr = curr[p]
elif isinstance(curr, tuple):
raise ValueError(f"Cannot dynamically reconstruct elements within a tuple at path {path}.")
else:
if not isinstance(p, str):
p = str(p)
if not hasattr(curr, p):
setattr(curr, p, nnx.Module())
curr = getattr(curr, p)

Comment thread
hsuan-lun-chiang marked this conversation as resolved.
warnings.warn(
f"Found unknown module paths in incoming state:{paths_str}. Intermediate modules have been reconstructed."
)

_fix_for_qwix_quantization(module)
nnx.update(module, new_state)
Expand Down
51 changes: 47 additions & 4 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn
from flax import nnx
from qwix._src import flax_util
from maxtext.layers import nnx_wrappers

from maxtext.common.common_types import DType, Config
from maxtext.inference.kvcache import KVQuant
Expand Down Expand Up @@ -710,6 +713,32 @@ def configure_kv_quant(config):
return None if not config.quantize_kvcache else KVQuant(config)


def _apply_linen_module_in_nnx(linen_module_cls, op_id, *args, **kwargs):
"""Applies a Linen module within an NNX context."""
try:
parent = flax_util.get_current_module()
is_nnx = isinstance(parent, nnx.Module)
except ValueError:
is_nnx = False

if is_nnx:
attr_name = f"_qwix_fp8_gpu_{op_id}"
if not hasattr(parent, attr_name):
Comment thread
hsuan-lun-chiang marked this conversation as resolved.
rngs = getattr(parent, "qwix_rngs", None)
if rngs is None:
parent_rngs = getattr(parent, "rngs", None)
if parent_rngs is not None and hasattr(parent_rngs, "fork"):
rngs = parent_rngs.fork()
else:
rngs = nnx.Rngs(0)
wrapper = nnx_wrappers.ToNNX(linen_module_cls(name=op_id), rngs=rngs)
wrapper.lazy_init(*args, **kwargs)
setattr(parent, attr_name, wrapper)
return getattr(parent, attr_name)(*args, mutable=["_overwrite_with_gradient"], **kwargs)
else:
return linen_module_cls(name=op_id)(*args, **kwargs)


class NvidaFp8Provider(qwix.QtProvider):
"""Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface."""

Expand All @@ -718,13 +747,13 @@ def dot_general(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)
return _apply_linen_module_in_nnx(nn.Fp8DirectDotGeneralOp, op_id, *args, **kwargs)

def einsum(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("einsum")
if rule is None:
return jnp.einsum(*args, **kwargs)
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)
return _apply_linen_module_in_nnx(nn.Fp8Einsum, op_id, *args, **kwargs)


class NANOOFp8Provider(qwix.QtProvider):
Expand All @@ -734,7 +763,7 @@ def dot_general(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)
return _apply_linen_module_in_nnx(nn.NANOOFp8DotGeneralOp, op_id, *args, **kwargs)


def get_fp8_full_qwix_rule_w_sparsity(config: Config):
Expand Down Expand Up @@ -815,7 +844,21 @@ def maybe_quantize_model(model, config):
if config.use_qwix_quantization and not config.use_batch_split_schedule:
quantization_provider = get_qt_provider(config)
if quantization_provider:
model = qwix.quantize_model(model, quantization_provider)
if config.pure_nnx:
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
dummy_tokens = jnp.ones(input_shape, dtype=jnp.int32)
dummy_positions = jnp.ones(input_shape, dtype=jnp.int32)
dummy_segment_ids = jnp.ones(input_shape, dtype=jnp.int32)
model = qwix.quantize_model(
Comment thread
RexBearIU marked this conversation as resolved.
model,
quantization_provider,
dummy_tokens,
dummy_positions,
dummy_segment_ids,
enable_dropout=False,
)
else:
model = qwix.quantize_model(model, quantization_provider)
return model


Expand Down
17 changes: 10 additions & 7 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from flax import linen as nn, nnx
from flax.linen import partitioning as nn_partitioning
from flax.nnx import variablelib

from maxtext.configs import pyconfig
from maxtext.utils.globals import EPS
Expand Down Expand Up @@ -359,7 +360,9 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
is_train=True,
)
else:
model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...)
owg_type = variablelib.variable_type_from_name("_overwrite_with_gradient", allow_register=True)
custom_param_filter = nnx.Any(owg_type)
model_graphdef, curr_params, custom_params, rest = nnx.split(state.model, nnx.Param, custom_param_filter, ...)
if config.parameter_memory_host_offload:
# Params are kept on host (pinned_host) in in_shardings. Move only Param
# variables to device before the forward/backward pass so that all dot_general
Expand All @@ -381,15 +384,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
)
nnx.update(state.model, curr_params)

def diff_wrapper(param, rest, config, data):
local_model = nnx.merge(model_graphdef, param, rest, copy=True)
def diff_wrapper(curr_params, custom_params, rest, config, data):
local_model = nnx.merge(model_graphdef, curr_params, custom_params, rest, copy=True)
loss, aux = loss_fn(local_model, config, data, None, None, is_train=True)
_, _, new_rest = nnx.split(local_model, nnx.Param, ...)
_, _, _, new_rest = nnx.split(local_model, nnx.Param, custom_param_filter, ...)
return loss, (aux, new_rest)

grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True)
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data)
nnx.update(state.model, new_rest)
grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
(loss, (aux, new_rest)), (raw_grads, custom_grads) = grad_func(curr_params, custom_params, rest, config, data)
nnx.update(state.model, nnx.State.merge(custom_grads, new_rest))

raw_grads = jax.tree_util.tree_map(
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
Expand Down
Loading
Loading