Fix Gemma RMSNorm +1 offset missing on --checkpoint path#19901
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19901
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dce6eca with merge base ec31735 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Fixes incorrect Gemma model behavior when supplying a pre-converted checkpoint via --checkpoint by ensuring Gemma RMSNorm weights are offset by +1 (to match Gemma’s (1 + w) * x convention) on that code path as well.
Changes:
- Apply Gemma RMSNorm
+1weight offset when loading weights from--checkpoint. - Keep Gemma model handling consistent between HF-download and
--checkpointloading paths.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The `--checkpoint` code path skipped the Gemma-specific RMSNorm weight adjustment (`weight + 1`). Gemma stores norm weights as deviations from 1 and computes `(1 + w) * x`, but ExecuTorch's RMSNorm computes `w * x`. The HF download path applied the +1 offset correctly, but passing a converted checkpoint via `--checkpoint` silently produced garbage output from all 36+ norm layers, regardless of quantization recipe.
The
--checkpointcode path skipped the Gemma-specific RMSNorm weight adjustment (weight + 1). Gemma stores norm weights as deviations from 1 and computes(1 + w) * x, but ExecuTorch's RMSNorm computesw * x. The HF download path applied the +1 offset correctly, but passing a converted checkpoint via--checkpointsilently produced garbage output from all 36+ norm layers, regardless of quantization recipe.##Test Plan