Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@
# to avoid numerical tolerance issues of doing comm gemm overlap, limit the number of GPUs used
MAX_GPUS_TO_USE = 4

COMM_GEMM_QUANTIZATION_PARAMS = [
pytest.param(False, "none", id="ub-bf16"),
pytest.param(False, "fp8", id="ub-fp8"),
pytest.param(False, "mxfp8", id="ub-mxfp8"),
pytest.param(True, "none", id="cublasmp-bf16"),
pytest.param(True, "fp8", id="cublasmp-fp8"),
pytest.param(True, "mxfp8", id="cublasmp-mxfp8"),
]

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), MAX_GPUS_TO_USE)
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
Expand Down Expand Up @@ -84,10 +93,6 @@ def _run_gemm_with_overlap(
if use_cublasmp:
if not tex.nvte_built_with_cublasmp():
pytest.skip("Transformer Engine not built with cuBLASMp (NVTE_WITH_CUBLASMP=0).")
if quantization == "mxfp8":
pytest.skip(
"cuBLASMp comm+GEMM overlap does not yet support MXFP8 (block scaling)."
)
if comm_type == "RS" and not p2p and not tex.device_supports_multicast():
pytest.skip(
"cuBLASMp non-P2P reduce-scatter requires NVSwitch (multicast support)."
Expand Down Expand Up @@ -140,8 +145,6 @@ def _run_layer_with_overlap(
if use_cublasmp:
if not tex.nvte_built_with_cublasmp():
pytest.skip("Transformer Engine not built with cuBLASMp (NVTE_WITH_CUBLASMP=0).")
if fp8 and quantization == "mxfp8":
pytest.skip("cuBLASMp comm+GEMM overlap does not yet support MXFP8 (block scaling).")
test_cmd.append("--use-cublasmp")

os.environ["PYTORCH_JIT"] = "0"
Expand Down Expand Up @@ -173,8 +176,7 @@ def _run_layer_with_overlap(
raise AssertionError(result.stderr.decode())


@pytest.mark.parametrize("use_cublasmp", (False, True))
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
@pytest.mark.parametrize("use_cublasmp,quantization", COMM_GEMM_QUANTIZATION_PARAMS)
@pytest.mark.parametrize("aggregate", (False, True))
def test_split_all_gather_overlaps(quantization, aggregate, use_cublasmp):
"""
Expand All @@ -184,8 +186,7 @@ def test_split_all_gather_overlaps(quantization, aggregate, use_cublasmp):
_run_gemm_with_overlap("AG", False, True, False, aggregate, quantization, use_cublasmp)


@pytest.mark.parametrize("use_cublasmp", (False, True))
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
@pytest.mark.parametrize("use_cublasmp,quantization", COMM_GEMM_QUANTIZATION_PARAMS)
@pytest.mark.parametrize("p2p", (False, True))
def test_split_reduce_scatter_overlaps(quantization, p2p, use_cublasmp):
"""
Expand Down
104 changes: 72 additions & 32 deletions transformer_engine/common/comm_gemm/comm_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,45 +260,75 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
const Tensor* d, const Tensor* bias, const Tensor* pre_act_out, bool transa,
bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t main_stream) {
for (auto t : {a, b, d}) {
NVTE_CHECK(is_tensor_scaling(t->scaling_mode),
"Unsupported scaling mode: " + std::to_string(t->scaling_mode));
const bool tensor_scaling =
is_tensor_scaling(a->scaling_mode) && is_tensor_scaling(b->scaling_mode);
const bool mxfp8 = is_mxfp_scaling(a->scaling_mode) && is_mxfp_scaling(b->scaling_mode);

NVTE_CHECK(tensor_scaling || mxfp8, "Unsupported scaling modes: A=", to_string(a->scaling_mode),
", B=", to_string(b->scaling_mode));
NVTE_CHECK(is_tensor_scaling(d->scaling_mode),
"Unsupported scaling mode for D: ", to_string(d->scaling_mode));

if (mxfp8) {
#if CUBLASMP_VERSION < 801
NVTE_ERROR("MXFP8 GEMM requires cuBLASMp 0.8.1+.");
#else
NVTE_CHECK(a->with_gemm_swizzled_scales,
"MXFP8 scales for A are not in format expected by GEMM");
NVTE_CHECK(b->with_gemm_swizzled_scales,
"MXFP8 scales for B are not in format expected by GEMM");
#endif
}

// Mirror cublaslt_gemm.cu's CanonicalizeGemmInput for FP8 tensor scaling: depending on the
// architecture and the quantizer's usage modes, the appropriate data + scale_inv
// may live on the rowwise or columnwise side of the tensor.
// * Hopper (!nvte_is_non_tn_fp8_gemm_supported): only TN FP8 GEMMs are supported, so an
// FP8 input not already in TN orientation must be swapped to its columnwise (transposed)
// view and the transpose flag flipped.
// * Blackwell+ (nvte_is_non_tn_fp8_gemm_supported): any FP8 GEMM layout is supported, but
// the quantizer usage may have only been set to columnwise. In that case, fall back to
// the columnwise view and flip the transpose flag so the GEMM sees the matching data and
// scale_inv pair.
// The original tensor is never modified; a new Tensor view aliases the columnwise pointers.
// Mirror cublaslt_gemm.cu's input canonicalization. Tensor FP8 columnwise data is a
// transposed view, while MXFP8 columnwise data keeps the logical shape.
const bool fp8_needs_tn = !nvte_is_non_tn_fp8_gemm_supported();
auto canonicalize_fp8_input = [fp8_needs_tn](const Tensor* t, bool current_trans, bool want_trans,
const char* side) -> std::pair<Tensor, bool> {
auto canonicalize_input = [fp8_needs_tn](const Tensor* t, bool current_trans, bool is_a,
const char* side) -> std::pair<Tensor, bool> {
auto use_columnwise = [t](bool new_trans) -> std::pair<Tensor, bool> {
Tensor view = *t;
view.data = t->columnwise_data;
view.scale_inv = t->columnwise_scale_inv;
view.amax = t->columnwise_amax;
return {view, new_trans};
};

if (is_mxfp_scaling(t->scaling_mode)) {
if (is_a) {
if (current_trans) {
NVTE_CHECK(t->has_data(), "MXFP8 transposed input A is missing row-wise data");
return {*t, current_trans};
}
NVTE_CHECK(t->has_columnwise_data(),
"MXFP8 non-transposed input A is missing column-wise data");
return use_columnwise(current_trans);
}
if (current_trans) {
NVTE_CHECK(t->has_columnwise_data(),
"MXFP8 transposed input B is missing column-wise data");
return use_columnwise(current_trans);
}
NVTE_CHECK(t->has_data(), "MXFP8 non-transposed input B is missing row-wise data");
return {*t, current_trans};
}

if (!is_fp8_dtype(t->dtype())) {
return {*t, current_trans};
}

const bool want_trans = is_a;
const bool hopper_tn_swap = fp8_needs_tn && current_trans != want_trans;
const bool blackwell_missing_rowwise = !fp8_needs_tn && !t->has_data();
if (!hopper_tn_swap && !blackwell_missing_rowwise) {
return {*t, current_trans};
}
NVTE_CHECK(t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype),
"cuBLASMp FP8 GEMM input ", side, " is missing column-wise usage");
Tensor view;
view.scaling_mode = t->scaling_mode;
view.data = t->columnwise_data;
view.scale_inv = t->columnwise_scale_inv;
// Columnwise data is the transposed view of the original — flip the transpose flag.
return {view, !current_trans};
return use_columnwise(!current_trans);
};

auto [a_used, transa_eff] = canonicalize_fp8_input(a, transa, /*want_trans=*/true, "A");
auto [b_used, transb_eff] = canonicalize_fp8_input(b, transb, /*want_trans=*/false, "B");
auto [a_used, transa_eff] = canonicalize_input(a, transa, /*is_a=*/true, "A");
auto [b_used, transb_eff] = canonicalize_input(b, transb, /*is_a=*/false, "B");
transa = transa_eff;
transb = transb_eff;

Expand All @@ -321,20 +351,30 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
sizeof algo_attr));

const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32;
if (is_fp8_dtype(a_used.dtype())) {
NVTE_CHECK(a_used.scale_inv.dptr, "Scaling must be set for FP8 dtype");
auto get_input_scale_mode = [&](const Tensor& t) {
#if CUBLASMP_VERSION >= 801
if (is_mxfp_scaling(t.scaling_mode)) {
return CUBLASMP_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
}
#endif
return scale_mode;
};
if (is_fp8_dtype(a_used.dtype()) || is_mxfp_scaling(a_used.scaling_mode)) {
const cublasMpMatmulMatrixScale_t input_scale_mode = get_input_scale_mode(a_used);
NVTE_CHECK(a_used.scale_inv.dptr, "Scaling must be set for input A");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode,
sizeof scale_mode));
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE,
&input_scale_mode, sizeof input_scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER,
&a_used.scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(b_used.dtype())) {
NVTE_CHECK(b_used.scale_inv.dptr, "Scaling must be set for FP8 dtype");
if (is_fp8_dtype(b_used.dtype()) || is_mxfp_scaling(b_used.scaling_mode)) {
const cublasMpMatmulMatrixScale_t input_scale_mode = get_input_scale_mode(b_used);
NVTE_CHECK(b_used.scale_inv.dptr, "Scaling must be set for input B");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode,
sizeof scale_mode));
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE,
&input_scale_mode, sizeof input_scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER,
&b_used.scale_inv.dptr, sizeof(void*)));
Expand Down
Loading