Skip to content

Fix step-1 NaN: run gemma attention + DeepSeek sparse-indexer logits in fp32#4136

Draft
ecnal-cienet wants to merge 2 commits into
mainfrom
fix/attention-fp32-numerics
Draft

Fix step-1 NaN: run gemma attention + DeepSeek sparse-indexer logits in fp32#4136
ecnal-cienet wants to merge 2 commits into
mainfrom
fix/attention-fp32-numerics

Conversation

@ecnal-cienet

Copy link
Copy Markdown
Collaborator

Description

Two pre-existing numerical-stability bugs produce an early-step NaN loss from bf16 attention-logit overflow. Both reproduce on the Linen path too — so they are not NNX-specific — and are fixed here as one coherent "fp32 for stability" change. (Surfaced by the NNX-defaults e2e matrix: gemma4-31b 09_pdb_lt_1 / 13_scan_layers_false, and deepseek3-671b 30_indexer_sparse.) Two independent commits.

1. gemma4 / gemma3 attention (configs/models/gemma4-31b.yml, gemma4-26b.yml, gemma3-27b.yml).
gemma4 uses head_dim=256 with no attn_logits_soft_cap, and float32_qk_product / float32_logits default to false — so the qk product and softmax inputs run in bf16 and can overflow to inf/nan once the step-0 update perturbs the weights off a clean init. step 0 is finite (loss ~12.98), step 1 is nan → "Aborting training due to NaN loss". gemma2-27b avoids this because it sets attn_logits_soft_cap: 50.0; Gemma3/4 dropped the softcap in favor of qk-norm.

Fix: set float32_qk_product: true + float32_logits: true in the gemma model ymls. This is semantically identical to the model — it just runs the qk product and softmax inputs in fp32 (the gates already exist in layers/attention_op.py). Chosen over re-introducing attn_logits_soft_cap, which would tanh-compress the logits and change the attention distribution (not faithful to the gemma4 architecture). Applied to gemma4-31b, gemma4-26b, and gemma3-27b — they share the same softcap-less attention; 26b/27b are preventive (only 31b is currently in the failing matrix).

2. DeepSeek sparse indexer (layers/attention_mla.py).
The DeepSeek-V3.2 sparse indexer computes its qk product at matmul_precision (bf16) and relus it in bf16, while weights_proj is already fp32 — so large bf16 logits can overflow to inf and propagate to a NaN loss. Fix: cast the indexer logits to fp32 before relu/aggregation, matching how the main attention runs its softmax in fp32. Only affects the use_indexer=True path. This is the leading cause for the 30_indexer_sparse NaN; the qk-overflow class matches the gemma case, but the final NaN-clearing on the 671B model is pending the e2e run (see Tests) — the commit is independent and can be dropped if it doesn't clear it.

Tests

  • gemma: config-load smoke confirms float32_qk_product / float32_logits resolve true for gemma4-31b/26b and gemma3-27b. The NaN itself only reproduces at scale (31B / V6e-32) — re-run gemma4-31b 09_pdb_lt_1 and 13_scan_layers_false and confirm step-1 loss is finite (was nan).
  • indexer: CPU smoke (model_name=deepseek3-tiny use_indexer=True indexer_sparse_training=True indexer_topk=4 attention=dot_product megablox=False, sized so the indexer qk path actually runs) trains 2 steps with finite loss (12.339 → 12.295), no crash. Full NaN-clearing on deepseek3-671b 30_indexer_sparse needs V6e-32.
  • bash lint.sh clean.

Stats

  • +15 across 4 files, 2 commits (3 gemma ymls; attention_mla.py).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

gemma4-31b/26b use head_dim=256 with no attn_logits_soft_cap, so bf16 attention
logits can overflow to inf/nan at step 1 (fails on both Linen and NNX). Setting
float32_qk_product and float32_logits runs the qk product and softmax inputs in
fp32 — semantically identical, just stable. Applied to gemma4-31b, gemma4-26b,
and gemma3-27b (same softcap-less attention).
The indexer qk product is computed at matmul_precision (bf16) and relu'd in bf16
while weights_proj is fp32; large bf16 logits can overflow to inf and propagate
to a NaN loss. Cast the indexer logits to fp32 before relu/aggregation, matching
how the main attention runs softmax in fp32.
@codecov

codecov Bot commented Jun 10, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant