[PyTorch] Preserve fprop operands for dequantized backward override#3141
[PyTorch] Preserve fprop operands for dequantized backward override#3141negvet wants to merge 4 commits into
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
cc @zianglih |
for more information, see https://pre-commit.ci
|
/te-ci L0 L1 |
Greptile SummaryThis PR fixes a semantic conflict in
Confidence Score: 5/5Safe to merge — the two-line change in each module has no effect outside FP8 mode and its interaction with the downstream backward logic has been verified. The fix is minimal and scoped: it only activates when FP8 is enabled and the recipe explicitly sets backward_override=dequantized. Both modules already null-out backward_override on the non-FP8 path, so non-FP8 callers are completely unaffected. The backward logic in grouped_linear.py uses if ctx.save_original_input / elif ctx.backward_override == dequantized branches that are now correctly exclusive. Four new parametrised tests — including bit-exact gradient comparisons — give strong coverage of both override modes for both modules. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Linear / GroupedLinear forward] --> B{fp8 active?}
B -- No --> C[backward_override = None]
B -- Yes --> D[backward_override = recipe.backward_override]
C --> E[save_original_input = module setting]
D --> F{backward_override?}
F -- high_precision --> G[save_original_input = True]
F -- dequantized --> H[save_original_input = False]
F -- None --> E
G --> I[Save original inp tensor for re-quantisation in backward]
H --> J[Save fprop-quantized QuantizedTensorStorage rowwise-only layout]
E --> K{module.save_original_input?}
K -- True --> I
K -- False --> J
I --> L[backward wgrad: re-split + requantize from original]
J --> M[backward wgrad: dequantize from fprop quantized tensor]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[Linear / GroupedLinear forward] --> B{fp8 active?}
B -- No --> C[backward_override = None]
B -- Yes --> D[backward_override = recipe.backward_override]
C --> E[save_original_input = module setting]
D --> F{backward_override?}
F -- high_precision --> G[save_original_input = True]
F -- dequantized --> H[save_original_input = False]
F -- None --> E
G --> I[Save original inp tensor for re-quantisation in backward]
H --> J[Save fprop-quantized QuantizedTensorStorage rowwise-only layout]
E --> K{module.save_original_input?}
K -- True --> I
K -- False --> J
I --> L[backward wgrad: re-split + requantize from original]
J --> M[backward wgrad: dequantize from fprop quantized tensor]
Reviews (2): Last reviewed commit: "Merge branch 'main' into fix_dequantized..." | Re-trigger Greptile |
|
Thanks for the fix! |
…original_input test Signed-off-by: root <root@prenyx0017.a51.clusters.nvidia.com>
|
/te-ci L0 L1 |
Description
Follow-up to #2644, which introduced
NVTE_BACKWARD_OVERRIDE=high_precision|dequantized.high_precisionis intended to use original unquantized tensor in backward, whiledequantizedis intended to use dequantized tensor from the forward-quantized one. However,save_original_input=Truecould override thedequantizedbehavior inLinearandGroupedLinear, causing backward to use the original input instead of the fprop-quantized operand.This PR makes the override semantics explicit:
backward_override="high_precision"forcessave_original_input=Truebackward_override="dequantized"forcessave_original_input=FalseFixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: