From 7606746d03bcdc90132c951b53cdf5bdba3cd073 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Tue, 23 Jun 2026 06:16:24 -0700 Subject: [PATCH 1/3] Preserve fprop operands for dequantized backward override Signed-off-by: Evgeny --- tests/pytorch/test_backward_override.py | 173 ++++++++++++++++++ .../pytorch/module/grouped_linear.py | 2 + transformer_engine/pytorch/module/linear.py | 2 + 3 files changed, 177 insertions(+) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 5e6f36e8b4..e5aa0e8d34 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -858,6 +858,179 @@ def test_backward_override_recipe_matches_requested_mode( assert quant_recipe.backward_override is None +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_linear_backward_override_dequantized_ignores_save_original_input( + recipe_name: str, + use_bias: bool, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + input_shape = (32, 128) + out_features = 128 + _maybe_skip_recipe_dtype(recipe_name, dtype, "linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "linear") + + mode_recipe = make_recipe(recipe_name, backward_override="dequantized") + skip_unsupported_backward_override("linear", mode_recipe, "dequantized") + + module_ref = te.Linear( + input_shape[-1], + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + module_test = te.Linear( + input_shape[-1], + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=True, + ) + _copy_named_parameters(module_ref, module_test) + + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(input_shape[0], out_features, dtype=dtype, device="cuda") + + y_ref, dx_ref, dw_ref, db_ref = _run_single_step(module_ref, x, dy, mode_recipe) + y_test, x_test, saved_operands = _run_single_step_with_saved_operands( + module_test, x, mode_recipe + ) + _assert_saved_quantized_operand_uses_rowwise_only(saved_operands[0], name="linear_input") + + y_test_detached = y_test.detach().clone() + y_test.backward(dy) + assert x_test.grad is not None + assert module_test.weight.grad is not None + dx_test = x_test.grad.detach().clone() + dw_test = module_test.weight.grad.detach().clone() + test_bias = getattr(module_test, "bias", None) + db_test = ( + None + if test_bias is None or test_bias.grad is None + else test_bias.grad.detach().clone() + ) + + assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_test, dw_ref, rtol=0, atol=0, check_dtype=True) + if use_bias: + assert db_test is not None and db_ref is not None + assert_close(db_test, db_ref, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_grouped_linear_backward_override_dequantized_ignores_save_original_input( + recipe_name: str, + use_bias: bool, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + in_features = 128 + out_features = 128 + m_splits = [64, 64] + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + _maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + mode_recipe = make_recipe(recipe_name, backward_override="dequantized") + skip_unsupported_backward_override("grouped_linear", mode_recipe, "dequantized") + + module_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + module_test = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=True, + ) + _copy_named_parameters(module_ref, module_test) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + y_ref, dx_ref, dw_ref, db_ref = _run_grouped_linear_single_step( + module_ref, x, m_splits, dy, mode_recipe + ) + y_test, x_test, saved_operands = _run_grouped_linear_step_with_saved_operands( + module_test, x, m_splits, mode_recipe + ) + saved_inputs = saved_operands[:num_gemms] + for i, saved_input in enumerate(saved_inputs): + _assert_saved_quantized_operand_uses_rowwise_only( + saved_input, name=f"grouped_linear_input{i}" + ) + + y_test_detached = y_test.detach().clone() + y_test.backward(dy) + assert x_test.grad is not None + dx_test = x_test.grad.detach().clone() + dw_test = [ + getattr(module_test, f"weight{i}").grad.detach().clone() for i in range(num_gemms) + ] + db_test: list[Optional[torch.Tensor]] = [] + for i in range(num_gemms): + if use_bias: + db_test.append(getattr(module_test, f"bias{i}").grad.detach().clone()) + else: + db_test.append(None) + + assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True) + for test_dw, ref_dw in zip(dw_test, dw_ref): + assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True) + if use_bias: + for test_db, ref_db in zip(db_test, db_ref): + assert test_db is not None and ref_db is not None + assert_close(test_db, ref_db, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +def test_linear_backward_override_high_precision_forces_save_original_input( + recipe_name: str, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + input_shape = (32, 128) + _maybe_skip_recipe_dtype(recipe_name, dtype, "linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "linear") + + mode_recipe = make_recipe(recipe_name, backward_override="high_precision") + skip_unsupported_backward_override("linear", mode_recipe, "high_precision") + + module = te.Linear( + input_shape[-1], + 128, + bias=False, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + + _, _, saved_operands = _run_single_step_with_saved_operands(module, x, mode_recipe) + + assert isinstance(saved_operands[0], torch.Tensor) + + @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize("module_type", ("linear", "layernorm_linear", "ops_linear")) @pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0ddfaacad1..699823c013 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -431,6 +431,8 @@ def forward( backward_override = None if backward_override == "high_precision": save_original_input = True + elif backward_override == "dequantized": + save_original_input = False num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..c68c689b97 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -285,6 +285,8 @@ def _linear_forward_impl( is_fsdp2 = args.is_fsdp2 if backward_override == "high_precision": save_original_input = True + elif backward_override == "dequantized": + save_original_input = False # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" From 941db9b0d5457c986f6b9a06bd960d6186c8a592 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jun 2026 13:21:48 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_backward_override.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index e5aa0e8d34..ac05cc1658 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -910,9 +910,7 @@ def test_linear_backward_override_dequantized_ignores_save_original_input( dw_test = module_test.weight.grad.detach().clone() test_bias = getattr(module_test, "bias", None) db_test = ( - None - if test_bias is None or test_bias.grad is None - else test_bias.grad.detach().clone() + None if test_bias is None or test_bias.grad is None else test_bias.grad.detach().clone() ) assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True) @@ -982,9 +980,7 @@ def test_grouped_linear_backward_override_dequantized_ignores_save_original_inpu y_test.backward(dy) assert x_test.grad is not None dx_test = x_test.grad.detach().clone() - dw_test = [ - getattr(module_test, f"weight{i}").grad.detach().clone() for i in range(num_gemms) - ] + dw_test = [getattr(module_test, f"weight{i}").grad.detach().clone() for i in range(num_gemms)] db_test: list[Optional[torch.Tensor]] = [] for i in range(num_gemms): if use_bias: From 6aef491d027f5b5e817a5f28badccccee0dd6872 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 25 Jun 2026 06:30:17 -0700 Subject: [PATCH 3/3] Add test_grouped_linear_backward_override_high_precision_forces_save_original_input test Signed-off-by: root --- tests/pytorch/test_backward_override.py | 43 +++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ac05cc1658..c0acf2e6b3 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -1027,6 +1027,49 @@ def test_linear_backward_override_high_precision_forces_save_original_input( assert isinstance(saved_operands[0], torch.Tensor) +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +def test_grouped_linear_backward_override_high_precision_forces_save_original_input( + recipe_name: str, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + in_features = 128 + out_features = 128 + m_splits = [64, 64] + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + _maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + mode_recipe = make_recipe(recipe_name, backward_override="high_precision") + skip_unsupported_backward_override("grouped_linear", mode_recipe, "high_precision") + + module = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + + _, _, saved_operands = _run_grouped_linear_step_with_saved_operands( + module, x, m_splits, mode_recipe + ) + + saved_inputs = saved_operands[:num_gemms] + assert isinstance(saved_inputs[0], torch.Tensor) + assert saved_inputs[0].shape == x.shape + assert all(saved_input is None for saved_input in saved_inputs[1:]) + + saved_weights = saved_operands[2 * num_gemms : 3 * num_gemms] + for saved_weight in saved_weights: + assert isinstance(saved_weight, torch.Tensor) + + @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize("module_type", ("linear", "layernorm_linear", "ops_linear")) @pytest.mark.parametrize("input_shape,out_features", _shape_test_cases)