Graph Safe Current Scaling Support for GroupedLinear Module/Ops#3143
Graph Safe Current Scaling Support for GroupedLinear Module/Ops#3143vthumbe1503 wants to merge 3 commits into
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Removed details about FP8 current scaling methods. Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR extends the graph-safe grouped-tensor / cuBLASLt GEMM path (
Confidence Score: 3/5Not safe to merge as-is: the module-level forward stores a None activation for the weight gradient computation on any non-FP8 grouped-tensor training pass where weights require gradients. The transformer_engine/pytorch/module/grouped_linear.py lines 335–338 need immediate attention; tests/pytorch/test_grouped_mlp.py has a minor gap in Hopper coverage for the new fp8_current_scaling case. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[GroupedLinear forward] --> B{use_grouped_tensor_path?}
B -->|CC check 9.0-11.0 passes| C{fp8?}
B -->|CC out of range| Z[Legacy path]
C -->|Yes| D{All Float8CurrentScalingQuantizer?}
D -->|Yes| E[Return True — Hopper + Blackwell supported]
D -->|No| F{CC >= 10.0 Blackwell?}
F -->|Yes| G{All MXFP8 or NVFP4+RHT?}
G -->|Yes| H[Return True — Blackwell only]
G -->|No| Z
F -->|No| Z
C -->|No BF16/FP16| I{dtype BF16 or FP16?}
I -->|Yes| J[Return True]
I -->|No| Z
E --> K[grouped_x = tex.group_quantize with columnwise=weight_requires_grad]
H --> K
J --> L[grouped_x = GroupedTensorStorage rowwise only]
K --> M[general_grouped_gemm_for_grouped_tensor]
L --> M
M --> N{is_grad_enabled and weight_requires_grad?}
N -->|ops path: with_quantized_compute and columnwise_data != None| O[Free rowwise_data — correct]
N -->|module path: BUGGY unconditional| P[Free rowwise_data — breaks BF16/FP16 wgrad]
O --> Q[save_for_backward]
P --> Q
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[GroupedLinear forward] --> B{use_grouped_tensor_path?}
B -->|CC check 9.0-11.0 passes| C{fp8?}
B -->|CC out of range| Z[Legacy path]
C -->|Yes| D{All Float8CurrentScalingQuantizer?}
D -->|Yes| E[Return True — Hopper + Blackwell supported]
D -->|No| F{CC >= 10.0 Blackwell?}
F -->|Yes| G{All MXFP8 or NVFP4+RHT?}
G -->|Yes| H[Return True — Blackwell only]
G -->|No| Z
F -->|No| Z
C -->|No BF16/FP16| I{dtype BF16 or FP16?}
I -->|Yes| J[Return True]
I -->|No| Z
E --> K[grouped_x = tex.group_quantize with columnwise=weight_requires_grad]
H --> K
J --> L[grouped_x = GroupedTensorStorage rowwise only]
K --> M[general_grouped_gemm_for_grouped_tensor]
L --> M
M --> N{is_grad_enabled and weight_requires_grad?}
N -->|ops path: with_quantized_compute and columnwise_data != None| O[Free rowwise_data — correct]
N -->|module path: BUGGY unconditional| P[Free rowwise_data — breaks BF16/FP16 wgrad]
O --> Q[save_for_backward]
P --> Q
|
| # Free Rowwise Data if columnwise data is available for backward pass | ||
| # (For FP8 per tensor current scaling on Hopper) if fp8 and grouped_x.columnwise_data is not None: | ||
| grouped_x.rowwise_data = None | ||
| grouped_x.scale_inv = None |
There was a problem hiding this comment.
Conditional guard accidentally embedded in comment — rowwise_data cleared unconditionally
The if fp8 and grouped_x.columnwise_data is not None: guard was intended to precede the two assignments on lines 337–338, but it was appended to the end of the preceding comment on line 336. As Python ignores everything after #, both grouped_x.rowwise_data = None and grouped_x.scale_inv = None now execute unconditionally whenever is_grad_enabled and weight_requires_grad are both True.
For the non-FP8 (BF16/FP16) grouped-tensor path, grouped_x.rowwise_data holds the packed activation buffer that is saved for backward and used to compute the weight gradient. Clearing it to None before ctx.save_for_backward destroys the activation data, causing the wgrad computation to operate on None — resulting in a crash or silently incorrect gradients.
The equivalent change in ops/basic/grouped_linear.py (line 1335) correctly places the condition on its own line: if with_quantized_compute and grouped_x.columnwise_data is not None:.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: