From c40ec955de82e3e905cd73fa947e241f8277acdb Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 26 Jun 2026 13:15:17 -0700 Subject: [PATCH 1/2] =?UTF-8?q?fix(runtime):=20RuntimeCache=20pickle=20pat?= =?UTF-8?q?h=20=E2=80=94=20wrapper=20exclusion=20+=20handle=20bytes=20roun?= =?UTF-8?q?d-trip?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two independent pickle-correctness fixes that together let ``torch.save(module)`` and other pickle paths handle ``RuntimeCache`` cleanly. These don't depend on any new attribute additions to ``RuntimeCache``; they're standalone improvements. 1. ``TorchTensorRTModule.__getstate__`` / ``__setstate__`` exclude ``_runtime_settings`` and ``_implicit_cache_handle`` from the pickle stream and reset both to defaults on load. This mirrors ``set_extra_state`` (line 515) and the documented intent at line 236: "RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory init values, not part of the engine's identity (see #4310)." Aligns the pickle path with the state_dict path -- the caller is expected to reapply any ``runtime_cache`` / strategy / cuda-graph configuration on the loading side. 2. ``_RuntimeCacheHandle`` (python rt) gains ``__getstate__`` / ``__setstate__``. The default pickle walks ``__dict__``, which crashes on the ``threading.Lock`` (``_thread.lock`` is unpicklable). The new protocol pickles ``(path, bytes)``, matching the cpp ``def_pickle`` shape. The bytes blob captures live cache contents when materialized, falling back to ``_pending_warm_bytes`` so the pre-materialize window is preserved end-to-end. Factored a lock-free ``_serialize_unlocked`` helper to share read logic with ``serialize``. 3. Cpp ``RuntimeCacheHandle::serialize()`` now falls back to ``pending_warm_bytes_`` when ``trt_handle_`` is null, so callers asking for the handle's persistable bytes (``save_to_stream``, ``def_pickle``) get the correct answer in every lifecycle state. The misleading "called before materialized" warning is retired now that the pre-materialize case has a real answer. 4. Cpp ``def_pickle`` switches from ``std::string`` (path only) to ``std::tuple`` (path + bytes) and round-trips via the existing ``serialize`` / ``deserialize`` API. The matching python ``_RuntimeCacheHandle.__getstate__`` uses the same shape. Pre-existing CI failures on ``test_cross_runtime_serde::test_save_*`` (python-runtime subprocess save/load) trace to (2) -- the ``_thread.lock`` failure -- and should resolve once this lands. Test ``test_python_handle_pickle_preserves_pending_warm_bytes`` covers the python handle's bytes round-trip end-to-end. Refs #4359 --- core/runtime/RuntimeSettings.cpp | 14 +++- core/runtime/register_jit_hooks.cpp | 26 ++++++-- .../dynamo/runtime/_TorchTensorRTModule.py | 31 +++++++++ py/torch_tensorrt/runtime/_runtime_cache.py | 66 +++++++++++++------ .../dynamo/runtime/test_000_runtime_cache.py | 26 ++++++++ 5 files changed, 132 insertions(+), 31 deletions(-) diff --git a/core/runtime/RuntimeSettings.cpp b/core/runtime/RuntimeSettings.cpp index 648baf4a67..e8269e738a 100644 --- a/core/runtime/RuntimeSettings.cpp +++ b/core/runtime/RuntimeSettings.cpp @@ -124,9 +124,17 @@ at::Tensor RuntimeCacheHandle::serialize() const { #ifdef TRT_MAJOR_RTX std::lock_guard lock(state_mu_); if (!trt_handle_) { - LOG_WARNING( - "RuntimeCacheHandle::serialize() called before the IRuntimeCache was materialized; returning empty bytes."); - return empty(); + // Pre-materialize: forward any ``pending_warm_bytes_`` so that the + // handle's persistable state survives ``save_to_stream`` and pickle + // even before any engine has triggered ``ensure_materialized``. + // ``save_to_stream`` / ``def_pickle`` round-trip then matches the + // python facade's behavior (it reads ``_pending_warm_bytes`` directly). + if (pending_warm_bytes_.empty()) { + return empty(); + } + auto tensor = at::empty({static_cast(pending_warm_bytes_.size())}, opts); + std::memcpy(tensor.data_ptr(), pending_warm_bytes_.data(), pending_warm_bytes_.size()); + return tensor; } auto host_mem = make_trt(trt_handle_->serialize()); TORCHTRT_CHECK(host_mem, "IRuntimeCache::serialize() returned null host memory; cannot serialize cache bytes."); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 790a7624c0..8c1f7f8070 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -31,16 +31,28 @@ static auto TORCHTRT_UNUSED RuntimeCacheHandleRegistration = // ``def_pickle`` registers ``__getstate__`` / ``__setstate__`` so the // handle can survive ``deepcopy`` / ``torch.export.save`` paths that // walk Python attributes (e.g. ``TorchTensorRTModule._implicit_cache_handle``). - // We persist only the ``path`` string: the underlying ``IRuntimeCache`` - // is CPU-side state that can't cross a process boundary anyway, and - // ``_resolve_runtime_cache`` re-warms from disk on the deserialized - // path through the standard load -> pending_warm_bytes flow. + // + // We persist ``(path, bytes)``: the bytes blob round-trips the cache + // contents end-to-end (live ``IRuntimeCache`` when materialized, + // ``pending_warm_bytes_`` when not -- ``serialize()`` unifies both + // states). On unpickle, ``deserialize`` stashes the bytes into + // ``pending_warm_bytes_`` so the first engine that calls + // ``ensure_materialized`` drains them into the live cache. The + // matching python facade (``_RuntimeCacheHandle.__getstate__``) + // uses the same shape. .def_pickle( // __getstate__ - [](c10::intrusive_ptr const& self) -> std::string { return self->path; }, + [](c10::intrusive_ptr const& self) -> std::tuple { + return std::make_tuple(self->path, self->serialize()); + }, // __setstate__ - [](std::string path) -> c10::intrusive_ptr { - return c10::make_intrusive(std::move(path)); + [](std::tuple state) -> c10::intrusive_ptr { + auto handle = c10::make_intrusive(std::move(std::get<0>(state))); + auto const& blob = std::get<1>(state); + if (blob.numel() > 0) { + handle->deserialize(blob); + } + return handle; }); // TODO: Implement a call method diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 40ff0dfa89..4ba7051b85 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -600,6 +600,37 @@ def restore_optimization_profile_state( engine._auto_select_profiles = auto engine.set_active_profile(active) + def __getstate__(self) -> dict[str, Any]: + """Exclude per-engine, in-memory state from the pickle stream. + + Mirrors the ``set_extra_state`` reset so that + ``torch.save(module)`` / ``torch.load`` behaves the same way as + ``state_dict`` / ``load_state_dict`` w.r.t. ``RuntimeSettings``: + the caller must reapply any ``runtime_cache`` / strategy / cuda-graph + configuration after load. See ``_pack_engine_info`` for the matching + cpp-side exclusion (engine bytes never carry these fields). + + ``_implicit_cache_handle`` is dropped alongside ``_runtime_settings`` + because it aliases the same ``RuntimeCache`` instance and would + otherwise drag a ``weakref`` (via the handle's ``atexit`` closure) + and a Python-only ``threading.Lock`` (when the python-runtime path + is active) into pickle -- neither is picklable. + """ + get_state = getattr(super(), "__getstate__", None) + state = (get_state() if get_state else self.__dict__).copy() + state.pop("_runtime_settings", None) + state.pop("_implicit_cache_handle", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + state.setdefault("_runtime_settings", RuntimeSettings()) + state.setdefault("_implicit_cache_handle", None) + set_state = getattr(super(), "__setstate__", None) + if set_state is not None: + set_state(state) + else: + self.__dict__.update(state) + def set_pre_allocated_outputs(self, enable: bool) -> None: self.get_engine().use_pre_allocated_outputs = enable diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py index c63c267fb6..5c5cbebe05 100644 --- a/py/torch_tensorrt/runtime/_runtime_cache.py +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -103,30 +103,21 @@ def __init__(self, path: str = "") -> None: self._pending_warm_bytes: Optional[bytes] = None self._lock = threading.Lock() - def __getstate__(self) -> dict: - # ``threading.Lock`` is not picklable, which breaks ``copy.deepcopy`` - # on any GraphModule that has us in its state (the cross-runtime - # export path calls deepcopy on the gm before re-tracing). The lock - # guards in-process mutations only; a freshly-deserialized cache - # always needs a new lock anyway. - state = self.__dict__.copy() - state.pop("_lock", None) - return state - - def __setstate__(self, state: dict) -> None: - self.__dict__.update(state) - self._lock = threading.Lock() + def _serialize_unlocked(self) -> torch.Tensor: + """Lock-free helper used by ``serialize`` and ``__getstate__``; + caller must hold ``self._lock``.""" + if self._cache is None: + return torch.empty(0, dtype=torch.uint8) + host_mem = self._cache.serialize() + if host_mem is None or host_mem.nbytes == 0: + return torch.empty(0, dtype=torch.uint8) + return torch.frombuffer( + bytearray(bytes(memoryview(host_mem))), dtype=torch.uint8 + ) def serialize(self) -> torch.Tensor: with self._lock: - if self._cache is None: - return torch.empty(0, dtype=torch.uint8) - host_mem = self._cache.serialize() - if host_mem is None or host_mem.nbytes == 0: - return torch.empty(0, dtype=torch.uint8) - return torch.frombuffer( - bytearray(bytes(memoryview(host_mem))), dtype=torch.uint8 - ) + return self._serialize_unlocked() def deserialize(self, data: torch.Tensor) -> None: with self._lock: @@ -157,6 +148,39 @@ def ensure_materialized(self, runtime_config: Any) -> Any: self._pending_warm_bytes = None return self._cache + def __getstate__(self) -> dict[str, Any]: + """Pickle as ``(path, bytes)`` mirroring the cpp ``def_pickle`` + contract. The bytes blob carries either the live ``IRuntimeCache`` + contents (when materialized) or the pending warm bytes + (pre-materialize window), so the loading side gets a hot handle on + the first engine without an extra ``handle.load()`` from disk. + + The lock and the live ``_cache`` are per-process and never cross + the pickle boundary; ``__setstate__`` rebuilds them fresh. + """ + with self._lock: + if self._cache is not None: + blob = self._serialize_unlocked() + elif self._pending_warm_bytes is not None: + blob = torch.frombuffer( + bytearray(self._pending_warm_bytes), dtype=torch.uint8 + ) + else: + blob = torch.empty(0, dtype=torch.uint8) + path = self.path + return {"path": path, "bytes": blob} + + def __setstate__(self, state: dict[str, Any]) -> None: + self.path = state["path"] + self._cache = None + self._pending_warm_bytes = None + self._lock = threading.Lock() + blob = state.get("bytes") + if blob is not None and blob.numel() > 0: + # Stashes into ``_pending_warm_bytes``; first engine that calls + # ``ensure_materialized`` drains it into the live cache. + self.deserialize(blob) + def _autosave_at_exit(ref: "weakref.ref[RuntimeCache]") -> None: """Module-level so the atexit closure only holds a weakref, not a bound diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 135995ce85..d4ee539a59 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -466,6 +466,32 @@ def test_pickle_round_trip_strips_atexit_token(self): "loaded handle must own its own atexit token (fresh weakref)", ) + @unittest.skipIf( + ENABLED_FEATURES.torch_tensorrt_runtime, + "exercises the python ``_RuntimeCacheHandle`` directly; cpp-rt " + "path is covered by ``register_jit_hooks.cpp`` ``def_pickle``", + ) + def test_python_handle_pickle_preserves_pending_warm_bytes(self): + """A python ``_RuntimeCacheHandle`` that has bytes loaded but + hasn't materialized them yet must round-trip those bytes through + pickle. Matches the cpp ``def_pickle`` contract (path + bytes). + """ + import pickle + + from torch_tensorrt.runtime._runtime_cache import _RuntimeCacheHandle + + handle = _RuntimeCacheHandle(path="/tmp/test_cache.bin") + warm = b"\x01\x02\x03\x04\x05\x06\x07\x08" + handle.deserialize(torch.frombuffer(bytearray(warm), dtype=torch.uint8)) + self.assertEqual(handle._pending_warm_bytes, warm) + + loaded = pickle.loads(pickle.dumps(handle)) + + self.assertEqual(loaded.path, "/tmp/test_cache.bin") + self.assertEqual(loaded._pending_warm_bytes, warm) + self.assertIsNone(loaded._cache, "live cache is never persisted") + self.assertIsNotNone(loaded._lock, "lock must be a fresh Lock instance") + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, From f547d507dbbdc2068e22663249aa3f34e2d85dd5 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 26 Jun 2026 17:04:32 -0700 Subject: [PATCH 2/2] review: address feedback on RuntimeCacheHandle pickle changes - ``RuntimeCacheHandle::serialize``: switch the pending-bytes copy from ``std::memcpy`` to ``std::copy`` (with free-function ``std::cbegin`` / ``std::cend``), and use free-function ``std::empty`` for the empty-vector check. Idiomatic C++17 in a hot path that didn't need the C-string-style API. - Drop the explanatory comment block above the pending-bytes branch. - Trim the wrapper ``__getstate__`` docstring: the rationale that named specific picklability hazards (``weakref``, ``threading.Lock``) is implementation noise that doesn't belong in the wrapper's contract. - Drop the cpp-rt skip on ``test_python_handle_pickle_preserves_pending_warm_bytes``. The test instantiates ``_RuntimeCacheHandle`` directly rather than going through ``RuntimeCache._handle`` selection, so the python class is exercisable regardless of which runtime is active. Updated docstring notes the directness. Behavior is unchanged. Addresses review comments on core/runtime/RuntimeSettings.cpp:131, :136 and tests/py/dynamo/runtime/test_000_runtime_cache.py:341. --- core/runtime/RuntimeSettings.cpp | 11 ++++------- .../dynamo/runtime/_TorchTensorRTModule.py | 5 +---- tests/py/dynamo/runtime/test_000_runtime_cache.py | 10 +++++----- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/core/runtime/RuntimeSettings.cpp b/core/runtime/RuntimeSettings.cpp index e8269e738a..ef6921619f 100644 --- a/core/runtime/RuntimeSettings.cpp +++ b/core/runtime/RuntimeSettings.cpp @@ -1,5 +1,6 @@ #include "core/runtime/RuntimeSettings.h" +#include #include #include #include @@ -124,16 +125,12 @@ at::Tensor RuntimeCacheHandle::serialize() const { #ifdef TRT_MAJOR_RTX std::lock_guard lock(state_mu_); if (!trt_handle_) { - // Pre-materialize: forward any ``pending_warm_bytes_`` so that the - // handle's persistable state survives ``save_to_stream`` and pickle - // even before any engine has triggered ``ensure_materialized``. - // ``save_to_stream`` / ``def_pickle`` round-trip then matches the - // python facade's behavior (it reads ``_pending_warm_bytes`` directly). - if (pending_warm_bytes_.empty()) { + if (std::empty(pending_warm_bytes_)) { return empty(); } auto tensor = at::empty({static_cast(pending_warm_bytes_.size())}, opts); - std::memcpy(tensor.data_ptr(), pending_warm_bytes_.data(), pending_warm_bytes_.size()); + std::copy( + std::cbegin(pending_warm_bytes_), std::cend(pending_warm_bytes_), static_cast(tensor.data_ptr())); return tensor; } auto host_mem = make_trt(trt_handle_->serialize()); diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 4ba7051b85..0272cf64b5 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -611,10 +611,7 @@ def __getstate__(self) -> dict[str, Any]: cpp-side exclusion (engine bytes never carry these fields). ``_implicit_cache_handle`` is dropped alongside ``_runtime_settings`` - because it aliases the same ``RuntimeCache`` instance and would - otherwise drag a ``weakref`` (via the handle's ``atexit`` closure) - and a Python-only ``threading.Lock`` (when the python-runtime path - is active) into pickle -- neither is picklable. + because it aliases the same ``RuntimeCache`` instance. """ get_state = getattr(super(), "__getstate__", None) state = (get_state() if get_state else self.__dict__).copy() diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index d4ee539a59..5616a234c8 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -466,15 +466,15 @@ def test_pickle_round_trip_strips_atexit_token(self): "loaded handle must own its own atexit token (fresh weakref)", ) - @unittest.skipIf( - ENABLED_FEATURES.torch_tensorrt_runtime, - "exercises the python ``_RuntimeCacheHandle`` directly; cpp-rt " - "path is covered by ``register_jit_hooks.cpp`` ``def_pickle``", - ) def test_python_handle_pickle_preserves_pending_warm_bytes(self): """A python ``_RuntimeCacheHandle`` that has bytes loaded but hasn't materialized them yet must round-trip those bytes through pickle. Matches the cpp ``def_pickle`` contract (path + bytes). + + The test instantiates ``_RuntimeCacheHandle`` directly rather than + going through ``RuntimeCache`` (which would pick the torchbind + backing on cpp rt), so the python class is exercisable regardless + of which runtime is active. """ import pickle