From 33e938ccf6e7de0dc1ba1af43a1f12b05b9f9888 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 25 Jun 2026 12:35:18 +0200 Subject: [PATCH 1/3] Make quantized-tensor __repr__ fake-safe under torch.compile Under torch.compile, TE quantized-tensor __repr__ methods are invoked on FakeTensors during AOT autograd's structured logging. The repr bodies call self._scale_inv.item() and/or self.dequantize() (which dispatches to the raw C++ op tex.dequantize), both of which access a FakeTensor's data pointer and raise: RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor) ... This was the sole cause of six fp8 failures in tests/pytorch/test_torch_compile.py. Fix: add one shared helper, safe_quantized_repr, in tensor/_quantization_helpers.py (a safe leaf module importing only torch) that builds a metadata-only repr string. Each data-touching __repr__ now wraps its existing body in a try/except and falls back to the helper when the data cannot be materialized. The eager (non-fake) repr output is unchanged; only a fallback path is added. Wrapped reprs: Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor, NVFP4Tensor and their *Storage counterparts. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- .../pytorch/tensor/_quantization_helpers.py | 56 +++++++++++++++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 20 +++++-- .../pytorch/tensor/float8_tensor.py | 19 ++++--- .../pytorch/tensor/mxfp8_tensor.py | 7 ++- .../pytorch/tensor/nvfp4_tensor.py | 7 ++- .../float8_blockwise_tensor_storage.py | 31 ++++++---- .../tensor/storage/float8_tensor_storage.py | 18 +++--- .../tensor/storage/mxfp8_tensor_storage.py | 22 +++++--- .../tensor/storage/nvfp4_tensor_storage.py | 24 ++++---- 9 files changed, 149 insertions(+), 55 deletions(-) diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index 56cf503630..932160ff20 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import Callable, Optional, Tuple, Any, Dict, TYPE_CHECKING +import warnings import torch if TYPE_CHECKING: @@ -83,3 +84,58 @@ def _stride_from_shape(shape: list[int]): for d in reversed(shape[1:]): rstride.append(rstride[-1] * d) return list(reversed(rstride)) + + +def _is_fake_data_access_error(error): + """Heuristic: is this the error PyTorch raises when a repr tries to read the + data pointer of a fake/functional tensor (e.g. under torch.compile tracing)? + + The exact exception type and message are not contractual across PyTorch + versions, so match defensively on the message text. Anything that is not + recognized is treated as an unexpected error worth surfacing. + """ + message = str(error).lower() + return "data pointer" in message or "faketensor" in message or "functionaltensor" in message + + +def safe_quantized_repr(obj, cls_name, extras=None, error=None): + """Metadata-only repr fallback for quantized tensors whose data cannot be + materialized (e.g. a FakeTensor under torch.compile tracing, where + dequantize()/.item() would access a data pointer). + + Parameters + ---------- + extras : dict, optional + Additional plain-Python (non-tensor) attributes to include, e.g. + ``{"is_2D_scaled": self._is_2D_scaled}``. Values are inserted after + ``fp8_dtype`` and before ``shape``. + error : BaseException, optional + The exception that triggered the fallback. The expected trigger is the + data-pointer access error PyTorch raises for fake/functional tensors; + anything else is surfaced as a warning so that real eager-path failures + (e.g. CUDA OOM, shape bugs) are not silently swallowed by ``__repr__``. + """ + if error is not None and not _is_fake_data_access_error(error): + warnings.warn( + f"{cls_name}.__repr__ fell back to a metadata-only representation " + f"because an unexpected error occurred while materializing data: " + f"{type(error).__name__}: {error}", + stacklevel=2, + ) + parts = [] + fp8_dtype = getattr(obj, "_fp8_dtype", None) + if fp8_dtype is not None: + parts.append(f"fp8_dtype={fp8_dtype}") + if extras: + for key, value in extras.items(): + parts.append(f"{key}={value}") + try: + parts.append(f"shape={tuple(obj.shape)}") + except Exception: # pylint: disable=broad-except + pass + try: + parts.append(f"dtype={obj.dtype}") + except Exception: # pylint: disable=broad-except + pass + parts.append("data=") + return f"{cls_name}({', '.join(parts)})" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ba46508d74..d2d28aecfb 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -14,7 +14,7 @@ from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr from ..constants import DType from ..utils import devices_match, round_up_to_nearest_multiple @@ -267,11 +267,19 @@ def __new__( return instance def __repr__(self, *, tensor_contents=None): - return ( - f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," - f" is_2D_scaled={self._is_2D_scaled}," - f" data={self.dequantize()})" - ) + try: + return ( + f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," + f" is_2D_scaled={self._is_2D_scaled}," + f" data={self.dequantize()})" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr( + self, + "Float8BlockwiseQTensor", + extras={"is_2D_scaled": self._is_2D_scaled}, + error=exc, + ) def quantize_( self, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 4de8d82217..17ba87201b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -18,7 +18,7 @@ from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr from ..constants import dist_group_type, DType aten = torch.ops.aten @@ -412,13 +412,16 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ def __repr__(self, *, tensor_contents=None): - return ( - "Float8Tensor(" - f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" - ")" - ) + try: + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "Float8Tensor", error=exc) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d759aaf5c4..33db63d059 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -18,7 +18,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr aten = torch.ops.aten @@ -233,7 +233,10 @@ def __new__( ) def __repr__(self, *, tensor_contents=None): - return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})" + try: + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})" + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "MXFP8Tensor", error=exc) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5a2765b9f5..f131615e72 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -23,7 +23,7 @@ from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func from ..quantized_tensor import QuantizedTensor, Quantizer -from ._quantization_helpers import _IdentityFunc +from ._quantization_helpers import _IdentityFunc, safe_quantized_repr aten = torch.ops.aten @@ -398,7 +398,10 @@ def __new__( return instance def __repr__(self, *, tensor_contents=None): - return f"NVFP4Tensor, data={self.dequantize()})" + try: + return f"NVFP4Tensor, data={self.dequantize()})" + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "NVFP4Tensor", error=exc) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index f7a3dae70b..993ead42ee 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType_To_Torch, DType @@ -354,17 +355,25 @@ def _transpose_columnwise_data(self): del _old_data def __repr__(self): - if self._rowwise_data is not None: - data = self.dequantize() - descriptor = "rowwise" - else: - data = self.dequantize() - descriptor = "columnwise" - return ( - "Float8BlockwiseQTensorStorage(" - f"fp8_dtype={self._fp8_dtype}, " - f"{descriptor}_scaled_data={data})" - ) + try: + if self._rowwise_data is not None: + data = self.dequantize() + descriptor = "rowwise" + else: + data = self.dequantize() + descriptor = "columnwise" + return ( + "Float8BlockwiseQTensorStorage(" + f"fp8_dtype={self._fp8_dtype}, " + f"{descriptor}_scaled_data={data})" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr( + self, + "Float8BlockwiseQTensorStorage", + extras={"is_2D_scaled": self._is_2D_scaled}, + error=exc, + ) def update_usage( self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index a97162f91c..374d0e1e72 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch, DType @@ -209,13 +210,16 @@ def view(self, shape: torch.Size): ) def __repr__(self): - return ( - "Float8TensorStorage(" - f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" - ")" - ) + try: + return ( + "Float8TensorStorage(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "Float8TensorStorage", error=exc) def _create_transpose(self): """Update FP8 transpose cache""" diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index ea592cd989..606ac9e74b 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -13,6 +13,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType as torch_to_transformer_engine_dtype, DType @@ -257,15 +258,18 @@ def view(self, shape: torch.Size): ) def __repr__(self): - data_rowwise = self.dequantize() - - return ( - "MXFP8TensorStorage(" - f"fp8_dtype={self._fp8_dtype}, " - f"rowwise_scaled_data={data_rowwise}" - f"rowwise_scale_inv={self._rowwise_scale_inv}, " - ")" - ) + try: + data_rowwise = self.dequantize() + + return ( + "MXFP8TensorStorage(" + f"fp8_dtype={self._fp8_dtype}, " + f"rowwise_scaled_data={data_rowwise}" + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "MXFP8TensorStorage", error=exc) def update_usage( self, diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 53bb5e7c11..09f040ba67 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -16,6 +16,7 @@ import transformer_engine_torch as tex from ...quantized_tensor import QuantizedTensorStorage, Quantizer +from .._quantization_helpers import safe_quantized_repr from ...constants import TE_DType as torch_to_transformer_engine_dtype, DType from ...utils import _empty_tensor @@ -340,16 +341,19 @@ def view(self, shape: torch.Size): ) def __repr__(self): - data_rowwise = self.dequantize() - - return ( - "NVFP4TensorStorage(" - f"rowwise_scaled_data={data_rowwise}," - f"rowwise_scale_inv={self._rowwise_scale_inv}," - f"amax_rowwise={self._amax_rowwise}," - f"amax_columnwise={self._amax_columnwise}," - ")" - ) + try: + data_rowwise = self.dequantize() + + return ( + "NVFP4TensorStorage(" + f"rowwise_scaled_data={data_rowwise}," + f"rowwise_scale_inv={self._rowwise_scale_inv}," + f"amax_rowwise={self._amax_rowwise}," + f"amax_columnwise={self._amax_columnwise}," + ")" + ) + except Exception as exc: # pylint: disable=broad-except + return safe_quantized_repr(self, "NVFP4TensorStorage", error=exc) def update_usage( self, From 8f8761558f45edd6359ab85a5a7064841dad325c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jun 2026 10:56:56 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/_quantization_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index 932160ff20..bb5a766bc2 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -118,7 +118,7 @@ def safe_quantized_repr(obj, cls_name, extras=None, error=None): if error is not None and not _is_fake_data_access_error(error): warnings.warn( f"{cls_name}.__repr__ fell back to a metadata-only representation " - f"because an unexpected error occurred while materializing data: " + "because an unexpected error occurred while materializing data: " f"{type(error).__name__}: {error}", stacklevel=2, ) From b30b6eed1fa0746b00b0eab31bf8bc8a56b4c7ad Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 25 Jun 2026 13:12:12 +0200 Subject: [PATCH 3/3] Make quantized __repr__ fallback universal, drop FakeTensor-specific logic Remove the FakeTensor-specific heuristic (_is_fake_data_access_error) and the warning path from safe_quantized_repr. The fallback is now a plain metadata-only repr triggered by any exception while materializing data, with each attribute access individually guarded so __repr__ never raises. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- .../pytorch/tensor/_quantization_helpers.py | 37 +++++-------------- 1 file changed, 10 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index bb5a766bc2..1b08039dda 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -10,7 +10,6 @@ from __future__ import annotations from typing import Callable, Optional, Tuple, Any, Dict, TYPE_CHECKING -import warnings import torch if TYPE_CHECKING: @@ -86,22 +85,11 @@ def _stride_from_shape(shape: list[int]): return list(reversed(rstride)) -def _is_fake_data_access_error(error): - """Heuristic: is this the error PyTorch raises when a repr tries to read the - data pointer of a fake/functional tensor (e.g. under torch.compile tracing)? - - The exact exception type and message are not contractual across PyTorch - versions, so match defensively on the message text. Anything that is not - recognized is treated as an unexpected error worth surfacing. - """ - message = str(error).lower() - return "data pointer" in message or "faketensor" in message or "functionaltensor" in message - - def safe_quantized_repr(obj, cls_name, extras=None, error=None): """Metadata-only repr fallback for quantized tensors whose data cannot be - materialized (e.g. a FakeTensor under torch.compile tracing, where - dequantize()/.item() would access a data pointer). + materialized for any reason. + + Each attribute access is guarded so that ``__repr__`` never raises. Parameters ---------- @@ -110,18 +98,10 @@ def safe_quantized_repr(obj, cls_name, extras=None, error=None): ``{"is_2D_scaled": self._is_2D_scaled}``. Values are inserted after ``fp8_dtype`` and before ``shape``. error : BaseException, optional - The exception that triggered the fallback. The expected trigger is the - data-pointer access error PyTorch raises for fake/functional tensors; - anything else is surfaced as a warning so that real eager-path failures - (e.g. CUDA OOM, shape bugs) are not silently swallowed by ``__repr__``. + The exception that triggered the fallback. When given, its type and + message are included in the ``data=`` field so that it is visible *why* + the data could not be materialized. """ - if error is not None and not _is_fake_data_access_error(error): - warnings.warn( - f"{cls_name}.__repr__ fell back to a metadata-only representation " - "because an unexpected error occurred while materializing data: " - f"{type(error).__name__}: {error}", - stacklevel=2, - ) parts = [] fp8_dtype = getattr(obj, "_fp8_dtype", None) if fp8_dtype is not None: @@ -137,5 +117,8 @@ def safe_quantized_repr(obj, cls_name, extras=None, error=None): parts.append(f"dtype={obj.dtype}") except Exception: # pylint: disable=broad-except pass - parts.append("data=") + if error is not None: + parts.append(f"data=") + else: + parts.append("data=") return f"{cls_name}({', '.join(parts)})"