diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 3ce7e05e0c..7a7e2b9c5f 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -49,9 +49,9 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_flex_attention.xml $TE_PATH/tests/pytorch/attention/test_flex_attention.py || test_fail "test_flex_attention.py" -NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" +NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_linear_mxfp8_attention.xml $TE_PATH/tests/pytorch/attention/test_linear_mxfp8_attention.py || test_fail "test_linear_mxfp8_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 0427886b84..b2780d32c9 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -17,6 +17,7 @@ import transformer_engine.pytorch as te from utils import make_recipe +from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV # Check supported quantization schemes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -131,8 +132,18 @@ def test_module(self, name: str) -> None: raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}") state_dict = torch.load(checkpoint_file, weights_only=False) - # Update module from checkpoint - module.load_state_dict(state_dict, strict=True) + # Update module from checkpoint. Delayed-scaling legacy extra state is unsafe by + # default and requires an explicit opt-in for trusted compatibility artifacts. + old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) + if quantization == "fp8": + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1" + try: + module.load_state_dict(state_dict, strict=True) + finally: + if old_unsafe_extra_state is None: + os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) + else: + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state def main() -> None: diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 43c7965518..3c6560e7c3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -6,6 +6,7 @@ from collections.abc import Iterable, Sequence import io +import os import math import random from typing import Optional @@ -18,6 +19,7 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, @@ -3217,7 +3219,16 @@ def test_linear( ) optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25) state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False) - model_load.load_state_dict(state_dict["model"]) + old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) + if quantization in ("fp8", "fp8_delayed_scaling"): + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1" + try: + model_load.load_state_dict(state_dict["model"]) + finally: + if old_unsafe_extra_state is None: + os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) + else: + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state optim_load.load_state_dict(state_dict["optim"]) # Training steps with loaded model diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 739aefd1f3..8249c7fedd 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -14,6 +14,7 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, ) +from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -847,7 +848,15 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= del block block = _test_e2e_checkpointing_get_model(config, dtype) - block.load_state_dict(torch.load(path, weights_only=False)) + loaded_state_dict = torch.load(path, weights_only=False) + old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) + try: + block.load_state_dict(loaded_state_dict) + finally: + if old_unsafe_extra_state is None: + os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) + else: + os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state torch.set_rng_state(_cpu_rng_state) torch.cuda.set_rng_state(_cuda_rng_state) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index fd3c5a3463..ccef104a33 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -4,6 +4,8 @@ from typing import Optional +import pickle + import pytest import torch import warnings @@ -31,10 +33,19 @@ ) import transformer_engine.pytorch.ops as te_ops from transformer_engine.common.recipe import ( + CustomRecipe, DelayedScaling, + Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling, NVFP4BlockScaling, + Recipe, +) +from transformer_engine.pytorch._extra_state import ( + CheckpointExtraStatePolicy, + UNSAFE_PICKLE_EXTRA_STATE_ENV, + _RECIPE_POLICIES, + should_load_extra_state_pickle, ) # Check if FP8 is supported @@ -691,3 +702,81 @@ def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N): ) new_dequantized_tensor = new_tensor.dequantize() torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) + + +def _custom_recipe_qfactory(_role): + return None + + +def _recipe_subclasses(cls): + for subcls in cls.__subclasses__(): + yield subcls + yield from _recipe_subclasses(subcls) + + +def _pickled_extra_state_payload(recipe_obj, *, include_delayed_state=False): + state = {"recipe": recipe_obj, "extra_fp8_variables": {}} + if include_delayed_state: + state.update( + { + "scale_fwd": torch.ones(1), + "amax_history_fwd": torch.zeros(1, 1), + "scale_bwd": torch.ones(1), + "amax_history_bwd": torch.zeros(1, 1), + } + ) + return pickle.dumps(state) + + +def test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes(): + for cls in _recipe_subclasses(Recipe): + key = ("transformer_engine.common.recipe", cls.__name__) + assert key in _RECIPE_POLICIES + assert _RECIPE_POLICIES[key] in CheckpointExtraStatePolicy + + +@pytest.mark.parametrize( + "recipe_obj", + [ + Float8CurrentScaling(), + MXFP8BlockScaling(), + Float8BlockScaling(), + NVFP4BlockScaling(), + ], +) +def test_stateless_pickled_extra_state_is_ignored(recipe_obj): + payload = _pickled_extra_state_payload(recipe_obj) + assert not should_load_extra_state_pickle(payload, "test") + + +def test_stateless_custom_pickled_extra_state_is_ignored(): + payload = _pickled_extra_state_payload(CustomRecipe(qfactory=_custom_recipe_qfactory)) + assert not should_load_extra_state_pickle(payload, "test") + + +@pytest.mark.parametrize("payload", [pickle.dumps({}), pickle.dumps({"extra_fp8_variables": {}})]) +def test_global_free_pickled_extra_state_is_ignored(payload): + # Older stateless checkpoints serialized an empty dict. Such a payload + # resolves no globals and cannot execute code, so it must load without the + # unsafe opt-in. + assert not should_load_extra_state_pickle(payload, "test") + + +@pytest.mark.parametrize( + "payload", + [ + _pickled_extra_state_payload(DelayedScaling(), include_delayed_state=True), + _pickled_extra_state_payload( + CustomRecipe(qfactory=_custom_recipe_qfactory), include_delayed_state=True + ), + pickle.dumps({"scale_inv_fwd": torch.ones(1), "extra_fp8_variables": {}}), + pickle.dumps({"recipe": object(), "extra_fp8_variables": {}}), + b"not a pickle", + ], +) +def test_stateful_unknown_or_malformed_pickled_extra_state_requires_opt_in(payload, monkeypatch): + with pytest.raises(RuntimeError, match=UNSAFE_PICKLE_EXTRA_STATE_ENV): + should_load_extra_state_pickle(payload, "test") + + monkeypatch.setenv(UNSAFE_PICKLE_EXTRA_STATE_ENV, "1") + assert should_load_extra_state_pickle(payload, "test") diff --git a/transformer_engine/pytorch/_extra_state.py b/transformer_engine/pytorch/_extra_state.py new file mode 100644 index 0000000000..831763e891 --- /dev/null +++ b/transformer_engine/pytorch/_extra_state.py @@ -0,0 +1,213 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helpers for Transformer Engine PyTorch extra-state checkpoint handling.""" + +from __future__ import annotations + +from enum import Enum +import os +import pickletools +from typing import Optional + +from ..common.recipe import Recipe + + +UNSAFE_PICKLE_EXTRA_STATE_ENV = "NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE" + +_RECIPE_MODULE = "transformer_engine.common.recipe" +_RECIPE_KEY = "recipe" + +# Pickle keys that FP8 delayed scaling stores in its ``_extra_state``. This is +# purely a backward-compatibility hack for that recipe; any future stateful +# recipe is expected to checkpoint without pickling. +_FLOAT8_DELAYED_SCALING_STATE_KEYS = { + "scale_fwd", + "amax_history_fwd", + "scale_bwd", + "amax_history_bwd", +} + + +class CheckpointExtraStatePolicy(Enum): + """How pickled PyTorch ``_extra_state`` should be handled in ``set_extra_state``. + + Pickling of ``_extra_state`` is a PyTorch-specific backward-compatibility + concern, so the recipe-to-policy map lives here in ``te.pytorch`` rather than + in ``te.common``. + + ``STATELESS`` recipes carry no checkpoint state and never need unpickling. + ``STATEFUL_FP8_DELAYED_SCALING`` is the only recipe that still relies on + unsafe pickling (a legacy of FP8 delayed scaling). ``DYNAMIC`` means the + recipe class alone is not enough; callers must inspect the checkpoint payload + shape before deciding whether the pickle can be ignored. + """ + + STATELESS = "stateless" + STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling" + DYNAMIC = "dynamic" + + +# Map of first-party recipes to their checkpoint policy. When a new stateful +# recipe is added, update this map (and any associated checkpoint handling) +# here instead of adding PyTorch-specific logic to ``te.common``. +_RECIPE_POLICIES: dict[tuple[str, str], CheckpointExtraStatePolicy] = { + (_RECIPE_MODULE, "DelayedScaling"): CheckpointExtraStatePolicy.STATEFUL_FP8_DELAYED_SCALING, + (_RECIPE_MODULE, "Float8CurrentScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "MXFP8BlockScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "Float8BlockScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "NVFP4BlockScaling"): CheckpointExtraStatePolicy.STATELESS, + (_RECIPE_MODULE, "CustomRecipe"): CheckpointExtraStatePolicy.DYNAMIC, +} + + +def recipe_extra_state_policy(recipe: Recipe) -> Optional[CheckpointExtraStatePolicy]: + """Return the checkpoint policy for a recipe instance, if known.""" + cls = type(recipe) + return _RECIPE_POLICIES.get((cls.__module__, cls.__name__)) + + +def is_stateless_recipe(recipe: Recipe) -> bool: + """Return whether a recipe carries no extra state to checkpoint.""" + return recipe_extra_state_policy(recipe) is CheckpointExtraStatePolicy.STATELESS + + +class _PickledExtraStateAction(Enum): + """Action to take for a pickled extra-state payload.""" + + IGNORE = "ignore" + UNSAFE_LOAD = "unsafe_load" + + +def unsafe_pickle_extra_state_enabled() -> bool: + """Return whether unsafe extra-state pickle loading is enabled.""" + return os.getenv(UNSAFE_PICKLE_EXTRA_STATE_ENV, "0") == "1" + + +def extra_state_pickle_advisory(context: str) -> str: + """Security advisory for pickled extra state.""" + return ( + f"Refusing to load pickled Transformer Engine extra state for {context}. " + "Delayed-scaling FP8 metadata can be stored as a Python pickle, and loading it " + "can execute arbitrary code. Only enable unsafe loading if this checkpoint is from " + f"a trusted source. To load it anyway, set {UNSAFE_PICKLE_EXTRA_STATE_ENV}=1." + ) + + +def should_load_extra_state_pickle(data: bytes, context: str) -> bool: + """Return whether callers should use the unsafe pickle loader. + + ``False`` means the payload was identified as empty/stateless and should be + ignored. ``True`` means the caller may unpickle because the unsafe opt-in is + enabled. Otherwise this raises with the security advisory. + """ + action = _classify_extra_state_pickle(data) + if action is _PickledExtraStateAction.IGNORE: + return False + if unsafe_pickle_extra_state_enabled(): + return True + raise RuntimeError(extra_state_pickle_advisory(context)) + + +def _classify_extra_state_pickle(data: bytes) -> _PickledExtraStateAction: + """Classify a pickled extra-state payload without executing it.""" + if not data: + return _PickledExtraStateAction.IGNORE + + try: + return _classify_extra_state_pickle_impl(data) + except Exception: # pylint: disable=broad-except + return _PickledExtraStateAction.UNSAFE_LOAD + + +def _classify_extra_state_pickle_impl(data: bytes) -> _PickledExtraStateAction: + strings: list[str] = [] + has_recipe_key = False + has_delayed_state_keys = False + has_global = False + policies: set[CheckpointExtraStatePolicy] = set() + + for opcode, arg, _pos in pickletools.genops(data): + if opcode.name in { + "STRING", + "BINSTRING", + "SHORT_BINSTRING", + "UNICODE", + "BINUNICODE", + "BINUNICODE8", + "SHORT_BINUNICODE", + }: + text = _string_opcode_arg(arg) + if text is not None: + strings.append(text) + has_recipe_key = has_recipe_key or text == _RECIPE_KEY + has_delayed_state_keys = ( + has_delayed_state_keys or text in _FLOAT8_DELAYED_SCALING_STATE_KEYS + ) + continue + + if opcode.name == "GLOBAL": + has_global = True + global_ref = _global_opcode_arg(arg) + elif opcode.name == "STACK_GLOBAL": + has_global = True + global_ref = _stack_global_args(strings) + else: + continue + + if global_ref is None: + continue + policy = _RECIPE_POLICIES.get(global_ref) + if policy is not None: + policies.add(policy) + + # A payload that never resolves a global cannot construct an arbitrary + # callable, so unpickling it cannot execute code (e.g. the empty dict that + # older stateless checkpoints serialized). It carries no state worth + # loading, so treat it as safe to ignore. A genuine TE 1.x delayed-scaling + # checkpoint always serializes torch tensors and thus contains globals. + if not has_global and not has_delayed_state_keys: + return _PickledExtraStateAction.IGNORE + + # TE 1.x checkpoints did not store a recipe and only supported delayed scaling. + if not has_recipe_key: + return _PickledExtraStateAction.UNSAFE_LOAD + + if CheckpointExtraStatePolicy.STATEFUL_FP8_DELAYED_SCALING in policies: + return _PickledExtraStateAction.UNSAFE_LOAD + + if has_delayed_state_keys: + return _PickledExtraStateAction.UNSAFE_LOAD + + # Unknown/newer payload shape. Give trusted users the explicit opt-in path. + if not policies: + return _PickledExtraStateAction.UNSAFE_LOAD + + return _PickledExtraStateAction.IGNORE + + +def _string_opcode_arg(arg: object) -> Optional[str]: + if isinstance(arg, str): + return arg + if isinstance(arg, bytes): + try: + return arg.decode("utf-8") + except UnicodeDecodeError: + return None + return None + + +def _global_opcode_arg(arg: object) -> Optional[tuple[str, str]]: + if not isinstance(arg, str): + return None + parts = arg.split() + if len(parts) != 2: + return None + return parts[0], parts[1] + + +def _stack_global_args(strings: list[str]) -> Optional[tuple[str, str]]: + if len(strings) < 2: + return None + return strings[-2], strings[-1] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c7ba8a8ab..a449d2dc87 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,6 +21,12 @@ import transformer_engine_torch as tex from ._common import _ParameterInitMeta, noop_cat +from .._extra_state import ( + extra_state_pickle_advisory, + is_stateless_recipe, + should_load_extra_state_pickle, + unsafe_pickle_extra_state_enabled, +) from ..quantization import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, @@ -1271,9 +1277,13 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: if not fp8_checkpoint: return torch.empty(0, dtype=torch.uint8) + recipe = self.fp8_meta["recipe"] + if is_stateless_recipe(recipe): + return torch.empty(0, dtype=torch.uint8) + # Copy tensors to CPU and store state = {} - state["recipe"] = self.fp8_meta["recipe"] + state["recipe"] = recipe if _has_delayed_scaling_state(self.fp8_meta): state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) @@ -1303,16 +1313,21 @@ def set_extra_state(self, state: torch.Tensor) -> None: return # Load state + context = self.__class__.__name__ if isinstance(state, torch.Tensor): - # No FP8 is indicated by an empty tensor we don't need to unpickle. + # Empty extra state does not need pickle handling. if state.numel() == 0: return - # Default format: byte tensor with pickled data - state = pickle.loads(state.detach().cpu().numpy().tobytes()) + state_bytes = state.detach().cpu().numpy().tobytes() + if not should_load_extra_state_pickle(state_bytes, context): + return + state = pickle.loads(state_bytes) elif isinstance(state, io.BytesIO): - # Deprecated format with io.BytesIO + # Deprecated format with io.BytesIO. Treat it as unsafe pickle. + if not unsafe_pickle_extra_state_enabled(): + raise RuntimeError(extra_state_pickle_advisory(context)) state.seek(0) - state = torch.load(state, map_location="cuda") + state = torch.load(state, map_location="cuda", weights_only=False) else: raise RuntimeError("Unsupported checkpoint format.") diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 86bd60ed9c..849c900f95 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,6 +14,7 @@ import torch from transformer_engine.common.recipe import Recipe +from .._extra_state import is_stateless_recipe, should_load_extra_state_pickle from ..quantization import ( FP8GlobalStateManager, QuantizerRole, @@ -590,11 +591,14 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Quantizer state fp8_meta = self._fp8_metas[mode] + recipe = fp8_meta["recipe"] + if is_stateless_recipe(recipe): + continue state[mode] = {} - state[mode]["recipe"] = fp8_meta["recipe"] + state[mode]["recipe"] = recipe # Copy tensors to CPU and store - if state[mode]["recipe"].delayed(): + if recipe.delayed(): if mode == "forward": state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) @@ -626,8 +630,12 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None: if state is None or state.numel() == 0: return - # Deserialize state from byte tensor - state = pickle.loads(state.detach().numpy(force=True).tobytes()) + # Deserialize state from byte tensor only when unsafe loading is enabled. + state_bytes = state.detach().numpy(force=True).tobytes() + context = self.__class__.__name__ + if not should_load_extra_state_pickle(state_bytes, context): + return + state = pickle.loads(state_bytes) if state is None or len(state) == 0: return