diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 6b1ad870e9..d8857b97b8 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -33,6 +33,15 @@ # to avoid numerical tolerance issues of doing comm gemm overlap, limit the number of GPUs used MAX_GPUS_TO_USE = 4 +COMM_GEMM_QUANTIZATION_PARAMS = [ + pytest.param(False, "none", id="ub-bf16"), + pytest.param(False, "fp8", id="ub-fp8"), + pytest.param(False, "mxfp8", id="ub-mxfp8"), + pytest.param(True, "none", id="cublasmp-bf16"), + pytest.param(True, "fp8", id="cublasmp-fp8"), + pytest.param(True, "mxfp8", id="cublasmp-mxfp8"), +] + TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(torch.cuda.device_count(), MAX_GPUS_TO_USE) LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] @@ -84,10 +93,6 @@ def _run_gemm_with_overlap( if use_cublasmp: if not tex.nvte_built_with_cublasmp(): pytest.skip("Transformer Engine not built with cuBLASMp (NVTE_WITH_CUBLASMP=0).") - if quantization == "mxfp8": - pytest.skip( - "cuBLASMp comm+GEMM overlap does not yet support MXFP8 (block scaling)." - ) if comm_type == "RS" and not p2p and not tex.device_supports_multicast(): pytest.skip( "cuBLASMp non-P2P reduce-scatter requires NVSwitch (multicast support)." @@ -140,8 +145,6 @@ def _run_layer_with_overlap( if use_cublasmp: if not tex.nvte_built_with_cublasmp(): pytest.skip("Transformer Engine not built with cuBLASMp (NVTE_WITH_CUBLASMP=0).") - if fp8 and quantization == "mxfp8": - pytest.skip("cuBLASMp comm+GEMM overlap does not yet support MXFP8 (block scaling).") test_cmd.append("--use-cublasmp") os.environ["PYTORCH_JIT"] = "0" @@ -173,8 +176,7 @@ def _run_layer_with_overlap( raise AssertionError(result.stderr.decode()) -@pytest.mark.parametrize("use_cublasmp", (False, True)) -@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) +@pytest.mark.parametrize("use_cublasmp,quantization", COMM_GEMM_QUANTIZATION_PARAMS) @pytest.mark.parametrize("aggregate", (False, True)) def test_split_all_gather_overlaps(quantization, aggregate, use_cublasmp): """ @@ -184,8 +186,7 @@ def test_split_all_gather_overlaps(quantization, aggregate, use_cublasmp): _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization, use_cublasmp) -@pytest.mark.parametrize("use_cublasmp", (False, True)) -@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) +@pytest.mark.parametrize("use_cublasmp,quantization", COMM_GEMM_QUANTIZATION_PARAMS) @pytest.mark.parametrize("p2p", (False, True)) def test_split_reduce_scatter_overlaps(quantization, p2p, use_cublasmp): """ diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 3545347da4..e7276496da 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -260,28 +260,63 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo const Tensor* d, const Tensor* bias, const Tensor* pre_act_out, bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, cudaStream_t main_stream) { - for (auto t : {a, b, d}) { - NVTE_CHECK(is_tensor_scaling(t->scaling_mode), - "Unsupported scaling mode: " + std::to_string(t->scaling_mode)); + const bool tensor_scaling = + is_tensor_scaling(a->scaling_mode) && is_tensor_scaling(b->scaling_mode); + const bool mxfp8 = is_mxfp_scaling(a->scaling_mode) && is_mxfp_scaling(b->scaling_mode); + + NVTE_CHECK(tensor_scaling || mxfp8, "Unsupported scaling modes: A=", to_string(a->scaling_mode), + ", B=", to_string(b->scaling_mode)); + NVTE_CHECK(is_tensor_scaling(d->scaling_mode), + "Unsupported scaling mode for D: ", to_string(d->scaling_mode)); + + if (mxfp8) { +#if CUBLASMP_VERSION < 801 + NVTE_ERROR("MXFP8 GEMM requires cuBLASMp 0.8.1+."); +#else + NVTE_CHECK(a->with_gemm_swizzled_scales, + "MXFP8 scales for A are not in format expected by GEMM"); + NVTE_CHECK(b->with_gemm_swizzled_scales, + "MXFP8 scales for B are not in format expected by GEMM"); +#endif } - // Mirror cublaslt_gemm.cu's CanonicalizeGemmInput for FP8 tensor scaling: depending on the - // architecture and the quantizer's usage modes, the appropriate data + scale_inv - // may live on the rowwise or columnwise side of the tensor. - // * Hopper (!nvte_is_non_tn_fp8_gemm_supported): only TN FP8 GEMMs are supported, so an - // FP8 input not already in TN orientation must be swapped to its columnwise (transposed) - // view and the transpose flag flipped. - // * Blackwell+ (nvte_is_non_tn_fp8_gemm_supported): any FP8 GEMM layout is supported, but - // the quantizer usage may have only been set to columnwise. In that case, fall back to - // the columnwise view and flip the transpose flag so the GEMM sees the matching data and - // scale_inv pair. - // The original tensor is never modified; a new Tensor view aliases the columnwise pointers. + // Mirror cublaslt_gemm.cu's input canonicalization. Tensor FP8 columnwise data is a + // transposed view, while MXFP8 columnwise data keeps the logical shape. const bool fp8_needs_tn = !nvte_is_non_tn_fp8_gemm_supported(); - auto canonicalize_fp8_input = [fp8_needs_tn](const Tensor* t, bool current_trans, bool want_trans, - const char* side) -> std::pair { + auto canonicalize_input = [fp8_needs_tn](const Tensor* t, bool current_trans, bool is_a, + const char* side) -> std::pair { + auto use_columnwise = [t](bool new_trans) -> std::pair { + Tensor view = *t; + view.data = t->columnwise_data; + view.scale_inv = t->columnwise_scale_inv; + view.amax = t->columnwise_amax; + return {view, new_trans}; + }; + + if (is_mxfp_scaling(t->scaling_mode)) { + if (is_a) { + if (current_trans) { + NVTE_CHECK(t->has_data(), "MXFP8 transposed input A is missing row-wise data"); + return {*t, current_trans}; + } + NVTE_CHECK(t->has_columnwise_data(), + "MXFP8 non-transposed input A is missing column-wise data"); + return use_columnwise(current_trans); + } + if (current_trans) { + NVTE_CHECK(t->has_columnwise_data(), + "MXFP8 transposed input B is missing column-wise data"); + return use_columnwise(current_trans); + } + NVTE_CHECK(t->has_data(), "MXFP8 non-transposed input B is missing row-wise data"); + return {*t, current_trans}; + } + if (!is_fp8_dtype(t->dtype())) { return {*t, current_trans}; } + + const bool want_trans = is_a; const bool hopper_tn_swap = fp8_needs_tn && current_trans != want_trans; const bool blackwell_missing_rowwise = !fp8_needs_tn && !t->has_data(); if (!hopper_tn_swap && !blackwell_missing_rowwise) { @@ -289,16 +324,11 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo } NVTE_CHECK(t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype), "cuBLASMp FP8 GEMM input ", side, " is missing column-wise usage"); - Tensor view; - view.scaling_mode = t->scaling_mode; - view.data = t->columnwise_data; - view.scale_inv = t->columnwise_scale_inv; - // Columnwise data is the transposed view of the original — flip the transpose flag. - return {view, !current_trans}; + return use_columnwise(!current_trans); }; - auto [a_used, transa_eff] = canonicalize_fp8_input(a, transa, /*want_trans=*/true, "A"); - auto [b_used, transb_eff] = canonicalize_fp8_input(b, transb, /*want_trans=*/false, "B"); + auto [a_used, transa_eff] = canonicalize_input(a, transa, /*is_a=*/true, "A"); + auto [b_used, transb_eff] = canonicalize_input(b, transb, /*is_a=*/false, "B"); transa = transa_eff; transb = transb_eff; @@ -321,20 +351,30 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo sizeof algo_attr)); const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; - if (is_fp8_dtype(a_used.dtype())) { - NVTE_CHECK(a_used.scale_inv.dptr, "Scaling must be set for FP8 dtype"); + auto get_input_scale_mode = [&](const Tensor& t) { +#if CUBLASMP_VERSION >= 801 + if (is_mxfp_scaling(t.scaling_mode)) { + return CUBLASMP_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + } +#endif + return scale_mode; + }; + if (is_fp8_dtype(a_used.dtype()) || is_mxfp_scaling(a_used.scaling_mode)) { + const cublasMpMatmulMatrixScale_t input_scale_mode = get_input_scale_mode(a_used); + NVTE_CHECK(a_used.scale_inv.dptr, "Scaling must be set for input A"); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( - ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, - sizeof scale_mode)); + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, + &input_scale_mode, sizeof input_scale_mode)); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, &a_used.scale_inv.dptr, sizeof(void*))); } - if (is_fp8_dtype(b_used.dtype())) { - NVTE_CHECK(b_used.scale_inv.dptr, "Scaling must be set for FP8 dtype"); + if (is_fp8_dtype(b_used.dtype()) || is_mxfp_scaling(b_used.scaling_mode)) { + const cublasMpMatmulMatrixScale_t input_scale_mode = get_input_scale_mode(b_used); + NVTE_CHECK(b_used.scale_inv.dptr, "Scaling must be set for input B"); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( - ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, - sizeof scale_mode)); + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, + &input_scale_mode, sizeof input_scale_mode)); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, &b_used.scale_inv.dptr, sizeof(void*)));