diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 5e6f36e8b4..c0acf2e6b3 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -858,6 +858,218 @@ def test_backward_override_recipe_matches_requested_mode( assert quant_recipe.backward_override is None +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_linear_backward_override_dequantized_ignores_save_original_input( + recipe_name: str, + use_bias: bool, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + input_shape = (32, 128) + out_features = 128 + _maybe_skip_recipe_dtype(recipe_name, dtype, "linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "linear") + + mode_recipe = make_recipe(recipe_name, backward_override="dequantized") + skip_unsupported_backward_override("linear", mode_recipe, "dequantized") + + module_ref = te.Linear( + input_shape[-1], + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + module_test = te.Linear( + input_shape[-1], + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=True, + ) + _copy_named_parameters(module_ref, module_test) + + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(input_shape[0], out_features, dtype=dtype, device="cuda") + + y_ref, dx_ref, dw_ref, db_ref = _run_single_step(module_ref, x, dy, mode_recipe) + y_test, x_test, saved_operands = _run_single_step_with_saved_operands( + module_test, x, mode_recipe + ) + _assert_saved_quantized_operand_uses_rowwise_only(saved_operands[0], name="linear_input") + + y_test_detached = y_test.detach().clone() + y_test.backward(dy) + assert x_test.grad is not None + assert module_test.weight.grad is not None + dx_test = x_test.grad.detach().clone() + dw_test = module_test.weight.grad.detach().clone() + test_bias = getattr(module_test, "bias", None) + db_test = ( + None if test_bias is None or test_bias.grad is None else test_bias.grad.detach().clone() + ) + + assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_test, dw_ref, rtol=0, atol=0, check_dtype=True) + if use_bias: + assert db_test is not None and db_ref is not None + assert_close(db_test, db_ref, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_grouped_linear_backward_override_dequantized_ignores_save_original_input( + recipe_name: str, + use_bias: bool, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + in_features = 128 + out_features = 128 + m_splits = [64, 64] + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + _maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + mode_recipe = make_recipe(recipe_name, backward_override="dequantized") + skip_unsupported_backward_override("grouped_linear", mode_recipe, "dequantized") + + module_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + module_test = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + save_original_input=True, + ) + _copy_named_parameters(module_ref, module_test) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + y_ref, dx_ref, dw_ref, db_ref = _run_grouped_linear_single_step( + module_ref, x, m_splits, dy, mode_recipe + ) + y_test, x_test, saved_operands = _run_grouped_linear_step_with_saved_operands( + module_test, x, m_splits, mode_recipe + ) + saved_inputs = saved_operands[:num_gemms] + for i, saved_input in enumerate(saved_inputs): + _assert_saved_quantized_operand_uses_rowwise_only( + saved_input, name=f"grouped_linear_input{i}" + ) + + y_test_detached = y_test.detach().clone() + y_test.backward(dy) + assert x_test.grad is not None + dx_test = x_test.grad.detach().clone() + dw_test = [getattr(module_test, f"weight{i}").grad.detach().clone() for i in range(num_gemms)] + db_test: list[Optional[torch.Tensor]] = [] + for i in range(num_gemms): + if use_bias: + db_test.append(getattr(module_test, f"bias{i}").grad.detach().clone()) + else: + db_test.append(None) + + assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True) + for test_dw, ref_dw in zip(dw_test, dw_ref): + assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True) + if use_bias: + for test_db, ref_db in zip(db_test, db_ref): + assert test_db is not None and ref_db is not None + assert_close(test_db, ref_db, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +def test_linear_backward_override_high_precision_forces_save_original_input( + recipe_name: str, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + input_shape = (32, 128) + _maybe_skip_recipe_dtype(recipe_name, dtype, "linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "linear") + + mode_recipe = make_recipe(recipe_name, backward_override="high_precision") + skip_unsupported_backward_override("linear", mode_recipe, "high_precision") + + module = te.Linear( + input_shape[-1], + 128, + bias=False, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + + _, _, saved_operands = _run_single_step_with_saved_operands(module, x, mode_recipe) + + assert isinstance(saved_operands[0], torch.Tensor) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +def test_grouped_linear_backward_override_high_precision_forces_save_original_input( + recipe_name: str, +) -> None: + reset_rng_states() + dtype = torch.bfloat16 + in_features = 128 + out_features = 128 + m_splits = [64, 64] + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + _maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + mode_recipe = make_recipe(recipe_name, backward_override="high_precision") + skip_unsupported_backward_override("grouped_linear", mode_recipe, "high_precision") + + module = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + save_original_input=False, + ) + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + + _, _, saved_operands = _run_grouped_linear_step_with_saved_operands( + module, x, m_splits, mode_recipe + ) + + saved_inputs = saved_operands[:num_gemms] + assert isinstance(saved_inputs[0], torch.Tensor) + assert saved_inputs[0].shape == x.shape + assert all(saved_input is None for saved_input in saved_inputs[1:]) + + saved_weights = saved_operands[2 * num_gemms : 3 * num_gemms] + for saved_weight in saved_weights: + assert isinstance(saved_weight, torch.Tensor) + + @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize("module_type", ("linear", "layernorm_linear", "ops_linear")) @pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8d56e423c4..f2fa8b657e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -431,6 +431,8 @@ def forward( backward_override = None if backward_override == "high_precision": save_original_input = True + elif backward_override == "dequantized": + save_original_input = False num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 9d033fa01e..dd4d6d6162 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -285,6 +285,8 @@ def _linear_forward_impl( is_fsdp2 = args.is_fsdp2 if backward_override == "high_precision": save_original_input = True + elif backward_override == "dequantized": + save_original_input = False # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward"