From eb9a802a6543ab773cbb647d49ce474dcfdcd90c Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 12 Jun 2026 05:24:04 -0700 Subject: [PATCH 1/5] Avoid unpickling the extra state if not needed Signed-off-by: Przemek Tredak --- tests/pytorch/test_checkpoint.py | 15 +- tests/pytorch/test_fusible_ops.py | 12 +- tests/pytorch/test_numerics.py | 13 +- tests/pytorch/test_recipe.py | 85 ++++++++++ transformer_engine/common/recipe/__init__.py | 41 ++++- transformer_engine/pytorch/_extra_state.py | 168 +++++++++++++++++++ transformer_engine/pytorch/module/base.py | 28 +++- transformer_engine/pytorch/ops/op.py | 18 +- 8 files changed, 363 insertions(+), 17 deletions(-) create mode 100644 transformer_engine/pytorch/_extra_state.py diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 0427886b84..b2780d32c9 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -17,6 +17,7 @@ import transformer_engine.pytorch as te from utils import make_recipe +from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV # Check supported quantization schemes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -131,8 +132,18 @@ def test_module(self, name: str) -> None: raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}") state_dict = torch.load(checkpoint_file, weights_only=False) - # Update module from checkpoint - module.load_state_dict(state_dict, strict=True) + # Update module from checkpoint. Delayed-scaling legacy extra state is unsafe by + # default and requires an explicit opt-in for trusted compatibility artifacts. + old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) + if quantization == "fp8": + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1" + try: + module.load_state_dict(state_dict, strict=True) + finally: + if old_unsafe_extra_state is None: + os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) + else: + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state def main() -> None: diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7c75d11e3b..454f85c9a2 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -20,6 +20,7 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV from transformer_engine.pytorch.ops._common import ( _cudnn_frontend_supports_grouped_gemm_srelu, _cudnn_frontend_version_supported, @@ -3488,7 +3489,16 @@ def test_linear( ) optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25) state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False) - model_load.load_state_dict(state_dict["model"]) + old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) + if quantization in ("fp8", "fp8_delayed_scaling"): + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1" + try: + model_load.load_state_dict(state_dict["model"]) + finally: + if old_unsafe_extra_state is None: + os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) + else: + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state optim_load.load_state_dict(state_dict["optim"]) # Training steps with loaded model diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 739aefd1f3..f4d5e11ce6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -14,6 +14,7 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, ) +from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -847,7 +848,17 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= del block block = _test_e2e_checkpointing_get_model(config, dtype) - block.load_state_dict(torch.load(path, weights_only=False)) + loaded_state_dict = torch.load(path, weights_only=False) + old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) + if recipe is not None and recipe.delayed(): + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1" + try: + block.load_state_dict(loaded_state_dict) + finally: + if old_unsafe_extra_state is None: + os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) + else: + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state torch.set_rng_state(_cpu_rng_state) torch.cuda.set_rng_state(_cuda_rng_state) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index fd3c5a3463..81445af372 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -4,6 +4,8 @@ from typing import Optional +import pickle + import pytest import torch import warnings @@ -31,10 +33,19 @@ ) import transformer_engine.pytorch.ops as te_ops from transformer_engine.common.recipe import ( + CheckpointExtraStatePolicy, + CustomRecipe, DelayedScaling, + Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling, NVFP4BlockScaling, + Recipe, +) +from transformer_engine.pytorch._extra_state import ( + UNSAFE_PICKLE_EXTRA_STATE_ENV, + _RECIPE_POLICIES, + should_load_extra_state_pickle, ) # Check if FP8 is supported @@ -691,3 +702,77 @@ def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N): ) new_dequantized_tensor = new_tensor.dequantize() torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) + + +def _custom_recipe_qfactory(_role): + return None + + +def _recipe_subclasses(cls): + for subcls in cls.__subclasses__(): + yield subcls + yield from _recipe_subclasses(subcls) + + +def _pickled_extra_state_payload(recipe_obj, *, include_delayed_state=False): + state = {"recipe": recipe_obj, "extra_fp8_variables": {}} + if include_delayed_state: + state.update( + { + "scale_fwd": torch.ones(1), + "amax_history_fwd": torch.zeros(1, 1), + "scale_bwd": torch.ones(1), + "amax_history_bwd": torch.zeros(1, 1), + } + ) + return pickle.dumps(state) + + +def test_checkpoint_extra_state_policy_declared_for_all_recipes(): + for cls in _recipe_subclasses(Recipe): + assert "checkpoint_extra_state_policy" in cls.__dict__ + assert cls.checkpoint_extra_state_policy in CheckpointExtraStatePolicy + + +def test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes(): + for cls in _recipe_subclasses(Recipe): + assert ("transformer_engine.common.recipe", cls.__name__) in _RECIPE_POLICIES + + +@pytest.mark.parametrize( + "recipe_obj", + [ + Float8CurrentScaling(), + MXFP8BlockScaling(), + Float8BlockScaling(), + NVFP4BlockScaling(), + ], +) +def test_stateless_pickled_extra_state_is_ignored(recipe_obj): + payload = _pickled_extra_state_payload(recipe_obj) + assert not should_load_extra_state_pickle(payload, "test") + + +def test_stateless_custom_pickled_extra_state_is_ignored(): + payload = _pickled_extra_state_payload(CustomRecipe(qfactory=_custom_recipe_qfactory)) + assert not should_load_extra_state_pickle(payload, "test") + + +@pytest.mark.parametrize( + "payload", + [ + _pickled_extra_state_payload(DelayedScaling(), include_delayed_state=True), + _pickled_extra_state_payload( + CustomRecipe(qfactory=_custom_recipe_qfactory), include_delayed_state=True + ), + pickle.dumps({"scale_inv_fwd": torch.ones(1), "extra_fp8_variables": {}}), + pickle.dumps({"recipe": object(), "extra_fp8_variables": {}}), + b"not a pickle", + ], +) +def test_stateful_unknown_or_malformed_pickled_extra_state_requires_opt_in(payload, monkeypatch): + with pytest.raises(RuntimeError, match=UNSAFE_PICKLE_EXTRA_STATE_ENV): + should_load_extra_state_pickle(payload, "test") + + monkeypatch.setenv(UNSAFE_PICKLE_EXTRA_STATE_ENV, "1") + assert should_load_extra_state_pickle(payload, "test") diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 8a03f2f51a..c48a9990a7 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -7,7 +7,7 @@ import abc import os from enum import Enum -from typing import Any, Literal, Optional, Union, Callable, NamedTuple +from typing import Any, ClassVar, Literal, Optional, Union, Callable, NamedTuple from dataclasses import field from pydantic.dataclasses import dataclass @@ -50,6 +50,20 @@ class Format(Enum): HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) +class CheckpointExtraStatePolicy(Enum): + """How pickled PyTorch ``_extra_state`` should be handled in ``set_extra_state``. + + Each recipe subclass must choose a policy so checkpoint loading can decide + whether unpickling is required. ``DYNAMIC`` means the recipe class alone is + not enough; callers must inspect the checkpoint payload shape inside + ``set_extra_state`` before deciding whether the pickle can be ignored. + """ + + STATELESS = "stateless" + STATEFUL = "stateful" + DYNAMIC = "dynamic" + + @dataclass(frozen=True) class MMParams: """Matrix multiplication options. @@ -113,6 +127,7 @@ class Recipe: # subclasses and invalidated by ``__setattr__`` whenever any attribute # changes. This makes repeated ``str(recipe)`` calls much cheaper _cached_repr: Optional[str] = None + checkpoint_extra_state_policy: ClassVar[Optional[CheckpointExtraStatePolicy]] = None def __setattr__(self, name: str, value: Any) -> None: # Invalidate the cached repr on any attribute mutation. @@ -250,6 +265,10 @@ def scaling_factor_compute(amax: Tensor, subject to change in future Transformer Engine releases. """ + checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( + CheckpointExtraStatePolicy.STATEFUL + ) + margin: int = 0 fp8_format: Format = Format.HYBRID amax_history_len: int = 1024 @@ -299,6 +318,10 @@ class Float8CurrentScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ + checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( + CheckpointExtraStatePolicy.STATELESS + ) + use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" fp8_format: Format = Format.HYBRID fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) @@ -363,6 +386,10 @@ class MXFP8BlockScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ + checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( + CheckpointExtraStatePolicy.STATELESS + ) + margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False @@ -416,6 +443,10 @@ class Float8BlockScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ + checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( + CheckpointExtraStatePolicy.STATELESS + ) + use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" fp8_format: Format = Format.E4M3 @@ -544,6 +575,10 @@ class NVFP4BlockScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ + checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( + CheckpointExtraStatePolicy.STATELESS + ) + # Configuration envvars disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1" disable_stochastic_rounding: bool = ( @@ -662,6 +697,10 @@ class CustomRecipe(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ + checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( + CheckpointExtraStatePolicy.DYNAMIC + ) + qfactory: Callable[..., Any] # fp8_format does not affect quantization (quantization factory controls that), diff --git a/transformer_engine/pytorch/_extra_state.py b/transformer_engine/pytorch/_extra_state.py new file mode 100644 index 0000000000..989a02503a --- /dev/null +++ b/transformer_engine/pytorch/_extra_state.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helpers for Transformer Engine PyTorch extra-state checkpoint handling.""" + +from __future__ import annotations + +from enum import Enum +import os +import pickletools +from typing import Optional + +from ..common.recipe import CheckpointExtraStatePolicy, Recipe + + +UNSAFE_PICKLE_EXTRA_STATE_ENV = "NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE" + +_RECIPE_MODULE = "transformer_engine.common.recipe" +_RECIPE_KEY = "recipe" +_DELAYED_STATE_KEYS = { + "scale_fwd", + "amax_history_fwd", + "scale_bwd", + "amax_history_bwd", +} + + +def _recipe_subclasses(cls: type[Recipe]) -> list[type[Recipe]]: + subclasses = [] + for subcls in cls.__subclasses__(): + subclasses.append(subcls) + subclasses.extend(_recipe_subclasses(subcls)) + return subclasses + + +_RECIPE_POLICIES = { + (_RECIPE_MODULE, cls.__name__): cls.checkpoint_extra_state_policy + for cls in _recipe_subclasses(Recipe) + if cls.checkpoint_extra_state_policy is not None +} + + +class _PickledExtraStateAction(Enum): + """Action to take for a pickled extra-state payload.""" + + IGNORE = "ignore" + UNSAFE_LOAD = "unsafe_load" + + +def unsafe_pickle_extra_state_enabled() -> bool: + """Return whether unsafe extra-state pickle loading is enabled.""" + return os.getenv(UNSAFE_PICKLE_EXTRA_STATE_ENV, "0") == "1" + + +def extra_state_pickle_advisory(context: str) -> str: + """Security advisory for pickled extra state.""" + return ( + f"Refusing to load pickled Transformer Engine extra state for {context}. " + "Delayed-scaling FP8 metadata can be stored as a Python pickle, and loading it " + "can execute arbitrary code. Only enable unsafe loading if this checkpoint is from " + f"a trusted source. To load it anyway, set {UNSAFE_PICKLE_EXTRA_STATE_ENV}=1." + ) + + +def should_load_extra_state_pickle(data: bytes, context: str) -> bool: + """Return whether callers should use the unsafe pickle loader. + + ``False`` means the payload was identified as empty/stateless and should be + ignored. ``True`` means the caller may unpickle because the unsafe opt-in is + enabled. Otherwise this raises with the security advisory. + """ + action = _classify_extra_state_pickle(data) + if action is _PickledExtraStateAction.IGNORE: + return False + if unsafe_pickle_extra_state_enabled(): + return True + raise RuntimeError(extra_state_pickle_advisory(context)) + + +def _classify_extra_state_pickle(data: bytes) -> _PickledExtraStateAction: + """Classify a pickled extra-state payload without executing it.""" + if not data: + return _PickledExtraStateAction.IGNORE + + try: + return _classify_extra_state_pickle_impl(data) + except Exception: # pylint: disable=broad-except + return _PickledExtraStateAction.UNSAFE_LOAD + + +def _classify_extra_state_pickle_impl(data: bytes) -> _PickledExtraStateAction: + strings: list[str] = [] + has_recipe_key = False + has_delayed_state_keys = False + policies: set[CheckpointExtraStatePolicy] = set() + + for opcode, arg, _pos in pickletools.genops(data): + if opcode.name in { + "STRING", + "BINSTRING", + "SHORT_BINSTRING", + "UNICODE", + "BINUNICODE", + "BINUNICODE8", + "SHORT_BINUNICODE", + }: + text = _string_opcode_arg(arg) + if text is not None: + strings.append(text) + has_recipe_key = has_recipe_key or text == _RECIPE_KEY + has_delayed_state_keys = has_delayed_state_keys or text in _DELAYED_STATE_KEYS + continue + + if opcode.name == "GLOBAL": + global_ref = _global_opcode_arg(arg) + elif opcode.name == "STACK_GLOBAL": + global_ref = _stack_global_args(strings) + else: + continue + + if global_ref is None: + continue + policy = _RECIPE_POLICIES.get(global_ref) + if policy is not None: + policies.add(policy) + + # TE 1.x checkpoints did not store a recipe and only supported delayed scaling. + if not has_recipe_key: + return _PickledExtraStateAction.UNSAFE_LOAD + + if CheckpointExtraStatePolicy.STATEFUL in policies: + return _PickledExtraStateAction.UNSAFE_LOAD + + if has_delayed_state_keys: + return _PickledExtraStateAction.UNSAFE_LOAD + + # Unknown/newer payload shape. Give trusted users the explicit opt-in path. + if not policies: + return _PickledExtraStateAction.UNSAFE_LOAD + + return _PickledExtraStateAction.IGNORE + + +def _string_opcode_arg(arg: object) -> Optional[str]: + if isinstance(arg, str): + return arg + if isinstance(arg, bytes): + try: + return arg.decode("utf-8") + except UnicodeDecodeError: + return None + return None + + +def _global_opcode_arg(arg: object) -> Optional[tuple[str, str]]: + if not isinstance(arg, str): + return None + parts = arg.split() + if len(parts) != 2: + return None + return parts[0], parts[1] + + +def _stack_global_args(strings: list[str]) -> Optional[tuple[str, str]]: + if len(strings) < 2: + return None + return strings[-2], strings[-1] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c7ba8a8ab..358298cb7e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,6 +21,11 @@ import transformer_engine_torch as tex from ._common import _ParameterInitMeta, noop_cat +from .._extra_state import ( + extra_state_pickle_advisory, + should_load_extra_state_pickle, + unsafe_pickle_extra_state_enabled, +) from ..quantization import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, @@ -57,7 +62,7 @@ nvtx_range_pop, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage -from ...common.recipe import DelayedScaling, Recipe +from ...common.recipe import CheckpointExtraStatePolicy, DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled @@ -1271,9 +1276,13 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: if not fp8_checkpoint: return torch.empty(0, dtype=torch.uint8) + recipe = self.fp8_meta["recipe"] + if recipe.checkpoint_extra_state_policy is CheckpointExtraStatePolicy.STATELESS: + return torch.empty(0, dtype=torch.uint8) + # Copy tensors to CPU and store state = {} - state["recipe"] = self.fp8_meta["recipe"] + state["recipe"] = recipe if _has_delayed_scaling_state(self.fp8_meta): state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) @@ -1303,16 +1312,21 @@ def set_extra_state(self, state: torch.Tensor) -> None: return # Load state + context = self.__class__.__name__ if isinstance(state, torch.Tensor): - # No FP8 is indicated by an empty tensor we don't need to unpickle. + # Empty extra state does not need pickle handling. if state.numel() == 0: return - # Default format: byte tensor with pickled data - state = pickle.loads(state.detach().cpu().numpy().tobytes()) + state_bytes = state.detach().cpu().numpy().tobytes() + if not should_load_extra_state_pickle(state_bytes, context): + return + state = pickle.loads(state_bytes) elif isinstance(state, io.BytesIO): - # Deprecated format with io.BytesIO + # Deprecated format with io.BytesIO. Treat it as unsafe pickle. + if not unsafe_pickle_extra_state_enabled(): + raise RuntimeError(extra_state_pickle_advisory(context)) state.seek(0) - state = torch.load(state, map_location="cuda") + state = torch.load(state, map_location="cuda", weights_only=False) else: raise RuntimeError("Unsupported checkpoint format.") diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 86bd60ed9c..63e5b157f1 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -13,7 +13,8 @@ import torch -from transformer_engine.common.recipe import Recipe +from transformer_engine.common.recipe import CheckpointExtraStatePolicy, Recipe +from .._extra_state import should_load_extra_state_pickle from ..quantization import ( FP8GlobalStateManager, QuantizerRole, @@ -590,11 +591,14 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Quantizer state fp8_meta = self._fp8_metas[mode] + recipe = fp8_meta["recipe"] + if recipe.checkpoint_extra_state_policy is CheckpointExtraStatePolicy.STATELESS: + continue state[mode] = {} - state[mode]["recipe"] = fp8_meta["recipe"] + state[mode]["recipe"] = recipe # Copy tensors to CPU and store - if state[mode]["recipe"].delayed(): + if recipe.delayed(): if mode == "forward": state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) @@ -626,8 +630,12 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None: if state is None or state.numel() == 0: return - # Deserialize state from byte tensor - state = pickle.loads(state.detach().numpy(force=True).tobytes()) + # Deserialize state from byte tensor only when unsafe loading is enabled. + state_bytes = state.detach().numpy(force=True).tobytes() + context = self.__class__.__name__ + if not should_load_extra_state_pickle(state_bytes, context): + return + state = pickle.loads(state_bytes) if state is None or len(state) == 0: return From 2d4eefee0d2e1159f9f3e35d38b964eb8e01b8e2 Mon Sep 17 00:00:00 2001 From: ksivamani Date: Mon, 22 Jun 2026 17:18:52 -0700 Subject: [PATCH 2/5] Remove pytorch changes from common and have better names Signed-off-by: ksivamani --- tests/pytorch/test_recipe.py | 12 ++-- transformer_engine/common/recipe/__init__.py | 41 +------------ transformer_engine/pytorch/_extra_state.py | 62 +++++++++++++++----- transformer_engine/pytorch/module/base.py | 5 +- transformer_engine/pytorch/ops/op.py | 6 +- 5 files changed, 59 insertions(+), 67 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 81445af372..7ee12f7def 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -33,7 +33,6 @@ ) import transformer_engine.pytorch.ops as te_ops from transformer_engine.common.recipe import ( - CheckpointExtraStatePolicy, CustomRecipe, DelayedScaling, Float8CurrentScaling, @@ -43,6 +42,7 @@ Recipe, ) from transformer_engine.pytorch._extra_state import ( + CheckpointExtraStatePolicy, UNSAFE_PICKLE_EXTRA_STATE_ENV, _RECIPE_POLICIES, should_load_extra_state_pickle, @@ -728,15 +728,11 @@ def _pickled_extra_state_payload(recipe_obj, *, include_delayed_state=False): return pickle.dumps(state) -def test_checkpoint_extra_state_policy_declared_for_all_recipes(): - for cls in _recipe_subclasses(Recipe): - assert "checkpoint_extra_state_policy" in cls.__dict__ - assert cls.checkpoint_extra_state_policy in CheckpointExtraStatePolicy - - def test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes(): for cls in _recipe_subclasses(Recipe): - assert ("transformer_engine.common.recipe", cls.__name__) in _RECIPE_POLICIES + key = ("transformer_engine.common.recipe", cls.__name__) + assert key in _RECIPE_POLICIES + assert _RECIPE_POLICIES[key] in CheckpointExtraStatePolicy @pytest.mark.parametrize( diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index c48a9990a7..8a03f2f51a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -7,7 +7,7 @@ import abc import os from enum import Enum -from typing import Any, ClassVar, Literal, Optional, Union, Callable, NamedTuple +from typing import Any, Literal, Optional, Union, Callable, NamedTuple from dataclasses import field from pydantic.dataclasses import dataclass @@ -50,20 +50,6 @@ class Format(Enum): HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) -class CheckpointExtraStatePolicy(Enum): - """How pickled PyTorch ``_extra_state`` should be handled in ``set_extra_state``. - - Each recipe subclass must choose a policy so checkpoint loading can decide - whether unpickling is required. ``DYNAMIC`` means the recipe class alone is - not enough; callers must inspect the checkpoint payload shape inside - ``set_extra_state`` before deciding whether the pickle can be ignored. - """ - - STATELESS = "stateless" - STATEFUL = "stateful" - DYNAMIC = "dynamic" - - @dataclass(frozen=True) class MMParams: """Matrix multiplication options. @@ -127,7 +113,6 @@ class Recipe: # subclasses and invalidated by ``__setattr__`` whenever any attribute # changes. This makes repeated ``str(recipe)`` calls much cheaper _cached_repr: Optional[str] = None - checkpoint_extra_state_policy: ClassVar[Optional[CheckpointExtraStatePolicy]] = None def __setattr__(self, name: str, value: Any) -> None: # Invalidate the cached repr on any attribute mutation. @@ -265,10 +250,6 @@ def scaling_factor_compute(amax: Tensor, subject to change in future Transformer Engine releases. """ - checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( - CheckpointExtraStatePolicy.STATEFUL - ) - margin: int = 0 fp8_format: Format = Format.HYBRID amax_history_len: int = 1024 @@ -318,10 +299,6 @@ class Float8CurrentScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ - checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( - CheckpointExtraStatePolicy.STATELESS - ) - use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" fp8_format: Format = Format.HYBRID fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) @@ -386,10 +363,6 @@ class MXFP8BlockScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ - checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( - CheckpointExtraStatePolicy.STATELESS - ) - margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False @@ -443,10 +416,6 @@ class Float8BlockScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ - checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( - CheckpointExtraStatePolicy.STATELESS - ) - use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" fp8_format: Format = Format.E4M3 @@ -575,10 +544,6 @@ class NVFP4BlockScaling(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ - checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( - CheckpointExtraStatePolicy.STATELESS - ) - # Configuration envvars disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1" disable_stochastic_rounding: bool = ( @@ -697,10 +662,6 @@ class CustomRecipe(Recipe): compute dtype (e.g. BF16/FP16/FP32) for backward. """ - checkpoint_extra_state_policy: ClassVar[CheckpointExtraStatePolicy] = ( - CheckpointExtraStatePolicy.DYNAMIC - ) - qfactory: Callable[..., Any] # fp8_format does not affect quantization (quantization factory controls that), diff --git a/transformer_engine/pytorch/_extra_state.py b/transformer_engine/pytorch/_extra_state.py index 989a02503a..a729428091 100644 --- a/transformer_engine/pytorch/_extra_state.py +++ b/transformer_engine/pytorch/_extra_state.py @@ -11,14 +11,18 @@ import pickletools from typing import Optional -from ..common.recipe import CheckpointExtraStatePolicy, Recipe +from ..common.recipe import Recipe UNSAFE_PICKLE_EXTRA_STATE_ENV = "NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE" _RECIPE_MODULE = "transformer_engine.common.recipe" _RECIPE_KEY = "recipe" -_DELAYED_STATE_KEYS = { + +# Pickle keys that FP8 delayed scaling stores in its ``_extra_state``. This is +# purely a backward-compatibility hack for that recipe; any future stateful +# recipe is expected to checkpoint without pickling. +_FLOAT8_DELAYED_SCALING_STATE_KEYS = { "scale_fwd", "amax_history_fwd", "scale_bwd", @@ -26,21 +30,49 @@ } -def _recipe_subclasses(cls: type[Recipe]) -> list[type[Recipe]]: - subclasses = [] - for subcls in cls.__subclasses__(): - subclasses.append(subcls) - subclasses.extend(_recipe_subclasses(subcls)) - return subclasses +class CheckpointExtraStatePolicy(Enum): + """How pickled PyTorch ``_extra_state`` should be handled in ``set_extra_state``. + + Pickling of ``_extra_state`` is a PyTorch-specific backward-compatibility + concern, so the recipe-to-policy map lives here in ``te.pytorch`` rather than + in ``te.common``. + ``STATELESS`` recipes carry no checkpoint state and never need unpickling. + ``STATEFUL_FP8_DELAYED_SCALING`` is the only recipe that still relies on + unsafe pickling (a legacy of FP8 delayed scaling). ``DYNAMIC`` means the + recipe class alone is not enough; callers must inspect the checkpoint payload + shape before deciding whether the pickle can be ignored. + """ -_RECIPE_POLICIES = { - (_RECIPE_MODULE, cls.__name__): cls.checkpoint_extra_state_policy - for cls in _recipe_subclasses(Recipe) - if cls.checkpoint_extra_state_policy is not None + STATELESS = "stateless" + STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling" + DYNAMIC = "dynamic" + + +# Map of first-party recipes to their checkpoint policy. When a new stateful +# recipe is added, update this map (and any associated checkpoint handling) +# here instead of adding PyTorch-specific logic to ``te.common``. +_RECIPE_POLICIES: dict[tuple[str, str], CheckpointExtraStatePolicy] = { + (_RECIPE_MODULE, "DelayedScaling"): CheckpointExtraStatePolicy.STATEFUL_FP8_DELAYED_SCALING, + (_RECIPE_MODULE, "Float8CurrentScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "MXFP8BlockScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "Float8BlockScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "NVFP4BlockScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "CustomRecipe"): CheckpointExtraStatePolicy.DYNAMIC, } +def recipe_extra_state_policy(recipe: Recipe) -> Optional[CheckpointExtraStatePolicy]: + """Return the checkpoint policy for a recipe instance, if known.""" + cls = type(recipe) + return _RECIPE_POLICIES.get((cls.__module__, cls.__name__)) + + +def is_stateless_recipe(recipe: Recipe) -> bool: + """Return whether a recipe carries no extra state to checkpoint.""" + return recipe_extra_state_policy(recipe) is CheckpointExtraStatePolicy.STATELESS + + class _PickledExtraStateAction(Enum): """Action to take for a pickled extra-state payload.""" @@ -109,7 +141,9 @@ def _classify_extra_state_pickle_impl(data: bytes) -> _PickledExtraStateAction: if text is not None: strings.append(text) has_recipe_key = has_recipe_key or text == _RECIPE_KEY - has_delayed_state_keys = has_delayed_state_keys or text in _DELAYED_STATE_KEYS + has_delayed_state_keys = ( + has_delayed_state_keys or text in _FLOAT8_DELAYED_SCALING_STATE_KEYS + ) continue if opcode.name == "GLOBAL": @@ -129,7 +163,7 @@ def _classify_extra_state_pickle_impl(data: bytes) -> _PickledExtraStateAction: if not has_recipe_key: return _PickledExtraStateAction.UNSAFE_LOAD - if CheckpointExtraStatePolicy.STATEFUL in policies: + if CheckpointExtraStatePolicy.STATEFUL_FP8_DELAYED_SCALING in policies: return _PickledExtraStateAction.UNSAFE_LOAD if has_delayed_state_keys: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 358298cb7e..a449d2dc87 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -23,6 +23,7 @@ from ._common import _ParameterInitMeta, noop_cat from .._extra_state import ( extra_state_pickle_advisory, + is_stateless_recipe, should_load_extra_state_pickle, unsafe_pickle_extra_state_enabled, ) @@ -62,7 +63,7 @@ nvtx_range_pop, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage -from ...common.recipe import CheckpointExtraStatePolicy, DelayedScaling, Recipe +from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled @@ -1277,7 +1278,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: return torch.empty(0, dtype=torch.uint8) recipe = self.fp8_meta["recipe"] - if recipe.checkpoint_extra_state_policy is CheckpointExtraStatePolicy.STATELESS: + if is_stateless_recipe(recipe): return torch.empty(0, dtype=torch.uint8) # Copy tensors to CPU and store diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 63e5b157f1..849c900f95 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -13,8 +13,8 @@ import torch -from transformer_engine.common.recipe import CheckpointExtraStatePolicy, Recipe -from .._extra_state import should_load_extra_state_pickle +from transformer_engine.common.recipe import Recipe +from .._extra_state import is_stateless_recipe, should_load_extra_state_pickle from ..quantization import ( FP8GlobalStateManager, QuantizerRole, @@ -592,7 +592,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Quantizer state fp8_meta = self._fp8_metas[mode] recipe = fp8_meta["recipe"] - if recipe.checkpoint_extra_state_policy is CheckpointExtraStatePolicy.STATELESS: + if is_stateless_recipe(recipe): continue state[mode] = {} state[mode]["recipe"] = recipe From de1318452cebed53f26a0dae7b180b4c11e35bf8 Mon Sep 17 00:00:00 2001 From: ksivamani Date: Tue, 23 Jun 2026 09:13:24 -0700 Subject: [PATCH 3/5] Fix envvar for tests Signed-off-by: ksivamani --- qa/L0_pytorch_unittest/test.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 92f73d5885..37e4ca9822 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -28,7 +28,7 @@ NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_L python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" @@ -42,15 +42,15 @@ NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_L python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_flex_attention.xml $TE_PATH/tests/pytorch/attention/test_flex_attention.py || test_fail "test_flex_attention.py" -NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" +NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_linear_mxfp8_attention.xml $TE_PATH/tests/pytorch/attention/test_linear_mxfp8_attention.py || test_fail "test_linear_mxfp8_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" @@ -58,7 +58,7 @@ export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_ if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files" fi -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" # Disable autotuning to make unittests faster. In addition, disable TF32 path to fully align with the pytorch reference implementation's precision From 5b54979aae82b58100a9ca7c5ae45626e4ff4f48 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 23 Jun 2026 20:30:59 +0000 Subject: [PATCH 4/5] Fix tests Signed-off-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_unittest/test.sh | 6 +++--- tests/pytorch/test_fusible_ops.py | 1 + tests/pytorch/test_numerics.py | 2 -- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 37e4ca9822..3f29c3e112 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -28,7 +28,7 @@ NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_L python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" @@ -42,7 +42,7 @@ NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_L python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" @@ -58,7 +58,7 @@ export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_ if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files" fi -NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" # Disable autotuning to make unittests faster. In addition, disable TF32 path to fully align with the pytorch reference implementation's precision diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index c06ec6c881..3c6560e7c3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -6,6 +6,7 @@ from collections.abc import Iterable, Sequence import io +import os import math import random from typing import Optional diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index f4d5e11ce6..8249c7fedd 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -850,8 +850,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= block = _test_e2e_checkpointing_get_model(config, dtype) loaded_state_dict = torch.load(path, weights_only=False) old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) - if recipe is not None and recipe.delayed(): - os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1" try: block.load_state_dict(loaded_state_dict) finally: From 86e032140ba3c511e36a307a772dece1687f3ecf Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 25 Jun 2026 03:46:14 +0000 Subject: [PATCH 5/5] Fix Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_recipe.py | 8 ++++++++ transformer_engine/pytorch/_extra_state.py | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 7ee12f7def..ccef104a33 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -754,6 +754,14 @@ def test_stateless_custom_pickled_extra_state_is_ignored(): assert not should_load_extra_state_pickle(payload, "test") +@pytest.mark.parametrize("payload", [pickle.dumps({}), pickle.dumps({"extra_fp8_variables": {}})]) +def test_global_free_pickled_extra_state_is_ignored(payload): + # Older stateless checkpoints serialized an empty dict. Such a payload + # resolves no globals and cannot execute code, so it must load without the + # unsafe opt-in. + assert not should_load_extra_state_pickle(payload, "test") + + @pytest.mark.parametrize( "payload", [ diff --git a/transformer_engine/pytorch/_extra_state.py b/transformer_engine/pytorch/_extra_state.py index a729428091..831763e891 100644 --- a/transformer_engine/pytorch/_extra_state.py +++ b/transformer_engine/pytorch/_extra_state.py @@ -125,6 +125,7 @@ def _classify_extra_state_pickle_impl(data: bytes) -> _PickledExtraStateAction: strings: list[str] = [] has_recipe_key = False has_delayed_state_keys = False + has_global = False policies: set[CheckpointExtraStatePolicy] = set() for opcode, arg, _pos in pickletools.genops(data): @@ -147,8 +148,10 @@ def _classify_extra_state_pickle_impl(data: bytes) -> _PickledExtraStateAction: continue if opcode.name == "GLOBAL": + has_global = True global_ref = _global_opcode_arg(arg) elif opcode.name == "STACK_GLOBAL": + has_global = True global_ref = _stack_global_args(strings) else: continue @@ -159,6 +162,14 @@ def _classify_extra_state_pickle_impl(data: bytes) -> _PickledExtraStateAction: if policy is not None: policies.add(policy) + # A payload that never resolves a global cannot construct an arbitrary + # callable, so unpickling it cannot execute code (e.g. the empty dict that + # older stateless checkpoints serialized). It carries no state worth + # loading, so treat it as safe to ignore. A genuine TE 1.x delayed-scaling + # checkpoint always serializes torch tensors and thus contains globals. + if not has_global and not has_delayed_state_keys: + return _PickledExtraStateAction.IGNORE + # TE 1.x checkpoints did not store a recipe and only supported delayed scaling. if not has_recipe_key: return _PickledExtraStateAction.UNSAFE_LOAD