diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e2e6d09c29..ee35e70df0 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -4,6 +4,7 @@ """PyTorch related extensions.""" +import importlib.util import os from pathlib import Path from importlib import metadata @@ -58,6 +59,25 @@ def setup_pytorch_extension( ] ) + # apache-tvm-ffi: headers for the C++ API (Module / Function / TensorView) + # and libtvm_ffi.so for symbol resolution. Used by tvm_ffi_bridge.h / + # applyTVMFunction. Python registers AOT-compiled CuTeDSL kernels into + # the global registry; TE C++ looks them up via Function::GetGlobalRequired. + tvm_ffi_spec = importlib.util.find_spec("tvm_ffi") + if tvm_ffi_spec is None or not tvm_ffi_spec.submodule_search_locations: + raise RuntimeError( + "apache-tvm-ffi package not found; install it (e.g. " + "`pip install apache-tvm-ffi`) — required for the TVM FFI bridge." + ) + tvm_ffi_root = Path(tvm_ffi_spec.submodule_search_locations[0]) + tvm_ffi_include = tvm_ffi_root / "include" + tvm_ffi_lib_dir = tvm_ffi_root / "lib" + if not tvm_ffi_include.is_dir() or not (tvm_ffi_lib_dir / "libtvm_ffi.so").exists(): + raise RuntimeError( + f"apache-tvm-ffi assets missing at {tvm_ffi_root} (need include/ and lib/libtvm_ffi.so)" + ) + include_dirs.append(tvm_ffi_include) + # Compiler flags cxx_flags = ["-O3", "-fvisibility=hidden"] if debug_build_enabled(): @@ -77,8 +97,11 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) - library_dirs = [] - libraries = [] + library_dirs = [tvm_ffi_lib_dir] + libraries = ["tvm_ffi"] + # rpath pinned to the pip install dir so the loader finds libtvm_ffi.so + # without LD_LIBRARY_PATH at runtime. + extra_link_args = [f"-Wl,-rpath,{tvm_ffi_lib_dir}"] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): assert ( os.getenv("NVSHMEM_HOME") is not None @@ -102,6 +125,7 @@ def setup_pytorch_extension( sources=[str(src) for src in sources], include_dirs=[str(inc) for inc in include_dirs], extra_compile_args={"cxx": cxx_flags}, + extra_link_args=extra_link_args, libraries=[str(lib) for lib in libraries], library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/pyproject.toml b/pyproject.toml index 4a8fded172..826a0e54a7 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,10 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +# apache-tvm-ffi is required at configure/compile/link time: the common C++ +# library finds it via find_package(tvm_ffi) and links libtvm_ffi.so (the +# CuTeDSL quant backend bridge). It is also a runtime dependency (see setup.py). +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1", "apache-tvm-ffi>=0.1.12"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/setup.py b/setup.py index 64ed120268..ed6fe977b4 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,9 @@ def setup_requirements() -> Tuple[List[str], List[str]]: "importlib-metadata>=1.0", "packaging", cusolvermp_pypi_package_name(), + # The core C++ library links libtvm_ffi.so (CuTeDSL quant backend bridge), + # so apache-tvm-ffi is required at runtime by every TE install. + "apache-tvm-ffi>=0.1.12", ] test_reqs: List[str] = ["pytest>=8.2.1"] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index edb8c5e109..faebee00f4 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -106,6 +106,24 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +# tvm-ffi: the quantize dispatch layer bridges to JIT-compiled CuTeDSL kernels +# through tvm-ffi (see common/tvm_ffi_bridge.h). Locate the tvm_ffi package that +# ships with the Python install and use its exported CMake config (provides the +# tvm_ffi::shared imported target with headers + libtvm_ffi.so). +execute_process( + COMMAND ${Python_EXECUTABLE} -c "import tvm_ffi.libinfo as li; print(li.find_cmake_path())" + OUTPUT_VARIABLE TVM_FFI_CMAKE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE TVM_FFI_CMAKE_QUERY) +if(NOT TVM_FFI_CMAKE_QUERY EQUAL 0) + message(FATAL_ERROR + "Could not import the tvm_ffi Python package (with '${Python_EXECUTABLE}'), " + "which Transformer Engine requires to build the CuTeDSL quantize backend " + "bridge (common/tvm_ffi_bridge.h). Install it into this Python environment: " + "`pip install apache-tvm-ffi`.") +endif() +find_package(tvm_ffi CONFIG REQUIRED PATHS "${TVM_FFI_CMAKE_DIR}") + function(find_nccl_version OUT_VERSION OUT_INCLUDE_DIR) find_path(_nvte_nccl_include_dir NAMES nccl.h @@ -360,6 +378,22 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cudart CUDNN::cudnn_all) +# CuTeDSL quantize backend bridge. PRIVATE: tvm_ffi_bridge.h is an internal +# header (not in the installed public include dir), so the symbols and headers +# are only needed to compile transformer_engine itself, not by downstream +# consumers. The INTERFACE include dirs of tvm_ffi::shared still apply to our +# own TUs, which is what fixes the not-found error. +target_link_libraries(transformer_engine PRIVATE tvm_ffi::shared) + +# libtvm_ffi.so ships inside the tvm_ffi Python package (not a system lib dir), +# so add its directory to the RPATH; otherwise the runtime loader can't satisfy +# the DT_NEEDED on libtvm_ffi.so and dlopen fails with "cannot open shared +# object file". Applied to both the build tree and the installed library. +get_target_property(TVM_FFI_SHARED_LOCATION tvm_ffi::shared IMPORTED_LOCATION) +get_filename_component(TVM_FFI_LIB_DIR "${TVM_FFI_SHARED_LOCATION}" DIRECTORY) +set_property(TARGET transformer_engine APPEND PROPERTY BUILD_RPATH "${TVM_FFI_LIB_DIR}") +set_property(TARGET transformer_engine APPEND PROPERTY INSTALL_RPATH "${TVM_FFI_LIB_DIR}") + target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE diff --git a/transformer_engine/common/CuTeDSL/__init__.py b/transformer_engine/common/CuTeDSL/__init__.py new file mode 100644 index 0000000000..5621c01e64 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""CuTeDSL kernels for Transformer Engine. + +Importing this package has a side effect: it registers the CuTeDSL kernel +entrypoints (e.g. ``get_mxfp8_quantization_function``) as TVM-FFI global +functions. The C++ dispatcher probes for those names via +``tvm::ffi::Function::GetGlobal`` — finding one means the process is running +inside a Python environment with the CuTeDSL toolchain available, so the kernel +may be compiled on demand; not finding it means a plain C++ environment, and +the dispatcher falls back to the CUDA C++ kernel. + +Importing requires the optional CuTeDSL toolchain (cutlass, tvm_ffi). Callers +that want graceful degradation should guard the import in a try/except. +""" + +from . import cast # noqa: F401 (import side effect: registers global funcs) diff --git a/transformer_engine/common/CuTeDSL/cast/__init__.py b/transformer_engine/common/CuTeDSL/cast/__init__.py new file mode 100644 index 0000000000..c4890ee489 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""CuTeDSL cast/quantization kernels. Importing pulls in each kernel module so +its TVM-FFI entrypoint is registered.""" + +from . import mxfp8 # noqa: F401 (import side effect: registers global funcs) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py new file mode 100644 index 0000000000..c42df11c01 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""MXFP8 CuTeDSL kernels. Importing ``quantize_mxfp8`` runs its module body, +which registers the ``get_mxfp8_quantization_function`` TVM-FFI global func.""" + +from . import quantize_mxfp8 # noqa: F401 (import side effect: registers the global func) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py new file mode 100644 index 0000000000..081219b67d --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py @@ -0,0 +1,1124 @@ +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 +from cutlass._mlir.dialects import arith as mlir_arith +from cutlass._mlir.dialects import llvm +from cutlass.base_dsl.compiler import GPUArch +from cutlass.cute.runtime import make_ptr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass.cute.arch import cvt_f32_bf16 + +from types import SimpleNamespace + +# FP8E4M3 max representable value +FP8E4M3_MAX_NORM = 448.0 +FP8E4M3_MAX_NORM_RCP = 1.0 / FP8E4M3_MAX_NORM +FP8E5M2_MAX_NORM = 57344.0 +FP8E5M2_MAX_NORM_RCP = 1.0 / FP8E5M2_MAX_NORM + +# NVFP4 (fp4e2m1) — 4-bit float, max representable value is 6.0 +FP4_E2M1_MAX = 6.0 +FP4_E2M1_MAX_RCP = 1.0 / FP4_E2M1_MAX +# Largest finite f32 — used to clamp the per-block scale inverse against +# division-by-zero (which produces +inf and then NaN downstream). +FP32_MAX = 3.4028234663852886e38 + +FP32_MANTISSA_BITS = 23 + + +@dsl_user_op +def _bitcast_f32_to_i32(val: Float32, *, loc=None, ip=None) -> Int32: + return Int32(mlir_arith.bitcast(T.i32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def _bitcast_i32_to_f32(val: Int32, *, loc=None, ip=None) -> Float32: + return Float32(mlir_arith.bitcast(T.f32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def fabs_f32(val: Float32, *, loc=None, ip=None) -> Float32: + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + abs_i32 = val_i32 & Int32(0x7FFFFFFF) + return _bitcast_i32_to_f32(abs_i32, loc=loc, ip=ip) + + +@dsl_user_op +def float_to_e8m0(val: Float32, *, loc=None, ip=None) -> Int32: + """Branchless float->E8M0: add mantissa mask to round up, clamp to 254.""" + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + rounded = val_i32 + Int32(0x7FFFFF) + exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) + return Int32( + mlir_arith.minsi( + exponent.ir_value(loc=loc, ip=ip), Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + + +@dsl_user_op +def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: + """2^(127 - biased_exp) with special-case handling.""" + new_exp = (Int32(254) - biased_exp) << Int32(FP32_MANTISSA_BITS) + result = _bitcast_i32_to_f32(new_exp, loc=loc, ip=ip) + for cmp_val, repl_bits in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: + cond = mlir_arith.cmpi( + mlir_arith.CmpIPredicate.eq, + biased_exp.ir_value(loc=loc, ip=ip), + Int32(cmp_val).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + alt = _bitcast_i32_to_f32(Int32(repl_bits), loc=loc, ip=ip) + result = Float32( + mlir_arith.select( + cond, alt.ir_value(loc=loc, ip=ip), result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + return result + + +@dsl_user_op +def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e4m3fn via PTX cvt.rn.satfinite.e4m3x2.f32.""" + zero = Float32(0.0) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + result_i32 = Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e5m2 via PTX cvt.rn.satfinite.e5m2x2.f32.""" + zero = Float32(0.0) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + result_i32 = Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def fma_f32(a: Float32, b: Float32, c: Float32, *, loc=None, ip=None) -> Float32: + """`fma.rn.f32 d, a, b, c;` — single-instruction fused multiply-add + matching nvcc's FFMA. Used for explicit `partial += a * b` patterns + where we need the same rounding as TE's compiler-fused FFMA.""" + return Float32( + llvm.inline_asm( + T.f32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip), c.ir_value(loc=loc, ip=ip)], + "fma.rn.f32 $0, $1, $2, $3;", + "=f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def tanh_approx(val: Float32, *, loc=None, ip=None) -> Float32: + """`tanh.approx.f32` — fast tanh approximation. Matches CUDA `__tanhf`.""" + return Float32( + llvm.inline_asm( + T.f32(), + [val.ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: + """Pack two f32 scalars into a single 64-bit register (`floatx2` layout). + + Low 32 bits = `lo`, high 32 bits = `hi`. Uses `mov.b64 %dst, {%lo, %hi};` + which lowers to a single register move — no actual memory traffic. + """ + return Int64( + llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def pack_i32x2(lo: Int32, hi: Int32, *, loc=None, ip=None) -> Int64: + """i32 sibling of `pack_f32x2` — concat two i32 into a single b64 register. + Used by NVFP4 to glue two `(bf16,bf16)`/`(f16,f16)` Int32 packs into the + `Int64` operand the `mul_cvt.*x4` PTX expects.""" + return Int64( + llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _trunc_i32_to_i16(val: Int32, *, loc=None, ip=None) -> Int16: + """Narrow an Int32 to Int16 by keeping the low 16 bits. + + Lives here because the existing arith-dialect narrowing pattern requires + loc/ip kwargs (see other `mlir_arith.trunci` callers); wrapping it as a + `@dsl_user_op` lets `@cute.jit` bodies use it without plumbing those in.""" + return Int16(mlir_arith.trunci(T.i16(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_fp8e4m3_to_f32(byte_i32: Int32, *, loc=None, ip=None) -> Float32: + """One fp8e4m3 byte (low 8 bits of `byte_i32`) → f32. + + PTX has no direct `cvt.f32.e4m3` for a scalar; route through the packed + `cvt.rn.f16x2.e4m3x2` and then `cvt.f32.f16`. The high byte of the .b16 + register is forced to zero so the discarded high f16 lane is well-defined.""" + asm = ( + "{\n" + ".reg .b32 masked; .reg .b16 b16; .reg .b16 b16_hi;\n\t" + ".reg .b32 f16pair; .reg .b16 lo_f16; .reg .b16 hi_f16;\n\t" + "and.b32 masked, $1, 0xFF;\n\t" + "mov.b32 {b16, b16_hi}, masked;\n\t" + "cvt.rn.f16x2.e4m3x2 f16pair, b16;\n\t" + "mov.b32 {lo_f16, hi_f16}, f16pair;\n\t" + "cvt.f32.f16 $0, lo_f16;\n\t" + "}" + ) + return Float32( + llvm.inline_asm( + T.f32(), + [byte_i32.ir_value(loc=loc, ip=ip)], + asm, + "=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +# --------------------------------------------------------------------------- +# 16-bit packed input PTX kit (bf16 / f16) +# +# bf16 and f16 share the same fast-path shape: packed-x2 amax via +# `max.xorsign.abs.x2`, then per-lane widen-to-f32 + `mul.f32x2` + +# `cvt.rn.satfinite.x2.f32`. Only the opcodes differ. Build one PTX kit +# per format at module load and let the kernel pick the right kit at JIT +# trace time via `cfg.DTYPE` — equivalent to a C++ template arg specialization +# on `IType`, with no runtime branch. +# --------------------------------------------------------------------------- +def _build_packed16_kit(in_fmt: str): + """Build a kit of PTX wrappers for a 16-bit input format. + + `in_fmt` is the PTX format string ('bf16' or 'f16'). Returns a namespace + with the per-format ops the rowwise/colwise inner loops need: + + abs_max_x2(Int32, Int32) -> Int32 # `max.xorsign.abs.x2` + abs_max_scalar(Int16, Int16) -> Int16 # `max.xorsign.abs.` + bits_to_f32(Int16) -> Float32 # widen one 16-bit element + x2_lo_to_f32(Int32) -> Float32 # extract+widen low half + x2_hi_to_f32(Int32) -> Float32 # extract+widen high half + mul_cvt_to_fp8x2(fp8_dtype) -> callable(Int32, Int64)->Int32 + # fused x2 * f32x2 -> fp8x2 + """ + + @dsl_user_op + def abs_max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def abs_max_scalar(a: Int16, b: Int16, *, loc=None, ip=None) -> Int16: + return Int16( + llvm.inline_asm( + T.i16(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt} $0, $1, $2;", + "=h,h,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + if in_fmt == "bf16": + # bf16 == top 16 bits of f32 — widening is a free bit-shift. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + i32 = Int32(mlir_arith.extui(T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + return _bitcast_i32_to_f32((bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + # `(x >> 16) << 16` ≡ `x & 0xFFFF0000`, sidestepping signed-literal + # issues. Sign bits from the arith-right shift get zeroed by the + # left shift. + return _bitcast_i32_to_f32((bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to bf16 precision (round-to-nearest-even), keep f32. + Matches C++'s `static_cast(static_cast(elt))`.""" + bf16_bits = Int16( + llvm.inline_asm( + T.i16(), + [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.bf16.f32 $0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + i32 = Int32( + mlir_arith.extui(T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + + else: + # f16 has its own bit layout; widening requires `cvt.f32.f16`. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + lo_i16 = Int16( + mlir_arith.trunci(T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + return bits_to_f32(lo_i16, loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + hi_shifted = bits >> Int32(16) + hi_i16 = Int16( + mlir_arith.trunci(T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + return bits_to_f32(hi_i16, loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to f16 precision, keep f32.""" + f16_bits = Int16( + llvm.inline_asm( + T.i16(), + [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16.f32 $0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Float32( + llvm.inline_asm( + T.f32(), + [f16_bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + def _build_mul_cvt(out_fmt: str, relu: bool = False): + """Build a fused `x2 * f32x2 → fp8x2` PTX wrapper. + + The shape is identical across (in_fmt, out_fmt) combos — only the + widening opcode (`cvt.f32.`) and the final saturating cvt + (`cvt.rn.satfinite.x2.f32`) differ. + """ + out_op = "e4m3x2" if out_fmt == "e4m3" else "e5m2x2" + asm = ( + "{\n" + ".reg.b64 vp0; .reg.b64 vp1;\n\t" + ".reg.b32 v1; .reg.b32 v2;\n\t" + ".reg.b16 vb1; .reg.b16 vb2;\n\t" + "mov.b32 {vb1, vb2}, $1;\n\t" + f"cvt.f32.{in_fmt} v1, vb1;\n\t" + f"cvt.f32.{in_fmt} v2, vb2;\n\t" + "mov.b64 vp0, {v1, v2};\n\t" + "mul.f32x2 vp1, vp0, $2;\n\t" + "mov.b64 {v2, v1}, vp1;\n\t" + f"cvt.rn.satfinite{'.relu' if relu else ''}.{out_op}.f32 $0, v1, v2;\n\t" + "}" + ) + + @dsl_user_op + def fn(val_2x: Int32, scale_2x: Int64, *, loc=None, ip=None) -> Int32: + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_2x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=h,r,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + + return fn + + def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): + if fp8_dtype == "e5m2": + return _build_mul_cvt("e5m2", relu) + return _build_mul_cvt("e4m3", relu) + + # NVFP4 fused cast: x4 × f32x2 → fp4e2m1x4 (4 fp4 packed in 16 + # bits). Same shape as `mul_cvt_to_fp8x2` but produces 4 elements at a + # time because the `cvt.rn.satfinite.e2m1x2.f32` PTX consumes pairs and + # writes a single byte (high nibble = first input, low nibble = second). + # The shuffled `mov.b64 {v1, v0}, v01` lines after the muls undo the + # PTX's hi/lo packing so the resulting byte is naturally + # `(fp4(elt1) << 4) | fp4(elt0)` — matches TE's C++ asm. + @dsl_user_op + def mul_cvt_to_fp4x4(in_4x: Int64, scale_2x: Int64, *, loc=None, ip=None) -> Int32: + asm = ( + "{\n" + ".reg.b64 v01; .reg.b64 v23;\n\t" + ".reg.b16 i0; .reg.b16 i1; .reg.b16 i2; .reg.b16 i3;\n\t" + ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" + ".reg.b8 f0; .reg.b8 f1;\n\t" + "mov.b64 {i0, i1, i2, i3}, $1;\n\t" + f"cvt.f32.{in_fmt} v0, i0;\n\t" + f"cvt.f32.{in_fmt} v1, i1;\n\t" + f"cvt.f32.{in_fmt} v2, i2;\n\t" + f"cvt.f32.{in_fmt} v3, i3;\n\t" + "mov.b64 v01, {v0, v1};\n\t" + "mov.b64 v23, {v2, v3};\n\t" + "mul.f32x2 v01, v01, $2;\n\t" + "mul.f32x2 v23, v23, $2;\n\t" + "mov.b64 {v1, v0}, v01;\n\t" + "mov.b64 {v3, v2}, v23;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 $0, {f0, f1, f0, f1};\n\t" + "}" + ) + return Int32( + llvm.inline_asm( + T.i32(), + [in_4x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=r,l,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + return SimpleNamespace( + abs_max_x2=abs_max_x2, + max_x2=max_x2, + abs_max_scalar=abs_max_scalar, + bits_to_f32=bits_to_f32, + x2_lo_to_f32=x2_lo_to_f32, + x2_hi_to_f32=x2_hi_to_f32, + truncate_f32=truncate_f32, + mul_cvt_to_fp8x2=mul_cvt_to_fp8x2, + mul_cvt_to_fp4x4=mul_cvt_to_fp4x4, + ) + + +_BF16_KIT = _build_packed16_kit("bf16") +_F16_KIT = _build_packed16_kit("f16") + + +def _is_packed16(dtype) -> bool: + """True if `dtype` is one of the 16-bit packed input formats.""" + return dtype is cutlass.BFloat16 or dtype is cutlass.Float16 + + +def _packed16_kit(dtype): + """Trace-time selector — pick a Packed16Kit for the input dtype.""" + if dtype is cutlass.Float16: + return _F16_KIT + return _BF16_KIT + + +# --------------------------------------------------------------------------- +# Forward-activation registry +# +# Each entry is a Float32 → Float32 callable applied per element before the +# MXFP8 amax + cast. Selection is by Python string at JIT trace time, so the +# const-expr machinery treats `cfg.ACTIVATION` like a C++ template argument +# — no runtime branch in the inner loop, separate kernel cached per choice. +# +# Math primitives match CUDA fast-math intrinsics so outputs are bit-exact +# with PyTorch's CUDA implementations of the same activations: +# tanh -> tanh.approx.f32 (== __tanhf) +# exp(x) -> exp2.approx.f32(x · log2(e)) (== __expf) +# --------------------------------------------------------------------------- +def _act_relu(x: Float32) -> Float32: + return cute.arch.fmax(x, Float32(0.0)) + + +def _act_gelu(x: Float32) -> Float32: + """Tanh-approximation GELU. Constants and operator grouping match TE's + `transformer_engine/common/util/math.h::gelu` exactly (factored form + `x · (0.5 + 0.5·tanh(x·(a + b·x²)))`) so quantized output is bit-exact + against the C++ fused IS_ACT path. Uses `cute.math.tanh(fastmath=False)` + rather than the `tanh.approx.f32` PTX intrinsic — TE compiles activation + kernels without `--use_fast_math` by default, so its `tanhf` is the + IEEE-precise expansion.""" + A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal + B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation + return x * (Float32(0.5) + Float32(0.5) * cute.math.tanh(x * (A + B * x * x))) + + +def _act_silu(x: Float32) -> Float32: + """SiLU/Swish: x · σ(x) = x / (1 + e^-x). + Matches TE's `silu` (`val / (1 + expf(-val))`).""" + return x / (Float32(1.0) + cute.arch.exp(-x)) + + +def _act_qgelu(x: Float32) -> Float32: + """Quick GELU: x · σ(1.702·x). Matches TE `qgelu_with_alpha(val, 1.702)` = + `cval · (1 / (1 + expf(-1.702·cval)))` (multiply by sigmoid, not a divide).""" + z = Float32(1.702) * x + return x * (Float32(1.0) / (Float32(1.0) + cute.arch.exp(-z))) + + +def _act_srelu(x: Float32) -> Float32: + """Squared ReLU: x>0 ? x·x : 0 == (max(0,x))². Matches TE `srelu`.""" + r = cute.arch.fmax(x, Float32(0.0)) + return r * r + + +SUPPORTED_ACTIVATIONS = { + "relu": _act_relu, + "gelu": _act_gelu, + "silu": _act_silu, + "qgelu": _act_qgelu, + "srelu": _act_srelu, +} + + +# --------------------------------------------------------------------------- +# Backward-activation (dact) registry +# +# Each entry is the derivative act'(x) as a Float32 → Float32 callable, matching +# the corresponding `d` in transformer_engine/common/util/math.h. The dact +# kernel computes `grad · act'(x)` per element before the MXFP8 amax + cast. +# Primitives mirror the forward registry (cute.math.tanh fastmath=False for +# gelu, cute.arch.exp for the sigmoid) so output is bit-exact with the C++ path. +# --------------------------------------------------------------------------- +@dsl_user_op +def _dact_drelu(x: Float32, *, loc=None, ip=None) -> Float32: + """drelu: x > 0 ? 1 : 0. Matches math.h `drelu` (NaN → 0 via ordered compare).""" + cond = mlir_arith.cmpf( + mlir_arith.CmpFPredicate.OGT, + x.ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + return Float32( + mlir_arith.select( + cond, + Float32(1.0).ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +def _dact_dsrelu(x: Float32) -> Float32: + """dsrelu: fmax(2x, 0). Matches math.h `dsrelu`.""" + return cute.arch.fmax(Float32(2.0) * x, Float32(0.0)) + + +def _sigmoid(x: Float32) -> Float32: + """σ(x) = 1 / (1 + e^-x), same exp intrinsic as the forward silu/qgelu.""" + return Float32(1.0) / (Float32(1.0) + cute.arch.exp(-x)) + + +def _dact_dsilu(x: Float32) -> Float32: + """dsilu: x·σ(x)·(1-σ(x)) + σ(x). Matches math.h `dsilu` + (`cval·dsigmoid + sigmoid`, dsigmoid = s·(1-s)).""" + s = _sigmoid(x) + return x * (s * (Float32(1.0) - s)) + s + + +def _dact_dqgelu(x: Float32) -> Float32: + """dqgelu (alpha=1.702): a·x·dσ(a·x) + σ(a·x). Matches math.h + `dqgelu_with_alpha(val, 1.702)`.""" + a = Float32(1.702) + ax = a * x + s = _sigmoid(ax) + return a * x * (s * (Float32(1.0) - s)) + s + + +def _dact_dgelu(x: Float32) -> Float32: + """dgelu (tanh approximation). Matches math.h `dgelu` term-for-term; + same tanh argument as the forward `_act_gelu`.""" + t = cute.math.tanh( + Float32(0.79788456) * x * (Float32(1.0) + Float32(0.044715) * x * x), + fastmath=False, + ) + return Float32(0.5) * x * ( + (Float32(1.0) - t * t) * (Float32(0.79788456) + Float32(0.1070322243) * x * x) + ) + Float32(0.5) * (Float32(1.0) + t) + + +SUPPORTED_DACTIVATIONS = { + "drelu": _dact_drelu, + "dgelu": _dact_dgelu, + "dsilu": _dact_dsilu, + "dqgelu": _dact_dqgelu, + "dsrelu": _dact_dsrelu, +} + + +@dsl_user_op +def cvt_f32x2_to_fp8e4m3x2( + val_hi: Float32, val_lo: Float32, relu: bool = False, *, loc=None, ip=None +) -> Int32: + """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. + + Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). + This mirrors ptx::mul_cvt_2x which converts 2 values in one instruction. + """ + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32(mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_f32x2_to_fp8e5m2x2( + val_hi: Float32, val_lo: Float32, relu: bool = False, *, loc=None, ip=None +) -> Int32: + """e5m2 sibling of `cvt_f32x2_to_fp8e4m3x2`.""" + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32(mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def mul_cvt_f32x4_to_fp4x4( + in01: Int64, in23: Int64, scale_2x: Int64, *, loc=None, ip=None +) -> Int32: + """f32x4 sibling of `kit.mul_cvt_to_fp4x4` — for the NVFP4 colwise path + where elements live on a strided column and we've already widened to f32 + for the amax reduction. `in01` = pack(f32_0, f32_1), `in23` similarly.""" + asm = ( + "{\n" + ".reg.b64 v01; .reg.b64 v23;\n\t" + ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" + ".reg.b8 f0; .reg.b8 f1;\n\t" + "mov.b64 {v0, v1}, $1;\n\t" + "mov.b64 {v2, v3}, $2;\n\t" + "mov.b64 v01, {v0, v1};\n\t" + "mov.b64 v23, {v2, v3};\n\t" + "mul.f32x2 v01, v01, $3;\n\t" + "mul.f32x2 v23, v23, $3;\n\t" + "mov.b64 {v1, v0}, v01;\n\t" + "mov.b64 {v3, v2}, v23;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 $0, {f0, f1, f0, f1};\n\t" + "}" + ) + return Int32( + llvm.inline_asm( + T.i32(), + [ + in01.ir_value(loc=loc, ip=ip), + in23.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip), + ], + asm, + "=r,l,l,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +def _cvt_f32_to_fp8(fp8_dtype: str): + """Const-expr dispatch: pick the f32→fp8 scalar PTX op based on output dtype. + + `fp8_dtype` is the Python string from `cfg.FP8_DTYPE`, evaluated at JIT + trace time; the unused branch is never traced. + """ + if fp8_dtype == "e5m2": + return cvt_f32_to_fp8e5m2 + return cvt_f32_to_fp8e4m3 + + +def _cvt_f32x2_to_fp8x2(fp8_dtype: str): + """Const-expr dispatch for the packed f32x2→fp8x2 cvt.""" + if fp8_dtype == "e5m2": + return cvt_f32x2_to_fp8e5m2x2 + return cvt_f32x2_to_fp8e4m3x2 + + +@cute.jit +def quantize_rowwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). Same purpose. + M, + N, # Int32 — full tensor extents; OOB threads skip their + # scale store. + ACTIVATION, + DTYPE, + ROWWISE, + COLWISE, + FP8_DTYPE, + TILE_Y, + SCALE_DIM, + WAVES, + THREADS_PER_WARP, + THREADS_PER_BANK, + PACK_SIZE, + WITH_ACT=False, # forward: apply activation to the element + WITH_DACT=False, # backward: out = grad · act'(act_input) + sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + DBIAS_REDUCTION=False, # rowwise-only dbias: accumulate per-column partials + dbias_acc=None, # rmem Float32[SCALE_DIM]; += this row's pre-truncate elt per column +): + tidx, _, _ = cute.arch.thread_idx() + + # Match the C++ reference's thread layout: pairs of adjacent lanes + # share a row (lanes 2k / 2k+1 both own row k), each pair covering + # the two 32-element scale blocks of that row. Express as a cute + # layout mapping `(tid_Y, tid_X) -> tidx` with stride (2, 1): + # linear(tidx) = tid_Y*2 + tid_X, so `get_flat_coord` inverts to + # `(tidx // 2, tidx % 2)` — semantically clearer than the raw + # divmod, and readily reusable if we later partition via TiledCopy. + # print(f"sX_tile: {sX_tile}") + # print(f"sO_row_tile: {sO_row_tile}") + # print(f"mS_row_stage: {mS_row_stage}") + + tiler, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + val_layout=cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + ) + # print(f"tv_layout: {tv_layout}") + # print(f"tiler: {tiler}") + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_row_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] # shape (32,) bf16 + sO_thread = sO_tv[tidx, None] # shape (32,) uint8 + + # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C1%29-%281%2C+32%29%3A%280%2C1%29 + # print(f"sX_thread: {sX_thread}") + # print(f"sO_thread: {sO_thread}") + + sO_thread_u32_ptr = cute.recast_ptr(sO_thread.iterator, dtype=Uint32) + # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. + sO_thread_u32 = cute.make_tensor( + sO_thread_u32_ptr, + cute.make_layout((SCALE_DIM // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements + ) + # print(f"sO_thread_u32: {sO_thread_u32}") + + FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") + # For this fast paht we can read in pack of 2 instead of reading individual f16 / bf16 element. + # dbias needs the per-element fp32 values to accumulate, so it forces the slow path. + _row_fast = _is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) and not DBIAS_REDUCTION + + if cutlass.const_expr(_row_fast): + # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack + # and use max.xorsign.abs.f16x2 / max.xorsign.abs.bf16x2 to compute + kit = _packed16_kit(DTYPE) + sX_thread_rw_i32 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int32), + cute.make_layout((1, SCALE_DIM // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements + ) + # print(f"sX_thread_rw_i32: {sX_thread_rw_i32}") + # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) + # In total we have 8 waves where each wave reads + in_r = [[None, None] for _ in range(WAVES)] + bank_group = ( + tidx % THREADS_PER_WARP + ) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 2 # Each bank group will read 2 i32 from their bank + for w in cutlass.range_constexpr(WAVES): + idx = (w * 2 + offset) % (SCALE_DIM // 2) + in_r[w][0] = sX_thread_rw_i32[0, idx] + in_r[w][1] = sX_thread_rw_i32[0, idx + 1] + + # 1. Packed-x2 amax — 2 PTX per wave, 16 total per thread. + # Accumulates `|elt|` in both lanes (with xorsign-drifted signs); + # final horizontal max reduces the two lanes to a single f32. + amax_2x = Int32(0) + # Each wave will use max.xorsign.abs.f16x2 or max.xorsign.abs.bf16x2 to compare 2 packed elements in parallel + for w in cutlass.range_constexpr(WAVES): + if cutlass.const_expr(FUSE_RELU): + # If we fuse relu then we don't want to do abs since negative value will be set to 0 and they will lose comparison automatically + amax_2x = kit.max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.max_x2(amax_2x, in_r[w][1]) + else: + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][1]) + if cutlass.const_expr(FUSE_RELU): + # Compare the 2 packed max without abs + amax_r = cute.arch.fmax( + kit.x2_lo_to_f32(amax_2x), + kit.x2_hi_to_f32(amax_2x), + ) + # For relu the max is at least 0 + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) + else: + # Compare the 2 packed abs max + amax_r = cute.arch.fmax( + fabs_f32(kit.x2_lo_to_f32(amax_2x)), + fabs_f32(kit.x2_hi_to_f32(amax_2x)), + ) + else: + # Since we need to do computation on individual f16 / bf16 elements, we can't read in pack + sX_thread_rw = cute.make_tensor( + sX_thread.iterator, + cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + ) + in_r = [[None] * PACK_SIZE for _ in range(WAVES)] + bank_group = ( + tidx % THREADS_PER_WARP + ) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will read 4 f16 from their bank + + if cutlass.const_expr(WITH_DACT): + # Backward: out = grad · act'(act_input). sX is grad, sA is act_input. + dop = SUPPORTED_DACTIVATIONS[ACTIVATION] + sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] + sA_thread_rw = cute.make_tensor( + sA_thread.iterator, + cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + ) + elif cutlass.const_expr(WITH_ACT): + op = SUPPORTED_ACTIVATIONS[ACTIVATION] + + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = _packed16_kit(DTYPE) + amax_r = Float32(0.0) + for w in cutlass.range_constexpr(WAVES): + idx = (w * PACK_SIZE + offset) % SCALE_DIM + for e in cutlass.range_constexpr(PACK_SIZE): + x = Float32(sX_thread_rw[0, idx + e]) # grad + if cutlass.const_expr(WITH_DACT): + # out = grad · act'(act_input) + x = x * dop(Float32(sA_thread_rw[0, idx + e])) + # If IS_ACT, apply activation function to x in f32 + elif cutlass.const_expr(WITH_ACT): + # If it's relu, we can handle it later + if not cutlass.const_expr(FUSE_RELU): + x = op(x) + # dbias: accumulate this row's column (idx+e) value BEFORE the bf16 + # truncation (matches CUDA's `thread_dbias_rowwise[j] += elt`). idx+e + # is a multiple-of-PACK_SIZE group + e, so it stays within [0, SCALE_DIM). + if cutlass.const_expr(DBIAS_REDUCTION): + dbias_acc[idx + e] = dbias_acc[idx + e] + x + # If 16-bit input with activation, truncate to IType + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + x = kit_act.truncate_f32(x) # TODO: Why not just qunatize from f32? + in_r[w][e] = x + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax( + amax_r, x + ) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically + else: + amax_r = cute.arch.fmax(amax_r, fabs_f32(x)) + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 + + # 2. E8M0 scale → gmem. mS_row's layout already encodes the swizzle + # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. + biased_exp_r = float_to_e8m0(amax_r * max_norm_rcp) + # mS_row_stage has logical shape (32, 2) and we have 64 threads where each is mapped to one scale factor + # The TV layout is equivalent to https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C+1%29-%281%29 + # but it's too trival so let's just index it directly without using layout + # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout that mappes the logical index to the physical offset + # Irregular shapes: skip the scale store if this thread's logical row / + # col-block lies past the input's actual extents. TMA already zero-fills + # OOB input reads and drops OOB output writes; only the direct scale-byte + # gmem store needs an explicit guard. + scale_row = tile_row_start + tidx // 2 + scale_col_first_elt = tile_col_start + (tidx % 2) * SCALE_DIM + if scale_row < M and scale_col_first_elt < N: + mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) + + # 3. scale + packed fp8 cast → smem as one u32 per wave. + inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale + # Fetch the conversion function based on the FP8 format + cvt_f32x2 = _cvt_f32x2_to_fp8x2(FP8_DTYPE) + if cutlass.const_expr(_row_fast): + kit_cast = _packed16_kit(DTYPE) + mul_cvt_x2 = kit_cast.mul_cvt_to_fp8x2(FP8_DTYPE, FUSE_RELU) + # Pack `(inv_scale_r, inv_scale_r)` as a single 64-bit f32x2 once; + # the per-wave mul_cvt consumes this directly. + scale_2x = pack_f32x2(inv_scale_r, inv_scale_r) + + bank_group = ( + tidx % THREADS_PER_WARP + ) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will write 4 fp8 to + for w in cutlass.range_constexpr(WAVES): + idx = (w * 4 + offset) % SCALE_DIM + idx = idx // 4 + if cutlass.const_expr(_row_fast): + # One fused PTX per x2 pair: x2 × f32x2 → fp8x2. + # Byte layout: byte[0]=fp8(lo * s), byte[1]=fp8(hi * s). + p01 = mul_cvt_x2(in_r[w][0], scale_2x) + p23 = mul_cvt_x2(in_r[w][1], scale_2x) + else: + # cvt PTX semantics: `cvt.rn.satfinite..f32 d, a, b` gives + # d[15:8]=fp8(a), d[7:0]=fp8(b). Pass (v1, v0) so the u16 low + # byte ends up as fp8(v0) and the high byte as fp8(v1). + v0 = in_r[w][0] * inv_scale_r + v1 = in_r[w][1] * inv_scale_r + v2 = in_r[w][2] * inv_scale_r + v3 = in_r[w][3] * inv_scale_r + p01 = cvt_f32x2(v1, v0, FUSE_RELU) # u16 little-endian: v0,v1 + p23 = cvt_f32x2(v3, v2, FUSE_RELU) # u16 little-endian: v2,v3 + quad = (p23 << Int32(16)) | p01 + sO_thread_u32[idx] = Uint32(quad) + + return amax_r + + +@cute.jit +def quantize_colwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). + M, + N, # Int32 — full tensor extents. + ACTIVATION, + DTYPE, + FP8_DTYPE, + SWIZZLE, + TILE_X, + TILE_Y, + SCALE_DIM, + WITH_ACT=False, # forward: apply activation to the element + WITH_DACT=False, # backward: out = grad · act'(act_input) + sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + WITH_DBIAS=False, # also return this thread's column sum (pre-truncate) +): + tidx, _, _ = cute.arch.thread_idx() + + # print(f"sX_tile: {sX_tile}") + # print(f"sO_col_tile: {sO_col_tile}") + # print(f"mS_col_stage: {mS_col_stage}") + + tiler, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), + val_layout=cute.make_layout((SCALE_DIM, 1), stride=(1, 1)), + ) + # print(f"tv_layout: {tv_layout}") + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_col_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] + sO_thread = sO_tv[tidx, None] + + # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%281%2C+64%29%3A%2864%2C+1%29-%2832%2C+1%29%3A%281%2C+1%29 + # print(f"sX_thread: {sX_thread}") # shape (32,) bf16 + # print(f"sO_thread: {sO_thread}") # shape (32,) uint8 + + # dbias needs the per-element fp32 values to sum, so it takes the f32 path + # (never the i16 fast path) — matching CUDA, whose f16 fast path requires + # `!IS_DBIAS` (quantize_mxfp8.cuh:219). + HALF_PRECISION_PATH = _is_packed16(DTYPE) and ACTIVATION is None and not WITH_DBIAS + dbias_partial = Float32(0.0) + + # 0. Load the 32-element column from smem into registers once (matches + # C++'s `in_colwise_IType[i]` cache). Amax and cast both reuse these. + if cutlass.const_expr(HALF_PRECISION_PATH): + kit = _packed16_kit(DTYPE) + # Per-thread Int16 view of the column. Same byte address as + # `sX_thread` (bf16/fp16 are 16-bit, same width as Int16); the + # element stride is TILE_X because the column elements are + # TILE_X apart in the row-major tile. + sX_thread_i16 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int16), + cute.make_layout((SCALE_DIM,), stride=(TILE_X,)), + ) + amax_bits = Int16(0) + for i in cutlass.range_constexpr(SCALE_DIM): + amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) + amax_c = fabs_f32(kit.bits_to_f32(amax_bits)) + else: + # Materialize the column into f32 registers — widen on read so + # bf16/fp16 inputs become real fp32 values (a pointer recast to + # Float32 would not widen; it would reinterpret the 16-bit bytes + # as half of a 32-bit float). + sX_thread_f32 = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((SCALE_DIM,), stride=(1,)), + dtype=Float32, + ) + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = Float32(sX_thread[i]) + # Apply activation (fwd) or grad·act'(act_input) (bwd dact) in f32. + if cutlass.const_expr(WITH_DACT): + dop = SUPPORTED_DACTIVATIONS[ACTIVATION] + sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = sX_thread_f32[i] * dop(Float32(sA_thread[i])) + elif cutlass.const_expr(WITH_ACT): + op = SUPPORTED_ACTIVATIONS[ACTIVATION] + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = op(sX_thread_f32[i]) + # dbias = column sum of the (post-act/dact) value, taken BEFORE the bf16 + # truncation — matches CUDA's `partial_dbias_colwise += elt`. + if cutlass.const_expr(WITH_DBIAS): + for i in cutlass.range_constexpr(SCALE_DIM): + dbias_partial += sX_thread_f32[i] + # Numerical truncation through IType so amax/cast match C++. + # Only needed when 16-bit input + activation; without activation + # the widening was already exact. + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = _packed16_kit(DTYPE) + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = kit_act.truncate_f32(sX_thread_f32[i]) + amax_c = Float32(0.0) + for i in cutlass.range_constexpr(SCALE_DIM): + amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[i])) + + # 2. E8M0 scale → gmem. mS_col's layout already encodes the swizzle + # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. + # Irregular shapes: skip when this stage's row range or this thread's + # column lies past the input extents. TILE_Y == SCALE_DIM so each stage + # is exactly one scale-row; valid iff `tile_row_start < M`. + biased_exp_c = float_to_e8m0(amax_c * max_norm_rcp) + scale_col = tile_col_start + tidx + if tile_row_start < M and scale_col < N: + if cutlass.const_expr(SWIZZLE): + mS_col_stage[(0, tidx % 32, tidx // 32)] = Uint8(biased_exp_c) + else: + mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) + + # 3. scale + FP8 cast → smem (one byte per (row, tidx)). Caller + # flushes the whole (TILE_Y, TILE_X) tile with a TMA S2G. + inv_scale_c = exp2f_rcp(biased_exp_c) + cvt_to_fp8 = _cvt_f32_to_fp8(FP8_DTYPE) + if cutlass.const_expr(HALF_PRECISION_PATH): + kit_cast = _packed16_kit(DTYPE) + for i in cutlass.range_constexpr(SCALE_DIM): + v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) + sO_thread[i] = Uint8(cvt_to_fp8(v_f32 * inv_scale_c)) + else: + for i in cutlass.range_constexpr(SCALE_DIM): + sO_thread[i] = Uint8(cvt_to_fp8(sX_thread_f32[i] * inv_scale_c)) + + return amax_c, dbias_partial diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py new file mode 100644 index 0000000000..b2211541fc --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -0,0 +1,1113 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""MXFP8 quantization kernel implemented in CuTeDSL. + +Replicates the core logic of quantize_mxfp8.cuh: given a 2D tensor of BF16/FP16 +values, quantize to MXFP8 format (FP8E4M3 data + E8M0 per-block scales). + +Matches the C++ kernel's tile dimensions and thread layout: + CHUNK_DIM_Y = 64, CHUNK_DIM_X = 64, THREADS_PER_CHUNK = 64 + BUFF_DIM_Y = 32, BUFF_DIM_X = 64, STAGES = 2 + SCALE_DIM = 32 (elements per MXFP8 scaling block) + +Grid: (ceil(N / 64), ceil(M / 64)) +Each block processes a 64x64 chunk in 2 stages of 32x64 tiles loaded into +shared memory. +""" +import logging + +import transformer_engine +from transformer_engine.common.CuTeDSL.utils import str_to_cutlass_dtype +import transformer_engine_torch as tex + +from typing import Optional, Type + +import torch +import transformer_engine_torch as tex + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 +from cuda.bindings.driver import CUstream + +import hashlib +import tvm_ffi + +from .mxfp8_utils import ( + SUPPORTED_ACTIVATIONS, + SUPPORTED_DACTIVATIONS, + FP8E4M3_MAX_NORM_RCP, + FP8E5M2_MAX_NORM_RCP, + _bitcast_f32_to_i32, + _cvt_f32_to_fp8, + _cvt_f32x2_to_fp8x2, + _is_packed16, + _packed16_kit, + exp2f_rcp, + fabs_f32, + float_to_e8m0, + quantize_colwise_mxfp8, + quantize_rowwise_mxfp8, +) + +# Per-backend logger, so a fallback warning is attributable to *this* CuTeDSL +# backend (the MXFP8 quantize backend). Other CuTeDSL backends should use their +# own `transformer_engine.cutedsl.` logger. +logger = logging.getLogger("transformer_engine.cutedsl.mxfp8") + +# MXFP8 settings +MXFP8_BLOCK_SIZE = ( + 32 # Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor +) +SCALE_DIM = MXFP8_BLOCK_SIZE + +# Double-buffering for async copy + compute overlap +BUFFER_NUM = 2 + +# Vectorised access constants for bank-conflict avoidance (rowwise pass) +PACK_SIZE = 4 # Elements per vector load +WAVES = ( + SCALE_DIM // PACK_SIZE +) # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total +THREADS_PER_WARP = 32 +TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) +THREADS_PER_BANK = TOTAL_BANKS_WIDTH // SCALE_DIM # 4 threads per bank + +# Tiling sizes +NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) +NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension +TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total +TILE_X = 64 # Each tile has 64 columns + +# CTA size +THREADS_PER_CHUNK = 64 +NUM_WARPS = THREADS_PER_CHUNK // 32 + + +# --------------------------------------------------------------------------- +# Kernel configuration +# --------------------------------------------------------------------------- +class MXFP8QuantizeConfig: + + def __init__( + self, + dtype: str, + fp8_dtype: str, + rowwise: bool, + colwise: bool, + with_gemm_swizzled_scales: bool, + with_amax: bool, + with_dbias: bool = False, + with_dact: bool = False, + with_act: bool = False, + with_noop: bool = False, + activation: Optional[str] = None, + ): + if dtype is None or dtype not in ("fp32", "fp16", "bf16"): + raise ValueError(f"unknown input dtype {dtype!r}; expected fp32|fp16|bf16") + self.DTYPE = str_to_cutlass_dtype(dtype) + self.DTYPE_STR = dtype # readable input-dtype token, for __str__ + if fp8_dtype not in ("e4m3", "e5m2"): + raise ValueError(f"unknown FP8 dtype {fp8_dtype!r}; expected 'e4m3' or 'e5m2'") + self.FP8_DTYPE = fp8_dtype + self.ROWWISE = rowwise + self.COLWISE = colwise + if not (rowwise or colwise): + raise ValueError("at least one of rowwise or colwise must be true") + self.WITH_GEMM_SWIZZLED_SCALES = with_gemm_swizzled_scales + self.WITH_AMAX = with_amax + if not with_dact and not with_act: + if activation == "none": + self.ACTIVATION = None + else: + raise ValueError( + "activation must be none when with_dact and with_act are both False" + ) + else: + if with_dact and with_act: + raise ValueError( + "with_dact and with_act cannot be true at the same time since they are used for" + " different paths (bwd vs fwd)" + ) + elif with_dact: + if activation in SUPPORTED_DACTIVATIONS: + self.ACTIVATION = activation + else: + raise ValueError( + f"unknown activation {activation!r} for with_dact=True; expected one of" + f" {sorted(SUPPORTED_DACTIVATIONS)}" + ) + elif with_act: + if activation in SUPPORTED_ACTIVATIONS: + self.ACTIVATION = activation + else: + raise ValueError( + f"unknown activation {activation!r} for with_act=True; expected one of" + f" {sorted(SUPPORTED_ACTIVATIONS)}" + ) + self.WITH_DACT = with_dact + self.WITH_ACT = with_act + # dbias is the column reduction of the (post-act/dact) element. With colwise + # output each thread owns a full column (trivial reduction); rowwise-only + # uses a cross-thread smem reduction over THREADS_Y. Both mirror the CUDA + # kernel's COLWISE_SCALING / rowwise dbias branches. + self.WITH_DBIAS = with_dbias + self.WITH_NOOP = with_noop + self.MAX_NORM_RCP = FP8E4M3_MAX_NORM_RCP if fp8_dtype == "e4m3" else FP8E5M2_MAX_NORM_RCP + + def __str__(self): + return ( + f"MXFP8QuantizeConfig(dtype={self.DTYPE_STR}, fp8_dtype={self.FP8_DTYPE}, " + f"rowwise={self.ROWWISE}, colwise={self.COLWISE}, " + f"swizzled={self.WITH_GEMM_SWIZZLED_SCALES}, with_amax={self.WITH_AMAX}, " + f"with_dbias={self.WITH_DBIAS}, with_dact={self.WITH_DACT}, " + f"with_act={self.WITH_ACT}, with_noop={self.WITH_NOOP}, " + f"activation={self.ACTIVATION})" + ) + + __repr__ = __str__ + + +# --------------------------------------------------------------------------- +# Unified MXFP8 quantization kernel — shared memory tiled, single-pass +# --------------------------------------------------------------------------- +class MXFP8QuantizeSmemKernel: + """MXFP8 quantization with shared-memory tiling (rowwise, colwise, or both). + + Matches C++ kernel's BIDIMENSIONAL scaling mode: + Grid (ceil(N/64), ceil(M/64)) + Block (64) + Each block processes a 64x64 chunk in 2 stages of 32x64. + + Per stage, the tile is loaded into shared memory once. The colwise + pass reads columns from smem first, then the rowwise pass reads rows. + When both directions are enabled, global memory is read only once per + element — matching the C++ single-pass behaviour. + + Thread mappings (per stage): + Colwise: thread tidx handles column tidx, 32 rows (stride BUFF_DIM_X). + Rowwise: tid_Y = tidx // 2 -> row, tid_X = tidx % 2 -> scale-block. + """ + + def __init__(self, cfg): + self.cfg = cfg + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # Input tensor to quantize + mO_row: Optional[cute.Tensor], + mS_row: Optional[cute.Tensor], # Rowwise output and scale tensors + mO_col: Optional[cute.Tensor], + mS_col: Optional[cute.Tensor], # Colwise output and scale tensors + mAmax: Optional[cute.Tensor], # Global amax accumulator, only used in WITH_AMAX path + mNoop: Optional[cute.Tensor], # 1-element cast_noop flag, only used in WITH_NOOP path + # Backward-only slots, present to mirror the CUDA mxfp8::quantize signature + # (act_input / dbias / workspace). NOT used yet — None on the forward path; + # WITH_DACT/WITH_DBIAS configs are rejected upstream so these never carry data. + mActInput: Optional[cute.Tensor], + mDbias: Optional[cute.Tensor], + mWorkspace: Optional[cute.Tensor], + stream: CUstream, + ): + M = mX.shape[0] + N = mX.shape[1] + cfg = self.cfg + max_norm_rcp = cfg.MAX_NORM_RCP + num_scale_cols = N // SCALE_DIM + num_scale_rows = M // SCALE_DIM + + # Rewrap mS_row / mS_col with the GEMM-swizzled layout when requested. + # Wrapper passes in a tensor with the compact (M, N/32):(N/32, 1) layout + # (built from a compact fake-ptr at compile time), and we re-view the + # underlying buffer here so the per-block scale stores below land at the + # cuBLAS-swizzled byte offsets. + # See https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout + # and swizzle_demo.svg for a visual of the byte permutation. + if cutlass.const_expr(cfg.WITH_GEMM_SWIZZLED_SCALES): + num_tiles_M = (M + 127) // 128 + num_tiles_SC = (num_scale_cols + 3) // 4 # = ceil(N / 128) + num_tiles_SR = (num_scale_rows + 3) // 4 # = ceil(M / 128) + num_tiles_N = (N + 127) // 128 + # row i = i_lo + 32 * (i_hi + 4 * tile_Y); col j = j_lo + 4 * tile_X. + # Within one 128×4 tile: byte = i_lo*16 + i_hi*4 + j_lo. + + # Tile-major outer dims add (tile_Y * num_tiles_SC + tile_X) * 512. + # For example, if M=256, N=512, then num_scale_cols = 16, num_scale_rows = 8, and num_tiles_M=2, num_tiles_SC=4, num_tiles_SR=2, num_tiles_N=4 + # The swizzled layout is ((32, 4, 2), (4, 4)):((16, 4, 2048), (1, 512)) + if cutlass.const_expr(cfg.ROWWISE): + mS_row = cute.make_tensor( + mS_row.iterator, + cute.make_layout( + ((32, 4, num_tiles_M), (4, num_tiles_SC)), + stride=((16, 4, num_tiles_SC * 512), (1, 512)), + ), + ) + # Colwise: same swizzle, axes swap roles — col axis gets the 32×4 + # inner decomp, scale-row axis gets the 4-extent dim. + if cutlass.const_expr(cfg.COLWISE): + mS_col = cute.make_tensor( + mS_col.iterator, + cute.make_layout( + ((4, num_tiles_SR), (32, 4, num_tiles_N)), + stride=((1, 512), (16, 4, num_tiles_SR * 512)), + ), + ) + + # Divide by the STAGE tile (TILE_Y, TILE_X // SCALE_DIM), not the CTA + # tile. Each CTA owns NUM_TILES consecutive row-tiles; the kernel walks + # them by indexing GRID's row dim with `bidy * NUM_TILES + stage` (cute + # auto-decomposes a flat coord onto GRID's hierarchical row modes). + # + # Critically, this is the only divide that cleanly cuts both layouts: + # - compact `(M, N/32):(N/32, 1)` → SCALE_TILE = (32, 2):(N/32, 1) + # - swizzled `((32,4,n_M),(4,n_SC)):((16,4,n_SC·512),(1,512))` + # → SCALE_TILE = (32, 2):(16, 1) + # The bigger (TILE_Y * NUM_TILES, ...) divide we used before tangles the + # swizzle's (32, 4) row hierarchy under flatten + sub-divide chain. + + # Declare TMA descriptors on the host side. + # make_tiled_tma_atom returns the UNTILED gmem tensor with basis strides. + # Tile it inside the kernel with zipped_divide so each coord selects + # one (TILE_Y, TILE_X) tile. + smem_tile_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) + cta_tiler = (TILE_Y, TILE_X) + + # Input: TMA G2S (bf16/fp16 → smem). + op_load = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + tma_atom, tma_src = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_load, + mX, + smem_tile_layout, + cta_tiler, + num_multicast=1, + ) + + # Backward (dact): the activation input is a second G2S load, identical to + # mX's. The kernel computes `grad · act'(act_input)`; here mX carries grad. + tma_atom_act = None + tma_src_act = None + if cutlass.const_expr(cfg.WITH_DACT): + tma_atom_act, tma_src_act = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_load, + mActInput, + smem_tile_layout, + cta_tiler, + num_multicast=1, + ) + + # Output: TMA S2G (uint8 smem → gmem) for both directions. Creating + # both atoms unconditionally — if a direction is disabled the kernel + # simply won't dispatch its copy, and the atom cost is negligible. + op_store = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + out_smem_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) + tma_atom_out_row = None + tma_dst_out_row = None + tma_atom_out_col = None + tma_dst_out_col = None + if cutlass.const_expr(cfg.ROWWISE): + tma_atom_out_row, tma_dst_out_row = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_store, + mO_row, + out_smem_layout, + cta_tiler, + num_multicast=1, + ) + if cutlass.const_expr(cfg.COLWISE): + tma_atom_out_col, tma_dst_out_col = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_store, + mO_col, + out_smem_layout, + cta_tiler, + num_multicast=1, + ) + + # Decide when to perform dbias reduction + DBIAS_REDUCTION_COLWISE: cutlass.Constexpr = False + DBIAS_REDUCTION_ROWWISE: cutlass.Constexpr = False + if cutlass.const_expr(cfg.WITH_DBIAS): + # We prefer to perform dbias reduction in the colwise pass since it doesn't require shuffle + if cutlass.const_expr(cfg.COLWISE): + DBIAS_REDUCTION_COLWISE = True + else: + DBIAS_REDUCTION_ROWWISE = True + + # CUDA launches in (0,0), (1,0), (2,0)... order, so we should make N the leading dimension for better access pattern + # So consecutive blocks will move along the N dimension first, which is the innermost dimension in memory and we can use cache better + grid = [ + cute.ceil_div(Int32(N), TILE_X), + cute.ceil_div(M, TILE_Y * NUM_TILES), + ] + block = [ + THREADS_PER_CHUNK, + ] + + self.kernel( + mX, + mS_row, + mS_col, + mAmax, + mNoop, + mWorkspace, + max_norm_rcp, + mX.element_type, + tma_atom, + tma_src, + tma_atom_out_row, + tma_dst_out_row, + tma_atom_out_col, + tma_dst_out_col, + tma_atom_act, + tma_src_act, + ).launch( + grid=grid, + block=block, + stream=stream, + ) + + # Device entry (launched by __call__). Reads the cast_noop flag and runs the + # work only if it is not set — matching the CUDA kernel's + # `if (noop[0]==1.0f) return;`. When WITH_NOOP is off, mNoop is None and the + # whole check is compiled out (so no flag is read). + @cute.kernel + def kernel( + self, + mX, + mS_row, + mS_col, + mAmax, + mNoop, + mWorkspace, + max_norm_rcp, + dtype: cutlass.Constexpr[Type[cutlass.Numeric]], + tma_atom, + tma_src, # how to use TMA to copy the input + tma_atom_out_row, + tma_dst_out_row, # how to use TMA to copy the rowwise output + tma_atom_out_col, + tma_dst_out_col, # how to use TMA to copy the colwise output + tma_atom_act, + tma_src_act, # dact only: how to copy the activation input + ): + cfg = self.cfg + # `not const_expr(WITH_NOOP)` is a compile-time True when noop is disabled, + # so Python short-circuits the `or` and never reads mNoop[0] (it is None). + if not cutlass.const_expr(cfg.WITH_NOOP) or mNoop[0] != Float32(1.0): + self._kernel_main( + mX, + mS_row, + mS_col, + mAmax, + mWorkspace, + max_norm_rcp, + dtype, + tma_atom, + tma_src, + tma_atom_out_row, + tma_dst_out_row, + tma_atom_out_col, + tma_dst_out_col, + tma_atom_act, + tma_src_act, + ) + + # The actual quantize work. MUST be @cute.jit (not @cute.kernel): it is invoked + # from the @cute.kernel `kernel` wrapper under a runtime noop branch, and only a + # separately-traced @cute.jit callable may allocate shared memory inside such a + # branch (an inlined/undecorated method or a nested @cute.kernel would fail). + @cute.jit + def _kernel_main( + self, + mX, + mS_row, + mS_col, + mAmax, + mWorkspace, + max_norm_rcp, + dtype: cutlass.Constexpr[Type[cutlass.Numeric]], + tma_atom, + tma_src, # how to use TMA to copy the input + tma_atom_out_row, + tma_dst_out_row, # how to use TMA to copy the rowwise output + tma_atom_out_col, + tma_dst_out_col, # how to use TMA to copy the colwise output + tma_atom_act, + tma_src_act, # dact only: how to copy the activation input + ): + cfg = self.cfg + + if cutlass.const_expr(cfg.ROWWISE): + mS_row = cute.zipped_divide(mS_row, (TILE_Y, TILE_X // SCALE_DIM)) + if cutlass.const_expr(cfg.COLWISE): + mS_col = cute.zipped_divide(mS_col, (TILE_Y // SCALE_DIM, TILE_X)) + # For M=256, N=512: + # Non-swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28256%2C+16%29%3A%2816%2C+1%29-32%0A2 + # Swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28%2832%2C+4%2C+2%29%2C+%284%2C+4%29%29%3A%28%2816%2C+4%2C+2048%29%2C+%281%2C+512%29%29-32%0A2 + # print(f"mS_row after zipped_divide: {mS_row}") + + # FP8 output smem, one 32×64 tile per stage per enabled direction. + # Allocating a dead sO_col in rowwise-only (or sO_row in colwise-only) + # bumps per-CTA smem from 12 KB to 16 KB, which drops occupancy and + # regresses the single-direction path by ~8-10% at 16384^2. Match + # C++ and only allocate what the active pass actually uses. + # sAmax holds one f32 per warp for the cross-warp amax reduction — + # negligible (8 bytes for NUM_WARPS=2) and we always allocate so the + # struct doesn't fork on a 4th const-expr (cfg.WITH_AMAX) dimension. + if cutlass.const_expr(cfg.ROWWISE and cfg.COLWISE): + + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_col: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + + elif cutlass.const_expr(cfg.ROWWISE and not cfg.COLWISE): + + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + + elif cutlass.const_expr(cfg.ROWWISE): + + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + + else: + + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_col: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # dact: the activation-input tile lives in its own smem buffer, same + # shape/layout as sX. Allocated separately so the 4 SharedStorage variants + # above don't have to fork again on WITH_DACT. + if cutlass.const_expr(cfg.WITH_DACT): + + @cute.struct + class DactStorage: + sActInput: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + + dact_storage = smem.allocate(DactStorage) + sActInput = dact_storage.sActInput.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + + # Rowwise-only dbias needs a cross-thread (over THREADS_Y) smem reduction, + # since each rowwise thread owns a row, not a column. Buffer is + # [THREADS_Y][THREADS_X*(SCALE_DIM+1)] f32 — the +1 per scale-block padding + # avoids bank conflicts, matching CUDA's DBIAS_BUFF_WIDTH. + DBIAS_REDUCTION_ROWWISE = cutlass.const_expr(cfg.WITH_DBIAS and not cfg.COLWISE) + DBIAS_BUFF_WIDTH = (TILE_X // SCALE_DIM) * (SCALE_DIM + 1) + if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): + + @cute.struct + class DbiasStorage: + sDbias: cute.struct.MemRange[Float32, TILE_Y * DBIAS_BUFF_WIDTH] + + dbias_storage = smem.allocate(DbiasStorage) + sDbias = dbias_storage.sDbias.get_tensor(cute.make_layout(TILE_Y * DBIAS_BUFF_WIDTH)) + + # Per-stage shmem tile is 2D (TILE_Y, TILE_X); stages laid out back-to-back. + # Mode 0 is hierarchical ((TILE_Y, TILE_X),) so it matches the rank/shape + # of gX_tiled[(None, (ty, tx))] produced by zipped_divide. + # sX[(None, stage)] selects one (TILE_Y, TILE_X) tile. + sX = storage.sX.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + if cutlass.const_expr(cfg.ROWWISE): + sO_row = storage.sO_row.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + if cutlass.const_expr(cfg.COLWISE): + sO_col = storage.sO_col.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Prefetch TMA descriptor (one-time; warp-0 only). + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom) + if cutlass.const_expr(cfg.WITH_DACT): + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_act) + + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + + # Producer: `arrive_and_expect_tx` is wrapped in `elect_one`, so only + # one lane of warp 0 arrives on the full barrier per stage → arrive_count=1. + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + # Consumer: `consumer_release` arrives only on the `is_signalling_thread` + # (lane 0 of each warp), so arrive_count = num_warps per stage. + num_warps = THREADS_PER_CHUNK // 32 + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_warps) + + # Bytes transferred per TMA copy: one (TILE_Y, TILE_X) tile of dtype. + # dact loads two tiles (grad + act_input) under the same per-stage barrier, + # so the barrier must expect both copies' bytes. + tx_count = TILE_Y * TILE_X * dtype.width // 8 + if cutlass.const_expr(cfg.WITH_DACT): + tx_count *= 2 + + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_storage.data_ptr(), + num_stages=NUM_STAGES, + producer_group=producer_group, + consumer_group=consumer_group, + tx_count=tx_count, + cta_layout_vmnk=None, # single-CTA, no cluster/multicast + ) + + prod_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, NUM_STAGES) + cons_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, NUM_STAGES) + + M = mX.shape[0] + N = mX.shape[1] + + num_tiles = cutlass.min( + NUM_TILES, + cute.ceil_div(M - bidy * TILE_Y * NUM_TILES, TILE_Y), + ) + + # Tile the TMA gmem view: ((TILE_Y, TILE_X), (M/TILE_Y, N/TILE_X)). + gX_tiled = cute.zipped_divide(tma_src, (TILE_Y, TILE_X)) + + # Partition sX/gX for the TMA atom (single-CTA, no cluster/multicast). + tXsX, tXgX = cute.nvgpu.cpasync.tma_partition( + tma_atom, + 0, # Use the only CTA to do the TMA copy + cute.make_layout(1), # This cluster only has 1 CTAs + sX, + gX_tiled, + ) + + # dact: identical partition for the activation-input load. + if cutlass.const_expr(cfg.WITH_DACT): + gA_tiled = cute.zipped_divide(tma_src_act, (TILE_Y, TILE_X)) + tXsA, tXgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_act, + 0, + cute.make_layout(1), + sActInput, + gA_tiled, + ) + + # Same partitioning for S2G outputs: sO_row → mO_row and sO_col → mO_col. + if cutlass.const_expr(cfg.ROWWISE): + gO_row_tiled = cute.zipped_divide(tma_dst_out_row, (TILE_Y, TILE_X)) + tXsO_row, tXgO_row = cute.nvgpu.cpasync.tma_partition( + tma_atom_out_row, + 0, + cute.make_layout(1), + sO_row, + gO_row_tiled, + ) + if cutlass.const_expr(cfg.COLWISE): + gO_col_tiled = cute.zipped_divide(tma_dst_out_col, (TILE_Y, TILE_X)) + tXsO_col, tXgO_col = cute.nvgpu.cpasync.tma_partition( + tma_atom_out_col, + 0, + cute.make_layout(1), + sO_col, + gO_col_tiled, + ) + + # print(f"sX: {sX}\n") + # print(f"gX_tiled: {gX_tiled}\n") + # print(f"tXsX: {tXsX}\n") + # print(f"tXgX: {tXgX}\n") + + # Ensure barrier init is visible to all threads before the pipeline is used. + cute.arch.sync_threads() + + # ---- Producer: warp 0 issues one TMA copy per tile. ---- + if warp_idx == 0: + for stage in cutlass.range(num_tiles, unroll=1): + mainloop_pipeline.producer_acquire(prod_state) + tile_y = bidy * NUM_TILES + stage + cute.copy( + tma_atom, + tXgX[(None, (tile_y, bidx))], + tXsX[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + if cutlass.const_expr(cfg.WITH_DACT): + cute.copy( + tma_atom_act, + tXgA[(None, (tile_y, bidx))], + tXsA[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + mainloop_pipeline.producer_commit(prod_state) + prod_state.advance() + + # Per-thread amax accumulator across all stages of this CTA. Combined + # with the per-warp redux + cross-warp shmem reduce + atomic at the + # bottom to produce a global max(|x|) in mAmax. Initialised to 0 + # since amax is non-negative. + if cutlass.const_expr(cfg.WITH_AMAX): + block_amax = Float32(0.0) + + # Per-thread partial dbias: thread tidx owns column tidx of the colwise + # tile and accumulates its column sum over this CTA's rows (both stages). + # Written to workspace[bidy, col] below; reduced over row-blocks separately. + if cutlass.const_expr(cfg.WITH_DBIAS): + block_dbias = Float32(0.0) + # Rowwise-only dbias: each thread holds per-column partials for its 32-col + # block, summed across stages, then cross-thread reduced (over THREADS_Y) + # into block_dbias after the loop. + if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): + rowwise_dbias_arr = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((SCALE_DIM,), stride=(1,)), + dtype=Float32, + ) + for c in cutlass.range_constexpr(SCALE_DIM): + rowwise_dbias_arr[c] = Float32(0.0) + + # ---- Consumer: all threads quantize each completed tile. ---- + for stage in cutlass.range(num_tiles, unroll=1): + mainloop_pipeline.consumer_wait(cons_state) + sX_tile = sX[(None, stage)] # (TILE_Y, TILE_X) bf16 (grad for dact) + sActInput_tile = None + if cutlass.const_expr(cfg.WITH_DACT): + sActInput_tile = sActInput[(None, stage)] # (TILE_Y, TILE_X) act_input + + """ + grid = [ + cute.ceil_div(Int32(N), TILE_X), + cute.ceil_div(M, TILE_Y * NUM_TILES), + ] + So to obtain the tile that belongs to this CTA. + """ + # This is just block's x axis idx + tile_idx_x = bidx + # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. + # So the tile index along Y dimension is `bidy * NUM_TILES + stage` + tile_idx_y = bidy * NUM_TILES + stage + if cutlass.const_expr(cfg.COLWISE): + # The first row that belongs to this CTA. Each CTA handles NUM_TILES of (TILE_Y, TILE_X) tiles stacked vertically, + # and each stage handles one of them. + sO_col_tile = sO_col[(None, stage)] + mS_col_stage = cute.flatten(mS_col[(None, (tile_idx_y, tile_idx_x))]) + + amax_c, dbias_c = self._process_colwise( + sX_tile, + sO_col_tile, + mS_col_stage, + max_norm_rcp, + tile_idx_y * TILE_Y, + bidx * TILE_X, + M, + N, + sActInput_tile, + ) + if cutlass.const_expr(cfg.WITH_AMAX): + block_amax = cute.arch.fmax(block_amax, amax_c) + if cutlass.const_expr(cfg.WITH_DBIAS): + block_dbias += dbias_c + if cutlass.const_expr(cfg.ROWWISE): + sO_row_tile = sO_row[(None, stage)] + # mS_row is ((SCALE_TILE), (GRID)) where SCALE_TILE = (32, 2). + # Each CTA owns NUM_TILES consecutive row-tiles of GRID. cute + # auto-decomposes the flat row coord `bidy * NUM_TILES + stage` + # onto GRID's hierarchical row modes — which is the + # (i_hi, tile_Y) tile-major order for swizzled, and the plain + # row-tile order for compact. Same source, both layouts correct. + mS_row_stage = cute.flatten(mS_row[(None, (tile_idx_y, tile_idx_x))]) + # print(f"s0_row_tile: {sO_row_tile}\n") + # print(f"sO_row: {sO_row}\n") + # print(f"mS_row: {mS_row}\n") + # print(f"mS_row_stage: {mS_row_stage}\n") + # print(f"mS_row_stage: {mS_row_stage}\n") + amax_r = self._process_rowwise( + sX_tile, + sO_row_tile, + mS_row_stage, + max_norm_rcp, + tile_idx_y * TILE_Y, + bidx * TILE_X, + M, + N, + sActInput_tile, + rowwise_dbias_arr if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE) else None, + ) + + if cutlass.const_expr(cfg.WITH_AMAX): + block_amax = cute.arch.fmax(block_amax, amax_r) + + # Make all smem stores (sO_row and/or sO_col) visible to the TMA + # async proxy, then block-sync so warp 0 sees the fences from all + # warps before issuing the bulk store(s). Matches the C++ + # reference's fence_proxy + __syncthreads pattern. + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + cute.arch.sync_threads() + + if warp_idx == 0: + tile_y = bidy * NUM_TILES + stage + if cutlass.const_expr(cfg.ROWWISE): + cute.copy( + tma_atom_out_row, + tXsO_row[(None, stage)], + tXgO_row[(None, (tile_y, bidx))], + ) + if cutlass.const_expr(cfg.COLWISE): + cute.copy( + tma_atom_out_col, + tXsO_col[(None, stage)], + tXgO_col[(None, (tile_y, bidx))], + ) + cute.arch.cp_async_bulk_commit_group() + + mainloop_pipeline.consumer_release(cons_state) + cons_state.advance() + + # Wait for in-flight TMA stores so data is visible to the host + # before the kernel returns. + cute.arch.cp_async_bulk_wait_group(0, read=False) + + # ---- rowwise-only dbias: cross-thread reduction over THREADS_Y --------- + # In the rowwise pass each thread owns a row, so its rowwise_dbias_arr holds + # per-column partials for its 32-col block. Transpose through smem so thread + # tidx ends up owning column tidx of the chunk (mirrors CUDA's + # partial_dbias_rowwise smem buffer + reduce over THREADS_Y). + if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): + THREADS_X = TILE_X // SCALE_DIM # scale-blocks per row (=2) + tid_Y = tidx // THREADS_X + tid_X = tidx % THREADS_X + for c in cutlass.range_constexpr(SCALE_DIM): + sDbias[tid_Y * DBIAS_BUFF_WIDTH + tid_X * (SCALE_DIM + 1) + c] = rowwise_dbias_arr[ + c + ] + cute.arch.sync_threads() + # thread tidx owns column tidx; +block skips the per-block padding slot. + block = tidx // SCALE_DIM + block_dbias = Float32(0.0) + for i in cutlass.range_constexpr(TILE_Y): + block_dbias += sDbias[i * DBIAS_BUFF_WIDTH + tidx + block] + + # ---- dbias: write this CTA's per-column partial to the workspace ------- + # Thread tidx owns column (bidx*TILE_X + tidx). Each CTA-row-block (bidy) + # contributes one row of the (blocks_Y, N) fp32 workspace; the reduction + # over blocks_Y to the final dbias[N] is a separate step. + if cutlass.const_expr(cfg.WITH_DBIAS): + dbias_col = bidx * TILE_X + tidx + if dbias_col < N: + mWorkspace[(bidy, dbias_col)] = block_dbias + + # ---- amax block reduction + cross-CTA atomic ---------------------- + # 1) intra-warp: redux.sync.fmax.f32 (sm_80+, single instruction). + # 2) cross-warp: NUM_WARPS shmem floats + sync_threads. + # 3) cross-CTA: int-atomic-max on the f32 bit pattern. Since amax is + # always ≥ 0, IEEE-754 bit ordering on positives matches float + # magnitude ordering, so atomic_max on i32 bits gives the right + # result. (atomic_max_float32 also exists but its pointer + # normalisation is broken as of this CuTeDSL build.) + if cutlass.const_expr(cfg.WITH_AMAX): + warp_amax = cute.arch.warp_redux_sync(block_amax, kind="fmax") + sAmax = storage.sAmax.get_tensor(cute.make_layout(NUM_WARPS)) + lane_idx = tidx % 32 + if lane_idx == 0: + sAmax[warp_idx] = warp_amax + cute.arch.sync_threads() + if tidx == 0: + cta_amax = Float32(0.0) + for w in cutlass.range_constexpr(NUM_WARPS): + cta_amax = cute.arch.fmax(cta_amax, sAmax[w]) + amax_i32 = cute.make_tensor( + cute.recast_ptr(mAmax.iterator, dtype=Int32), + cute.make_layout(1), + ) + cute.arch.atomic_max( + amax_i32.iterator, + _bitcast_f32_to_i32(cta_amax), + ) + + @cute.jit + def _process_rowwise( + self, + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, + N, # Int32 — full input extents, for OOB masking + sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) + dbias_acc=None, # rmem Float32[SCALE_DIM] dbias accumulator (rowwise-only dbias) + ): + """Rowwise MXFP8 pass: thread `(tid_Y, tid_X) = (tidx % 32, tidx // 32)` + owns one 32-element scale block (row `tid_Y`, columns `tid_X*32 .. +32`). + + The bank-group swizzle `((w + bank_group) * PACK_SIZE) % SCALE_DIM` + staggers each 4-thread group's starting wave, which otherwise would + collide on smem banks since all lanes in a warp read different rows + at the same column offset. + + Writes quantized bytes into `sO_row_tile` as u32s (one per wave); + caller is responsible for the TMA S2G flush. + """ + cfg = self.cfg + return quantize_rowwise_mxfp8( + sX_tile, + sO_row_tile, + mS_row_stage, + max_norm_rcp, + tile_row_start, + tile_col_start, + M, + N, + ACTIVATION=cfg.ACTIVATION, + DTYPE=cfg.DTYPE, + ROWWISE=cfg.ROWWISE, + COLWISE=cfg.COLWISE, + FP8_DTYPE=cfg.FP8_DTYPE, + TILE_Y=TILE_Y, + SCALE_DIM=SCALE_DIM, + WAVES=WAVES, + THREADS_PER_WARP=THREADS_PER_WARP, + THREADS_PER_BANK=THREADS_PER_BANK, + PACK_SIZE=PACK_SIZE, + WITH_ACT=cfg.WITH_ACT, + WITH_DACT=cfg.WITH_DACT, + sA_tile=sActInput_tile, + DBIAS_REDUCTION=cfg.WITH_DBIAS and not cfg.COLWISE, + dbias_acc=dbias_acc, + ) + + @cute.jit + def _process_colwise( + self, + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, + N, # Int32 — full input extents, for OOB masking + sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) + ): + """Colwise MXFP8 pass: thread `tidx` owns column `tidx` of the (32, 64) + smem tile — 32 elements down. Writes quantized bytes into `sO_col_tile` + so the caller can flush with a TMA S2G — matches C++'s + `out_colwise_data_sh` + `cp.async.bulk.tensor.2d.shared_to_global`. + """ + cfg = self.cfg + return quantize_colwise_mxfp8( + sX_tile, + sO_col_tile, + mS_col_stage, + max_norm_rcp, + tile_row_start, + tile_col_start, + M, + N, + ACTIVATION=cfg.ACTIVATION, + DTYPE=cfg.DTYPE, + FP8_DTYPE=cfg.FP8_DTYPE, + SWIZZLE=cfg.WITH_GEMM_SWIZZLED_SCALES, + TILE_X=TILE_X, + TILE_Y=TILE_Y, + SCALE_DIM=SCALE_DIM, + WITH_ACT=cfg.WITH_ACT, + WITH_DACT=cfg.WITH_DACT, + sA_tile=sActInput_tile, + WITH_DBIAS=cfg.WITH_DBIAS, + ) + + +def compile_cutedsl_function_from_cfg(cfg): + """ + Return the compiled CuTeDSL function object for the given MXFP8 quantization config. + """ + + kernel_obj = MXFP8QuantizeSmemKernel(cfg) + + # stride_order=(1, 0): row-major, dim 1 stride 1. 1D: (0,). + kw_rm16_2d = dict(stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) + kw_rm4_2d = dict(stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) + kw_rm4_1d = dict(stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4) + + def fake(dtype, shape, kw): + return cute.runtime.make_fake_compact_tensor(dtype, shape, **kw) + + # M, N must be divisible by the MXFP8 scale-block size (SCALE_DIM = 32) — the + # same alignment the CUDA C++ kernel requires. The C++ dispatcher gates on the + # matching value (kCuTeDSLMXFP8ShapeAlignment in cast/dispatch/quantize.cuh) + # and falls back to CUDA for anything not divisible by it, so tvm-ffi never + # sees a shape this kernel can't accept. + sym_M = cute.sym_int32(divisibility=SCALE_DIM) + sym_N = cute.sym_int32(divisibility=SCALE_DIM) + in_shape = out_shape = (sym_M, sym_N) + # TE allocates scale tensors at a padded shape (see + # MXFP8Quantizer::get_scale_shape in transformer_engine/pytorch/csrc): + # rowwise: (roundup(M, 128), roundup(N // 32, 4)) + # columnwise: (roundup(M // 32, 4), roundup(N, 128)) + # These padded extents are NOT M/N (and SymInt has no `//`/`+`), so give the + # scales their own fresh syms carrying the divisibility the padding + # guarantees (rowwise: 128 x 4; colwise: 4 x 128). + scale_r_shape = (cute.sym_int32(divisibility=128), cute.sym_int32(divisibility=4)) + scale_c_shape = (cute.sym_int32(divisibility=4), cute.sym_int32(divisibility=128)) + # Scale dim-1 is only 4-byte-divisible, so a 16-byte alignment promise would + # be a lie for many shapes; the per-block scale stores are byte-wise anyway, + # so 4-byte alignment loses nothing. + scale_kw = kw_rm4_2d + + in_fake = fake(cfg.DTYPE, in_shape, kw_rm16_2d) + out_row_fake = fake(cute.Uint8, out_shape, kw_rm16_2d) if cfg.ROWWISE else None + scale_row_fake = fake(cute.Uint8, scale_r_shape, scale_kw) if cfg.ROWWISE else None + out_col_fake = fake(cute.Uint8, out_shape, kw_rm16_2d) if cfg.COLWISE else None + scale_col_fake = fake(cute.Uint8, scale_c_shape, scale_kw) if cfg.COLWISE else None + amax_fake = fake(Float32, (1,), kw_rm4_1d) if cfg.WITH_AMAX else None + noop_fake = fake(Float32, (1,), kw_rm4_1d) if cfg.WITH_NOOP else None + # Backward-only slots (act_input/dbias/workspace). Always None today — + # WITH_DACT/WITH_DBIAS are rejected in the config — but kept in the compile + # signature so the tvm-ffi protocol matches the CUDA mxfp8::quantize args. + act_input_fake = fake(cfg.DTYPE, in_shape, kw_rm16_2d) if cfg.WITH_DACT else None + # dbias: the kernel never writes the dbias tensor — it writes per-row-block + # partials into the workspace (shape (blocks_Y, N) fp32, blocks_Y = ceil(M/64), + # set by the C++ worker's size query). The final reduction lives elsewhere, so + # mDbias stays None and only the workspace fake is built. + dbias_fake = None + ws_shape = (cute.sym_int32(), sym_N) # (blocks_Y, N); N ties to input N + workspace_fake = fake(Float32, ws_shape, kw_rm4_2d) if cfg.WITH_DBIAS else None + + compiled = cute.compile( + kernel_obj, + in_fake, # mX + out_row_fake, + scale_row_fake, # mO_row, mS_row + out_col_fake, + scale_col_fake, # mO_col, mS_col + amax_fake, # mAmax + noop_fake, # mNoop (1-element cast_noop flag) + act_input_fake, # mActInput (backward slot, unused) + dbias_fake, # mDbias (backward slot, unused) + workspace_fake, # mWorkspace(backward slot, unused) + cute.runtime.make_fake_stream(), # stream (compiled as an explicit tvm-ffi + # "handle" arg; C++ passes the CUDA stream + # as void*) + options="--enable-tvm-ffi", + ) + return compiled + + +def get_mxfp8_quantization_function( + fn_name: str, + dtype: str, + fp8_dtype: str, + rowwise: bool, + colwise: bool, + with_gemm_swizzled_scales: bool, + with_amax: bool, + with_dbias: bool, + with_dact: bool, + with_act: bool, + with_noop: bool, + activation: str, +) -> bool: + """Compile the MXFP8 quantize kernel for this config and register it in the + TVM-FFI global registry under EXACTLY `fn_name` (the key the C++ dispatcher + built; Python treats it as an opaque name). Returns True if a kernel is + registered under `fn_name` (the C++ side then fetches it with + GetGlobal(fn_name)); False if the config is unsupported, so the caller caches + the negative result and falls back to the CUDA C++ kernel. + + The registry owns the compiled kernel's lifetime — important because it wraps + a Python object, and tvm-ffi releases registry entries at interpreter + shutdown (whereas a C++-held handle would be released after finalize → crash). + """ + # Already registered (e.g. by a prior call) -> supported. + if tvm_ffi.get_global_func(fn_name, allow_missing=True) is not None: + return True + + try: + cfg = MXFP8QuantizeConfig( + dtype=dtype, + fp8_dtype=fp8_dtype, + rowwise=rowwise, + colwise=colwise, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + with_amax=with_amax, + with_dbias=with_dbias, + with_dact=with_dact, + with_act=with_act, + with_noop=with_noop, + activation=activation, + ) + except ValueError as e: + # The exception message states exactly why the config is unsupported + # (unknown dtype/activation, dbias not implemented, ...). Surfacing it as a + # warning lets the C++ dispatcher's CUDA fallback be recognized as expected. + logger.warning( + "CuTeDSL MXFP8 backend does not support this config, " + f"falling back to the CUDA C++ kernel: {e}" + ) + return False + + logger.debug(f"Compiling CuTeDSL MXFP8 quantization kernel for {cfg}") + compiled = compile_cutedsl_function_from_cfg(cfg) + tvm_ffi.register_global_func(fn_name, compiled, override=True) + + return True + + +# Exposed so the C++ dispatcher can request on-demand compilation by name. +tvm_ffi.register_global_func( + "get_mxfp8_quantization_function", get_mxfp8_quantization_function, override=True +) diff --git a/transformer_engine/common/CuTeDSL/utils.py b/transformer_engine/common/CuTeDSL/utils.py new file mode 100644 index 0000000000..9ad78fc1d0 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/utils.py @@ -0,0 +1,18 @@ +import cutlass + +_CUTLASS_DTYPE_FROM_STR = { + "fp32": cutlass.Float32, + "fp16": cutlass.Float16, + "bf16": cutlass.BFloat16, +} +_STR_FROM_CUTLASS_DTYPE = {v: k for k, v in _CUTLASS_DTYPE_FROM_STR.items()} + + +def str_to_cutlass_dtype(dtype_str: str): + """Convert a string dtype to a cutlass dtype, or None if unknown.""" + return _CUTLASS_DTYPE_FROM_STR.get(dtype_str, None) + + +def cutlass_dtype_to_str(dtype): + """Convert a cutlass dtype back to its protocol string, or None if unknown.""" + return _STR_FROM_CUTLASS_DTYPE.get(dtype, None) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 6c71285cd4..a3d237d61c 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -13,13 +13,18 @@ #include +#include +#include + #include "../../common.h" #include "../../transpose/cast_transpose.h" +#include "../../util/cuda_runtime.h" #include "../../util/vectorized_pointwise.h" #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#include "../mxfp8/quantize_mxfp8_cutedsl.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_4over6_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" @@ -84,9 +89,16 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const Tensor *dummy_input_tensor = nullptr; Tensor *dummy_dbias_tensor = nullptr; Tensor *dummy_workspace_tensor = nullptr; - mxfp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); + bool quantized_with_cutedsl = + quantize::mxfp8_quantize_cutedsl(input_tensor, dummy_input_tensor, noop_tensor, + output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + if (!quantized_with_cutedsl) { + mxfp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } break; } case NVTE_NVFP4_1D_SCALING: { @@ -249,9 +261,15 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens break; } case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); + bool quantized_with_cutedsl = + quantize::mxfp8_quantize_cutedsl( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + if (!quantized_with_cutedsl) { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } break; } case NVTE_NVFP4_1D_SCALING: { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh new file mode 100644 index 0000000000..cecbf8a7d0 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh @@ -0,0 +1,260 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ +#define TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ + +#include +#include + +#include +#include +#include + +#include "../../common.h" +#include "../../tvm_ffi_bridge.h" +#include "../../util/math.h" +#include "../core/common.cuh" // dispatch::common::reduce_dbias + +namespace transformer_engine { +namespace tvm_ffi_bridge { + +struct MXFP8QuantConfig { + static constexpr const char *kEntrypointName = "get_mxfp8_quantization_function"; + + DType dtype; + DType fp8_dtype; + bool rowwise; + bool colwise; + bool swizzled; + bool with_amax; + bool with_dbias = false; + bool with_dact = false; + bool with_act = false; + bool with_noop = false; + Activation activation = Activation::kNone; + + std::string to_key() const { + std::string key; + key.reserve(56); + key.append("cutedsl_mxfp8_") + .append(te_dtype_to_str(dtype)) + .append("_") + .append(te_dtype_to_str(fp8_dtype)) + .append("_") + .append(rowwise ? "1" : "0") + .append("_") + .append(colwise ? "1" : "0") + .append("_") + .append(swizzled ? "1" : "0") + .append("_") + .append(with_amax ? "1" : "0") + .append("_") + .append(with_dbias ? "1" : "0") + .append("_") + .append(with_dact ? "1" : "0") + .append("_") + .append(with_act ? "1" : "0") + .append("_") + .append(with_noop ? "1" : "0") + .append("_") + .append(activation_to_str(activation)); + return key; + } + + bool retrieve_func_from_python(const std::string &fn_name) const { + auto entrypoint = tvm::ffi::Function::GetGlobal(kEntrypointName); + if (!entrypoint.has_value()) { + return false; + } + tvm::ffi::Any result = + (*entrypoint)(tvm::ffi::String(fn_name), tvm::ffi::String(te_dtype_to_str(dtype)), + tvm::ffi::String(te_dtype_to_str(fp8_dtype)), rowwise, colwise, swizzled, + with_amax, with_dbias, with_dact, with_act, with_noop, + tvm::ffi::String(activation_to_str(activation))); + return result.try_cast().value_or(false); + } +}; + +template +struct MXFP8QuantFused { + static constexpr Activation activation = Activation::kNone; + // No fused op: plain quantize, or dbias-only cast (IS_DBIAS, no activation). + static constexpr bool supported = (OP == nullptr) && !IS_DACT && !IS_ACT; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kReLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kGeLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kSiLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kQGeLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kSReLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDReLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDGeLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDSiLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDQGeLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDSReLU; + static constexpr bool supported = true; +}; + +} // namespace tvm_ffi_bridge + +namespace quantize { + +// Signature mirrors mxfp8::quantize (input, act_input, noop, output, dbias, +// workspace, stream). Returns false to fall back to the CUDA kernel. +inline bool mxfp8_quantize_cutedsl(const tvm_ffi_bridge::MXFP8QuantConfig &config, + const Tensor *input_tensor, const Tensor *act_input_tensor, + const Tensor *noop_tensor, Tensor *output_tensor, + Tensor *dbias_tensor, Tensor *workspace_tensor, + cudaStream_t stream) { + constexpr size_t kCuTeDSLMXFP8ShapeAlignment = 32; + const size_t flat_m = input_tensor->flat_first_dim(); + const size_t flat_n = input_tensor->flat_last_dim(); + if (flat_m % kCuTeDSLMXFP8ShapeAlignment != 0 || flat_n % kCuTeDSLMXFP8ShapeAlignment != 0) { + return false; + } + + // dbias workspace-size query, mirroring mxfp8::quantize: the framework first + // calls with an unallocated workspace to learn its shape, allocates a buffer of + // that shape, then calls again to run. The kernel writes per-row-block partial + // dbias into this workspace; reducing it to the final dbias is a separate step. + if (config.with_dbias && workspace_tensor != nullptr && workspace_tensor->data.dptr == nullptr) { + constexpr size_t kCuTeDSLMXFP8ChunkRows = 64; // TILE_Y * NUM_TILES (CTA row span) + const size_t dbias_rows = (flat_m + kCuTeDSLMXFP8ChunkRows - 1) / kCuTeDSLMXFP8ChunkRows; + workspace_tensor->data.shape = {dbias_rows, flat_n}; + workspace_tensor->data.dtype = DType::kFloat32; + return true; + } + + std::optional mxfp8_quant_func_opt = + tvm_ffi_bridge::TVMFFICentral::getInstance().lazyload_function(config); + if (!mxfp8_quant_func_opt.has_value()) { + return false; + } + + // Zero out swizzled scale padding when the matrix isn't a multiple of the + // 128x128 GEMM tile. The kernel writes only the meaningful scale region, so + // cuBLAS would otherwise read uninitialized padding. Mirrors the CUDA launcher + // in quantize_mxfp8.cuh (the kernel itself does not pad the scales). + // TODO: move this into the CuTeDSL host code so the padding is handled inside + // the kernel launch — this CUDA-driver memset is an implementation detail that + // doesn't belong in the dispatcher (blocked on calling the driver API there). + if (config.swizzled && (flat_m % 128 != 0 || flat_n % 128 != 0)) { + if (output_tensor->has_data()) { + NVTE_CHECK_CUDA(cudaMemsetAsync(output_tensor->scale_inv.dptr, 0, + output_tensor->scale_inv.buffer_size_bytes(), stream)); + } + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK_CUDA(cudaMemsetAsync(output_tensor->columnwise_scale_inv.dptr, 0, + output_tensor->columnwise_scale_inv.buffer_size_bytes(), + stream)); + } + } + + // Data tensors auto-flatten to 2D (DLTensorWrapper's default), matching the + // kernel's flat (rows, cols) view; scale/amax/noop are rank <= 2 and pass through. + tvm_ffi_bridge::DLTensorWrapper mX(input_tensor->data); + tvm_ffi_bridge::DLTensorWrapper mO_row(output_tensor->data); + tvm_ffi_bridge::DLTensorWrapper mS_row(output_tensor->scale_inv); + tvm_ffi_bridge::DLTensorWrapper mO_col(output_tensor->columnwise_data); + tvm_ffi_bridge::DLTensorWrapper mS_col(output_tensor->columnwise_scale_inv); + tvm_ffi_bridge::DLTensorWrapper mAmax(output_tensor->amax); + tvm_ffi_bridge::DLTensorWrapper mNoop(noop_tensor->data); + // Backward tensors: null wrapper (None) unless present, no allocation when absent. + // mDbias stays None: the kernel writes per-block partials into the workspace, and + // the final dbias is produced by a separate reduction (not by this kernel). + tvm_ffi_bridge::DLTensorWrapper mActInput, mDbias, mWorkspace; + if (act_input_tensor != nullptr) + mActInput = tvm_ffi_bridge::DLTensorWrapper(act_input_tensor->data); + if (workspace_tensor != nullptr) + mWorkspace = tvm_ffi_bridge::DLTensorWrapper(workspace_tensor->data); + // stream is a tvm-ffi opaque "handle"; pass the CUDA stream as void*. + (*mxfp8_quant_func_opt)(&mX, &mO_row, &mS_row, &mO_col, &mS_col, &mAmax, &mNoop, &mActInput, + &mDbias, &mWorkspace, static_cast(stream)); + + // dbias: the kernel wrote per-row-block partials into the workspace; reduce them + // over the row-blocks into the final dbias[N]. Mirrors mxfp8::quantize, which + // launches common::reduce_dbias after its quantize kernel. + if (config.with_dbias) { + const size_t blocks_Y = (flat_m + 63) / 64; // ceil(M/64) = workspace rows + const float *workspace_ptr = reinterpret_cast(workspace_tensor->data.dptr); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input_tensor->dtype(), IType, + dispatch::common::reduce_dbias(workspace_ptr, dbias_tensor, blocks_Y, flat_n, + stream);) // NOLINT(*) + } + return true; +} + +template +bool mxfp8_quantize_cutedsl(const Tensor *input_tensor, const Tensor *act_input_tensor, + const Tensor *noop_tensor, Tensor *output_tensor, Tensor *dbias_tensor, + Tensor *workspace_tensor, cudaStream_t stream) { + using Fused = tvm_ffi_bridge::MXFP8QuantFused; + if constexpr (!Fused::supported) { + return false; + } else { + const bool with_noop = noop_tensor != nullptr && noop_tensor->data.dptr != nullptr; + const tvm_ffi_bridge::MXFP8QuantConfig config{ + /*dtype=*/input_tensor->dtype(), + /*fp8_dtype=*/output_tensor->dtype(), + /*rowwise=*/output_tensor->has_data(), + /*colwise=*/output_tensor->has_columnwise_data(), + /*swizzled=*/output_tensor->with_gemm_swizzled_scales, + /*with_amax=*/output_tensor->amax.dptr != nullptr, + /*with_dbias=*/IS_DBIAS, + /*with_dact=*/IS_DACT, + /*with_act=*/IS_ACT, + /*with_noop=*/with_noop, + /*activation=*/Fused::activation}; + return mxfp8_quantize_cutedsl(config, input_tensor, act_input_tensor, noop_tensor, + output_tensor, dbias_tensor, workspace_tensor, stream); + } +} + +} // namespace quantize +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ diff --git a/transformer_engine/common/tvm_ffi_bridge.h b/transformer_engine/common/tvm_ffi_bridge.h new file mode 100644 index 0000000000..e7317de664 --- /dev/null +++ b/transformer_engine/common/tvm_ffi_bridge.h @@ -0,0 +1,286 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ +#define TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transformer_engine/transformer_engine.h" +#include "util/cuda_runtime.h" +#include "util/logging.h" + +namespace transformer_engine { +namespace tvm_ffi_bridge { + +inline const char *te_dtype_to_str(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return "fp32"; + case DType::kFloat16: + return "fp16"; + case DType::kBFloat16: + return "bf16"; + case DType::kFloat8E4M3: + return "e4m3"; + case DType::kFloat8E5M2: + return "e5m2"; + default: + return ""; + } +} + +// Fused activation token forwarded to Python. Encodes both the family and the +// forward-vs-derivative direction: "relu" is the forward activation, "drelu" its +// backward derivative (dact). This is why no separate is_act/is_dact flag is +// needed — the token carries it; only with_dbias (orthogonal) is a separate flag. +// The d-variants are slots for the not-yet-wired backward path; the forward +// tokens must match Python's SUPPORTED_ACTIVATIONS set. +enum class Activation { + kNone, + kReLU, + kGeLU, + kSiLU, + kQGeLU, + kSReLU, + kDReLU, + kDGeLU, + kDSiLU, + kDQGeLU, + kDSReLU +}; + +inline const char *activation_to_str(Activation act) { + switch (act) { + case Activation::kReLU: + return "relu"; + case Activation::kGeLU: + return "gelu"; + case Activation::kSiLU: + return "silu"; + case Activation::kQGeLU: + return "qgelu"; + case Activation::kSReLU: + return "srelu"; + case Activation::kDReLU: + return "drelu"; + case Activation::kDGeLU: + return "dgelu"; + case Activation::kDSiLU: + return "dsilu"; + case Activation::kDQGeLU: + return "dqgelu"; + case Activation::kDSReLU: + return "dsrelu"; + case Activation::kNone: + return "none"; + } + return "none"; +} + +inline DLDataType convert_to_dltype(NVTEDType type) { + switch (type) { + case kNVTEFloat32: + return DLDataType{kDLFloat, 32, 1}; + case kNVTEFloat16: + return DLDataType{kDLFloat, 16, 1}; + case kNVTEBFloat16: + return DLDataType{kDLBfloat, 16, 1}; + case kNVTEByte: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEInt32: + return DLDataType{kDLInt, 32, 1}; + case kNVTEInt64: + return DLDataType{kDLInt, 64, 1}; + // FP8 / E8M0 → raw 1-byte uint; the kernel interprets the bits. + case kNVTEFloat8E4M3: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E5M2: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E8M0: + return DLDataType{kDLUInt, 8, 1}; + default: + NVTE_ERROR("unsupported NVTEDType: ", static_cast(type)); + } +} + +class DLTensorWrapper : public DLTensor { + public: + // Null wrapper (data == nullptr): packs as TVM-FFI None, no allocation. + DLTensorWrapper() : DLTensor{} {} + + DLTensorWrapper(const NVTEBasicTensor &tensor, bool flatten_2D = true) { + const int32_t device_index = transformer_engine::cuda::current_device(); + const int n = static_cast(tensor.shape.ndim); + if (flatten_2D && n > 2) { + int64_t flat_first = 1; + for (int i = 0; i + 1 < n; ++i) flat_first *= static_cast(tensor.shape.data[i]); + const int64_t flat_last = static_cast(tensor.shape.data[n - 1]); + shape_buf_ = std::make_unique(2); + strides_buf_ = std::make_unique(2); + shape_buf_[0] = flat_first; + shape_buf_[1] = flat_last; + strides_buf_[0] = flat_last; + strides_buf_[1] = 1; + this->ndim = 2; + } else { + shape_buf_ = std::make_unique(n); + strides_buf_ = std::make_unique(n); + int64_t stride = 1; + for (int i = n - 1; i >= 0; --i) { + shape_buf_[i] = static_cast(tensor.shape.data[i]); + strides_buf_[i] = stride; + stride *= shape_buf_[i]; + } + this->ndim = n; + } + this->data = tensor.data_ptr; + this->device = DLDevice{kDLCUDA, device_index}; + this->dtype = convert_to_dltype(tensor.dtype); + this->shape = shape_buf_.get(); + this->strides = strides_buf_.get(); + this->byte_offset = 0; + } + + ~DLTensorWrapper() = default; + DLTensorWrapper(const DLTensorWrapper &) = delete; + DLTensorWrapper &operator=(const DLTensorWrapper &) = delete; + DLTensorWrapper(DLTensorWrapper &&) = default; + DLTensorWrapper &operator=(DLTensorWrapper &&) = default; + + private: + std::unique_ptr shape_buf_; + std::unique_ptr strides_buf_; +}; + +} // namespace tvm_ffi_bridge +} // namespace transformer_engine + +namespace tvm { +namespace ffi { +// Make a (borrowed) DLTensorWrapper* a first-class TVM-FFI argument, so wrappers +// can be passed straight to Function::operator()(&w, ...). Like DLTensor* it is a +// non-owning DLTensorPtr view (the wrapper must outlive the call), but a null +// pointer OR a wrapper over an absent buffer (null data) packs as TVM-FFI None — +// so a kernel's optional args need no special handling at the call site. Only +// the pack-as-argument path (CopyToAnyView) is provided; reading back is unused. +// Declared after DLTensorWrapper: the specialization needs the complete type +// (it reads src->data and static_casts to its DLTensor base). +template <> +struct TypeTraits + : public TypeTraits { + TVM_FFI_INLINE static void CopyToAnyView(transformer_engine::tvm_ffi_bridge::DLTensorWrapper *src, + TVMFFIAny *result) { + if (src == nullptr || src->data == nullptr) { + TypeTraits::CopyToAnyView(nullptr, result); // -> TVM-FFI None + } else { + TypeTraits::CopyToAnyView(static_cast(src), result); + } + } +}; +} // namespace ffi +} // namespace tvm + +namespace transformer_engine { +namespace tvm_ffi_bridge { + +// Compile-time check that a config provides the lazy-loadable kernel API: +// - std::string to_key() const +// - bool retrieve_func_from_python(const std::string& key) const +// (compiles + globally registers the kernel under `key`; returns whether +// a kernel is now registered / the config is supported) +// Drives the static_assert in TVMFFICentral::lazyload_function so a config that +// is missing either method fails with a clear message instead of a deref-into- +// the-template error. +namespace detail { +template +struct is_lazyloadable_config : std::false_type {}; +template +struct is_lazyloadable_config< + T, std::void_t().to_key()), + decltype(std::declval().retrieve_func_from_python( + std::declval()))>> : std::true_type {}; +} // namespace detail + +class TVMFFICentral { + public: + static TVMFFICentral &getInstance() { + static TVMFFICentral instance; + return instance; + } + + // Resolve the compiled kernel for `cfg`. The kernel itself lives in the tvm-ffi + // global registry (registered by the Python entrypoint under cfg.to_key()), + // which releases its Python-backed entries safely at interpreter shutdown; we + // fetch it per call with GetGlobal(key). C++ caches only a bool per config + // (supported or not), so Python is asked at most once per config and we never + // hold a Python-backed handle in a static-duration object (which would crash + // at exit, when the singleton is torn down after the interpreter is finalized). + template + std::optional lazyload_function(const Config &cfg) { + static_assert(detail::is_lazyloadable_config::value, + "Config must define `std::string to_key() const` and " + "`bool retrieve_func_from_python(const std::string&) const`."); + if (!enabled_) return std::nullopt; + const std::string key = cfg.to_key(); + { + std::shared_lock read_lock(mutex_); + auto it = supported_.find(key); + if (it != supported_.end()) { + return it->second ? tvm::ffi::Function::GetGlobal(key) : std::nullopt; + } + } + // Cold miss: ask Python to compile + globally register the kernel under + // `key`; cache only the support decision (avoids re-asking Python, and + // negative-caches unsupported configs). + const bool supported = cfg.retrieve_func_from_python(key); + { + std::unique_lock write_lock(mutex_); + supported_.emplace(key, supported); + } + return supported ? tvm::ffi::Function::GetGlobal(key) : std::nullopt; + } + + private: + ~TVMFFICentral() = default; + TVMFFICentral() : enabled_(is_cutedsl_backend_enabled()) {} + TVMFFICentral(const TVMFFICentral &) = delete; + TVMFFICentral &operator=(const TVMFFICentral &) = delete; + TVMFFICentral(TVMFFICentral &&) = delete; + TVMFFICentral &operator=(TVMFFICentral &&) = delete; + + static bool is_cutedsl_backend_enabled() { + // On by default; set NVTE_ENABLE_CUTEDSL_QUANT_BACKEND=0 to disable. + const char *flag = std::getenv("NVTE_ENABLE_CUTEDSL_QUANT_BACKEND"); + return flag == nullptr || flag[0] != '0'; + } + + const bool enabled_; + std::shared_mutex mutex_; + // Per-config support decision (cfg.to_key() -> supported). Holds NO Python- + // backed handles, so it is safe to destroy at static teardown — the kernels + // live in the tvm-ffi registry, owned and released by tvm-ffi itself. + std::unordered_map supported_; +}; + +} // namespace tvm_ffi_bridge +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 06db28ee27..334dd0eb15 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -18,6 +18,16 @@ load_framework_extension("torch") from transformer_engine.pytorch import constants from transformer_engine.pytorch.constants import DType + +# Register the CuTeDSL kernel entrypoints (TVM-FFI global funcs) so the C++ +# dispatcher can discover them via GetGlobal and compile kernels on demand. The +# CuTeDSL toolchain (cutlass, tvm_ffi) is optional; if it is unavailable the +# import is skipped and C++ simply falls back to the CUDA C++ kernels. +try: + import transformer_engine.common.CuTeDSL # noqa: F401 +except Exception: + pass + from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP