diff --git a/core/runtime/RuntimeSettings.cpp b/core/runtime/RuntimeSettings.cpp index 648baf4a67..ef6921619f 100644 --- a/core/runtime/RuntimeSettings.cpp +++ b/core/runtime/RuntimeSettings.cpp @@ -1,5 +1,6 @@ #include "core/runtime/RuntimeSettings.h" +#include #include #include #include @@ -124,9 +125,13 @@ at::Tensor RuntimeCacheHandle::serialize() const { #ifdef TRT_MAJOR_RTX std::lock_guard lock(state_mu_); if (!trt_handle_) { - LOG_WARNING( - "RuntimeCacheHandle::serialize() called before the IRuntimeCache was materialized; returning empty bytes."); - return empty(); + if (std::empty(pending_warm_bytes_)) { + return empty(); + } + auto tensor = at::empty({static_cast(pending_warm_bytes_.size())}, opts); + std::copy( + std::cbegin(pending_warm_bytes_), std::cend(pending_warm_bytes_), static_cast(tensor.data_ptr())); + return tensor; } auto host_mem = make_trt(trt_handle_->serialize()); TORCHTRT_CHECK(host_mem, "IRuntimeCache::serialize() returned null host memory; cannot serialize cache bytes."); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 790a7624c0..8c1f7f8070 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -31,16 +31,28 @@ static auto TORCHTRT_UNUSED RuntimeCacheHandleRegistration = // ``def_pickle`` registers ``__getstate__`` / ``__setstate__`` so the // handle can survive ``deepcopy`` / ``torch.export.save`` paths that // walk Python attributes (e.g. ``TorchTensorRTModule._implicit_cache_handle``). - // We persist only the ``path`` string: the underlying ``IRuntimeCache`` - // is CPU-side state that can't cross a process boundary anyway, and - // ``_resolve_runtime_cache`` re-warms from disk on the deserialized - // path through the standard load -> pending_warm_bytes flow. + // + // We persist ``(path, bytes)``: the bytes blob round-trips the cache + // contents end-to-end (live ``IRuntimeCache`` when materialized, + // ``pending_warm_bytes_`` when not -- ``serialize()`` unifies both + // states). On unpickle, ``deserialize`` stashes the bytes into + // ``pending_warm_bytes_`` so the first engine that calls + // ``ensure_materialized`` drains them into the live cache. The + // matching python facade (``_RuntimeCacheHandle.__getstate__``) + // uses the same shape. .def_pickle( // __getstate__ - [](c10::intrusive_ptr const& self) -> std::string { return self->path; }, + [](c10::intrusive_ptr const& self) -> std::tuple { + return std::make_tuple(self->path, self->serialize()); + }, // __setstate__ - [](std::string path) -> c10::intrusive_ptr { - return c10::make_intrusive(std::move(path)); + [](std::tuple state) -> c10::intrusive_ptr { + auto handle = c10::make_intrusive(std::move(std::get<0>(state))); + auto const& blob = std::get<1>(state); + if (blob.numel() > 0) { + handle->deserialize(blob); + } + return handle; }); // TODO: Implement a call method diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 40ff0dfa89..0272cf64b5 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -600,6 +600,34 @@ def restore_optimization_profile_state( engine._auto_select_profiles = auto engine.set_active_profile(active) + def __getstate__(self) -> dict[str, Any]: + """Exclude per-engine, in-memory state from the pickle stream. + + Mirrors the ``set_extra_state`` reset so that + ``torch.save(module)`` / ``torch.load`` behaves the same way as + ``state_dict`` / ``load_state_dict`` w.r.t. ``RuntimeSettings``: + the caller must reapply any ``runtime_cache`` / strategy / cuda-graph + configuration after load. See ``_pack_engine_info`` for the matching + cpp-side exclusion (engine bytes never carry these fields). + + ``_implicit_cache_handle`` is dropped alongside ``_runtime_settings`` + because it aliases the same ``RuntimeCache`` instance. + """ + get_state = getattr(super(), "__getstate__", None) + state = (get_state() if get_state else self.__dict__).copy() + state.pop("_runtime_settings", None) + state.pop("_implicit_cache_handle", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + state.setdefault("_runtime_settings", RuntimeSettings()) + state.setdefault("_implicit_cache_handle", None) + set_state = getattr(super(), "__setstate__", None) + if set_state is not None: + set_state(state) + else: + self.__dict__.update(state) + def set_pre_allocated_outputs(self, enable: bool) -> None: self.get_engine().use_pre_allocated_outputs = enable diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py index c63c267fb6..5c5cbebe05 100644 --- a/py/torch_tensorrt/runtime/_runtime_cache.py +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -103,30 +103,21 @@ def __init__(self, path: str = "") -> None: self._pending_warm_bytes: Optional[bytes] = None self._lock = threading.Lock() - def __getstate__(self) -> dict: - # ``threading.Lock`` is not picklable, which breaks ``copy.deepcopy`` - # on any GraphModule that has us in its state (the cross-runtime - # export path calls deepcopy on the gm before re-tracing). The lock - # guards in-process mutations only; a freshly-deserialized cache - # always needs a new lock anyway. - state = self.__dict__.copy() - state.pop("_lock", None) - return state - - def __setstate__(self, state: dict) -> None: - self.__dict__.update(state) - self._lock = threading.Lock() + def _serialize_unlocked(self) -> torch.Tensor: + """Lock-free helper used by ``serialize`` and ``__getstate__``; + caller must hold ``self._lock``.""" + if self._cache is None: + return torch.empty(0, dtype=torch.uint8) + host_mem = self._cache.serialize() + if host_mem is None or host_mem.nbytes == 0: + return torch.empty(0, dtype=torch.uint8) + return torch.frombuffer( + bytearray(bytes(memoryview(host_mem))), dtype=torch.uint8 + ) def serialize(self) -> torch.Tensor: with self._lock: - if self._cache is None: - return torch.empty(0, dtype=torch.uint8) - host_mem = self._cache.serialize() - if host_mem is None or host_mem.nbytes == 0: - return torch.empty(0, dtype=torch.uint8) - return torch.frombuffer( - bytearray(bytes(memoryview(host_mem))), dtype=torch.uint8 - ) + return self._serialize_unlocked() def deserialize(self, data: torch.Tensor) -> None: with self._lock: @@ -157,6 +148,39 @@ def ensure_materialized(self, runtime_config: Any) -> Any: self._pending_warm_bytes = None return self._cache + def __getstate__(self) -> dict[str, Any]: + """Pickle as ``(path, bytes)`` mirroring the cpp ``def_pickle`` + contract. The bytes blob carries either the live ``IRuntimeCache`` + contents (when materialized) or the pending warm bytes + (pre-materialize window), so the loading side gets a hot handle on + the first engine without an extra ``handle.load()`` from disk. + + The lock and the live ``_cache`` are per-process and never cross + the pickle boundary; ``__setstate__`` rebuilds them fresh. + """ + with self._lock: + if self._cache is not None: + blob = self._serialize_unlocked() + elif self._pending_warm_bytes is not None: + blob = torch.frombuffer( + bytearray(self._pending_warm_bytes), dtype=torch.uint8 + ) + else: + blob = torch.empty(0, dtype=torch.uint8) + path = self.path + return {"path": path, "bytes": blob} + + def __setstate__(self, state: dict[str, Any]) -> None: + self.path = state["path"] + self._cache = None + self._pending_warm_bytes = None + self._lock = threading.Lock() + blob = state.get("bytes") + if blob is not None and blob.numel() > 0: + # Stashes into ``_pending_warm_bytes``; first engine that calls + # ``ensure_materialized`` drains it into the live cache. + self.deserialize(blob) + def _autosave_at_exit(ref: "weakref.ref[RuntimeCache]") -> None: """Module-level so the atexit closure only holds a weakref, not a bound diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 135995ce85..5616a234c8 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -466,6 +466,32 @@ def test_pickle_round_trip_strips_atexit_token(self): "loaded handle must own its own atexit token (fresh weakref)", ) + def test_python_handle_pickle_preserves_pending_warm_bytes(self): + """A python ``_RuntimeCacheHandle`` that has bytes loaded but + hasn't materialized them yet must round-trip those bytes through + pickle. Matches the cpp ``def_pickle`` contract (path + bytes). + + The test instantiates ``_RuntimeCacheHandle`` directly rather than + going through ``RuntimeCache`` (which would pick the torchbind + backing on cpp rt), so the python class is exercisable regardless + of which runtime is active. + """ + import pickle + + from torch_tensorrt.runtime._runtime_cache import _RuntimeCacheHandle + + handle = _RuntimeCacheHandle(path="/tmp/test_cache.bin") + warm = b"\x01\x02\x03\x04\x05\x06\x07\x08" + handle.deserialize(torch.frombuffer(bytearray(warm), dtype=torch.uint8)) + self.assertEqual(handle._pending_warm_bytes, warm) + + loaded = pickle.loads(pickle.dumps(handle)) + + self.assertEqual(loaded.path, "/tmp/test_cache.bin") + self.assertEqual(loaded._pending_warm_bytes, warm) + self.assertIsNone(loaded._cache, "live cache is never persisted") + self.assertIsNotNone(loaded._lock, "lock must be a fresh Lock instance") + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx,