From bad0b2c9af774f72f4932e477d2c198017fb17bb Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 25 Jun 2026 00:40:35 +0000 Subject: [PATCH 1/3] support in grouped linear and relevant tests Signed-off-by: Varun Thumbe --- tests/pytorch/test_grouped_linear.py | 31 +++++++++++++++---- tests/pytorch/test_grouped_mlp.py | 7 ++++- .../pytorch/module/grouped_linear.py | 16 +++++++--- .../pytorch/ops/basic/grouped_linear.py | 25 ++++++++++++--- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index caa84ec02a..01a7cf2415 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -1496,6 +1496,7 @@ def test_fp8_grouped_gemm(shape, accumulate): _FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM" _ALL_BOOLEAN = all_boolean +_fp8_available, _reason_for_no_fp8 = fp8_available, reason_for_no_fp8 _mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8 _nvfp4_available, _reason_for_no_nvfp4 = nvfp4_available, reason_for_no_nvfp4 @@ -1577,6 +1578,10 @@ def _run_grouped_linear_path( "fp8_recipe", [ None, + pytest.param( + recipe.Float8CurrentScaling(), + marks=pytest.mark.skipif(not _fp8_available, reason=_reason_for_no_fp8), + ), pytest.param( recipe.MXFP8BlockScaling(), marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), @@ -1586,7 +1591,7 @@ def _run_grouped_linear_path( marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4), ), ], - ids=["bf16", "mxfp8", "nvfp4"], + ids=["bf16", "fp8_current_scaling", "mxfp8", "nvfp4"], ) @pytest.mark.parametrize("bias", _ALL_BOOLEAN) @pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN) @@ -1600,8 +1605,13 @@ def test_grouped_linear_grouped_tensor_path_matches_legacy( pytest.skip( "GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)." ) - if use_fp8 and device_capability < (10, 0): - pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).") + # MXFP8/NVFP4 grouped quantization kernels require Blackwell, but FP8 per-tensor + # current scaling also runs on the Hopper grouped GEMM path. + is_current_scaling = use_fp8 and fp8_recipe.float8_current_scaling() + if use_fp8 and not is_current_scaling and device_capability < (10, 0): + pytest.skip( + "Quantized GroupedTensor grouped GEMM path (MXFP8/NVFP4) requires Blackwell (SM100+)." + ) cublaslt_version = tex.get_cublasLt_version() if device_capability < (10, 0) and cublaslt_version < 130400: pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.") @@ -1786,6 +1796,10 @@ def test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4(monkeypatch): "fp8_recipe", [ None, + pytest.param( + recipe.Float8CurrentScaling(), + marks=pytest.mark.skipif(not _fp8_available, reason=_reason_for_no_fp8), + ), pytest.param( recipe.MXFP8BlockScaling(), marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), @@ -1795,7 +1809,7 @@ def test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4(monkeypatch): marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4), ), ], - ids=["bf16", "mxfp8", "nvfp4"], + ids=["bf16", "fp8_current_scaling", "mxfp8", "nvfp4"], ) @pytest.mark.parametrize("bias", _ALL_BOOLEAN) def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch): @@ -1806,8 +1820,13 @@ def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch pytest.skip( "GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)." ) - if use_fp8 and device_capability < (10, 0): - pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).") + # MXFP8/NVFP4 grouped quantization kernels require Blackwell, but FP8 per-tensor + # current scaling also runs on the Hopper grouped GEMM path. + is_current_scaling = use_fp8 and fp8_recipe.float8_current_scaling() + if use_fp8 and not is_current_scaling and device_capability < (10, 0): + pytest.skip( + "Quantized GroupedTensor grouped GEMM path (MXFP8/NVFP4) requires Blackwell (SM100+)." + ) cublaslt_version = tex.get_cublasLt_version() if device_capability < (10, 0) and cublaslt_version < 130400: pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.") diff --git a/tests/pytorch/test_grouped_mlp.py b/tests/pytorch/test_grouped_mlp.py index cb90ac6bd9..606223ef1e 100644 --- a/tests/pytorch/test_grouped_mlp.py +++ b/tests/pytorch/test_grouped_mlp.py @@ -439,7 +439,10 @@ def test_grouped_linear( @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) @pytest.mark.parametrize( "quantization", - [None] + (["mxfp8"] if mxfp8_available else []), + [None] + + (["fp8_current_scaling"] if fp8_available else []) + + (["mxfp8"] if mxfp8_available else []) + + (["nvfp4_rht"] if nvfp4_available else []), ) @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("bias", (False, True)) @@ -479,6 +482,8 @@ def test_grouped_linear_cuda_graph_safe( pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)") if quantization is None and quantized_weight: pytest.skip("quantized_weight requires a quantization recipe") + if quantization is not None and quantization.startswith("nvfp4") and dtype != torch.bfloat16: + pytest.skip("NVFP4 grouped GEMM only supports BF16 output") single_grouped_bias = bias and single_grouped_weight diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8d56e423c4..4630f128fa 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -103,9 +103,13 @@ def _is_grouped_tensor_path_supported( and be incompatible with CUDA Graphs. Supported Compute Capability (CC) and precisions: - * Hopper (CC 9.0): BF16/FP16. - * Blackwell (CC 10.x and 11.0): BF16/FP16/MXFP8/NVFP4 with RHT. - FP8 delayed / current scaling, and FP8 block scaling are not supported because the + * Hopper (CC 9.0): BF16/FP16 and FP8 per-tensor current scaling. + * Blackwell (CC 10.x and 11.0): BF16/FP16/MXFP8/NVFP4 with RHT and FP8 + per-tensor current scaling. + FP8 current scaling uses grouped current-scaling quantization + (``tex.group_quantize``) plus cuBLASLt grouped GEMM with per-batch scalar FP8 + scaling, both of which are available on Hopper and Blackwell. + FP8 delayed scaling and FP8 block scaling are not supported because the corresponding grouped quantization kernels are missing. Non-RHT NVFP4 falls back to the legacy path because graph-safe grouped quantization currently requires RHT. @@ -133,6 +137,9 @@ def _is_grouped_tensor_path_supported( return False # 5. Filter by quantization recipes. if fp8: + if all(isinstance(q, Float8CurrentScalingQuantizer) for q in input_quantizers): + return True + # MXFP8 and NVFP4 grouped quantization kernels require Blackwell. if not (10, 0) <= get_device_compute_capability() <= (11, 0): return False return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all( @@ -328,7 +335,8 @@ def _forward_grouped_tensor( if is_grad_enabled: if weight_requires_grad: - if fp8: + # Free Rowwise Data if columnwise data is available for backward pass + # (For FP8 per tensor current scaling on Hopper) if fp8 and grouped_x.columnwise_data is not None: grouped_x.rowwise_data = None grouped_x.scale_inv = None else: diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 4bbd75bc64..bce7f0d627 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -26,7 +26,13 @@ from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...quantized_tensor import QuantizedTensorStorage -from ...tensor import MXFP8Quantizer, MXFP8Tensor, NVFP4Quantizer, Quantizer +from ...tensor import ( + Float8CurrentScalingQuantizer, + MXFP8Quantizer, + MXFP8Tensor, + NVFP4Quantizer, + Quantizer, +) from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -768,8 +774,11 @@ def _is_graph_safe_path_supported( requirement without duplicating its cuBLAS version checks. * Quantized compute supports MXFP8 and NVFP4 on Blackwell GPUs with Compute Capability (CC) 10.x and 11.0. NVFP4 requires RHT because graph-safe grouped quantization currently - requires RHT; - Every other quantization recipe (fp8 delayed / current scaling, fp8 block scaling, ...) + requires RHT. + * FP8 per-tensor current scaling is backed by grouped current-scaling quantization + (``tex.group_quantize``) and cuBLASLt grouped GEMM with per-batch scalar FP8 scaling, + which are supported on Hopper (CC 9.0) and Blackwell (CC 10.x and 11.0). + Every other quantization recipe (fp8 delayed scaling, fp8 block scaling, ...) falls back to the legacy flow because the corresponding grouped quantization kernels are missing. * Unquantized compute supports BF16/FP16 on Hopper (CC 9.0) and Blackwell (CC 10.x and 11.0) @@ -780,6 +789,11 @@ def _is_graph_safe_path_supported( if not (9, 0) <= get_device_compute_capability() <= (11, 0): return False if with_quantized_compute: + # FP8 per-tensor current scaling runs on the Hopper and Blackwell grouped GEMM + # path; the compute-capability range was already checked above. + if all(isinstance(q, Float8CurrentScalingQuantizer) for q in input_quantizers): + return True + # MXFP8 and NVFP4 grouped quantization kernels require Blackwell. if not (10, 0) <= get_device_compute_capability() <= (11, 0): return False return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all( @@ -1318,8 +1332,9 @@ def _fuser_forward_grouped_tensor( # [split_sizes, base_split_offsets, split_points, # (scales if _scale_bias), grouped_x, *weights] if grouped_x is not None: - if with_quantized_compute: - # only columnwise data is needed for wgrad + # Free Rowwise Data if columnwise data is available for backward pass + # (For FP8 per tensor current scaling on Hopper) + if with_quantized_compute and grouped_x.columnwise_data is not None: grouped_x.rowwise_data = None grouped_x.scale_inv = None saved: list[Optional[torch.Tensor]] = [split_sizes, base_split_offsets, split_points] From 9af6df6a1663dfc6d668d28dd1caf2a34b68487f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jun 2026 00:52:39 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_grouped_mlp.py | 6 +++++- transformer_engine/pytorch/module/grouped_linear.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_grouped_mlp.py b/tests/pytorch/test_grouped_mlp.py index 606223ef1e..0fcffd23f5 100644 --- a/tests/pytorch/test_grouped_mlp.py +++ b/tests/pytorch/test_grouped_mlp.py @@ -482,7 +482,11 @@ def test_grouped_linear_cuda_graph_safe( pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)") if quantization is None and quantized_weight: pytest.skip("quantized_weight requires a quantization recipe") - if quantization is not None and quantization.startswith("nvfp4") and dtype != torch.bfloat16: + if ( + quantization is not None + and quantization.startswith("nvfp4") + and dtype != torch.bfloat16 + ): pytest.skip("NVFP4 grouped GEMM only supports BF16 output") single_grouped_bias = bias and single_grouped_weight diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4630f128fa..71ac771d2e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -335,10 +335,10 @@ def _forward_grouped_tensor( if is_grad_enabled: if weight_requires_grad: - # Free Rowwise Data if columnwise data is available for backward pass - # (For FP8 per tensor current scaling on Hopper) if fp8 and grouped_x.columnwise_data is not None: - grouped_x.rowwise_data = None - grouped_x.scale_inv = None + # Free Rowwise Data if columnwise data is available for backward pass + # (For FP8 per tensor current scaling on Hopper) if fp8 and grouped_x.columnwise_data is not None: + grouped_x.rowwise_data = None + grouped_x.scale_inv = None else: grouped_x = None From bd0832cfff43b70f1d760de09b9c3ff23004c1d3 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Wed, 24 Jun 2026 17:56:22 -0700 Subject: [PATCH 3/3] Unecessary details remove Removed details about FP8 current scaling methods. Signed-off-by: vthumbe1503 --- transformer_engine/pytorch/module/grouped_linear.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 71ac771d2e..e962b4dc8a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -106,9 +106,6 @@ def _is_grouped_tensor_path_supported( * Hopper (CC 9.0): BF16/FP16 and FP8 per-tensor current scaling. * Blackwell (CC 10.x and 11.0): BF16/FP16/MXFP8/NVFP4 with RHT and FP8 per-tensor current scaling. - FP8 current scaling uses grouped current-scaling quantization - (``tex.group_quantize``) plus cuBLASLt grouped GEMM with per-batch scalar FP8 - scaling, both of which are available on Hopper and Blackwell. FP8 delayed scaling and FP8 block scaling are not supported because the corresponding grouped quantization kernels are missing. Non-RHT NVFP4 falls back to the legacy path because graph-safe grouped quantization