From bae4e3700610170ef2a5cab1dc95241dc4eee0b0 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Sat, 30 May 2026 01:38:16 -0700 Subject: [PATCH] Fix Gemma RMSNorm +1 offset missing on --checkpoint path The `--checkpoint` code path skipped the Gemma-specific RMSNorm weight adjustment (`weight + 1`). Gemma stores norm weights as deviations from 1 and computes `(1 + w) * x`, but ExecuTorch's RMSNorm computes `w * x`. The HF download path applied the +1 offset correctly, but passing a converted checkpoint via `--checkpoint` silently produced garbage output from all 36+ norm layers, regardless of quantization recipe. --- .../oss_scripts/llama/wrappers/llm_wrappers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 0d5052c89bd..73279733711 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -192,6 +192,19 @@ def _prepare_model(self): # noqa: C901 k.replace("_orig_mod.", ""): v for k, v in state_dict.items() } + if self.control_args.decoder_model in { + "gemma-2b", + "gemma2-2b", + "gemma3-1b", + }: + for k, v in state_dict.items(): + if "norm" not in k: + continue + # Gemma RMSNorm uses (1 + w) * x, so converted checkpoints + # that haven't been offset need +1 applied here. + # See https://github.com/huggingface/transformers/pull/29402 + state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32) + # change to HF weight to improve the performance of RoPE in HTP backend. if self.config.transform_weight: