Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions core/runtime/RuntimeSettings.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "core/runtime/RuntimeSettings.h"

#include <algorithm>
#include <array>
#include <cstring>
#include <iterator>
Expand Down Expand Up @@ -124,9 +125,13 @@ at::Tensor RuntimeCacheHandle::serialize() const {
#ifdef TRT_MAJOR_RTX
std::lock_guard<std::mutex> lock(state_mu_);
if (!trt_handle_) {
LOG_WARNING(
"RuntimeCacheHandle::serialize() called before the IRuntimeCache was materialized; returning empty bytes.");
return empty();
if (std::empty(pending_warm_bytes_)) {
return empty();
}
auto tensor = at::empty({static_cast<int64_t>(pending_warm_bytes_.size())}, opts);
std::copy(
std::cbegin(pending_warm_bytes_), std::cend(pending_warm_bytes_), static_cast<uint8_t*>(tensor.data_ptr()));
return tensor;
}
auto host_mem = make_trt(trt_handle_->serialize());
TORCHTRT_CHECK(host_mem, "IRuntimeCache::serialize() returned null host memory; cannot serialize cache bytes.");
Expand Down
26 changes: 19 additions & 7 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,28 @@ static auto TORCHTRT_UNUSED RuntimeCacheHandleRegistration =
// ``def_pickle`` registers ``__getstate__`` / ``__setstate__`` so the
// handle can survive ``deepcopy`` / ``torch.export.save`` paths that
// walk Python attributes (e.g. ``TorchTensorRTModule._implicit_cache_handle``).
// We persist only the ``path`` string: the underlying ``IRuntimeCache``
// is CPU-side state that can't cross a process boundary anyway, and
// ``_resolve_runtime_cache`` re-warms from disk on the deserialized
// path through the standard load -> pending_warm_bytes flow.
//
// We persist ``(path, bytes)``: the bytes blob round-trips the cache
// contents end-to-end (live ``IRuntimeCache`` when materialized,
// ``pending_warm_bytes_`` when not -- ``serialize()`` unifies both
// states). On unpickle, ``deserialize`` stashes the bytes into
// ``pending_warm_bytes_`` so the first engine that calls
// ``ensure_materialized`` drains them into the live cache. The
// matching python facade (``_RuntimeCacheHandle.__getstate__``)
// uses the same shape.
.def_pickle(
// __getstate__
[](c10::intrusive_ptr<RuntimeCacheHandle> const& self) -> std::string { return self->path; },
[](c10::intrusive_ptr<RuntimeCacheHandle> const& self) -> std::tuple<std::string, at::Tensor> {
return std::make_tuple(self->path, self->serialize());
},
// __setstate__
[](std::string path) -> c10::intrusive_ptr<RuntimeCacheHandle> {
return c10::make_intrusive<RuntimeCacheHandle>(std::move(path));
[](std::tuple<std::string, at::Tensor> state) -> c10::intrusive_ptr<RuntimeCacheHandle> {
auto handle = c10::make_intrusive<RuntimeCacheHandle>(std::move(std::get<0>(state)));
auto const& blob = std::get<1>(state);
if (blob.numel() > 0) {
handle->deserialize(blob);
}
return handle;
});

// TODO: Implement a call method
Expand Down
28 changes: 28 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still trying to figure out if RuntimeCache is runtime specific or engine specific. If its not engine specific then the module should not serialize. The module should be 1:1 to the engine

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point Naren. RuntimeCache is a runtime-config knob and is not engine specific and so not a part of the engine's identity. The same instance can be 1:N with engines: that's why we have an in-memory runtime cache object that the users can attach and share across multiple engines.

Also, since the TorchTensorRTModule wrapper is 1:1 with an engine, we should not serialize runtime cache, and so TorchTensorRTModule.getstate drops _runtime_settings (and _implicit_cache_handle, which aliases the same RuntimeCache) on pickle.

This is consistent with the two serialization paths toda:

This MR ensures that the pickle path also adheres to the same settings. I have added getstate and setstate for the RuntimeCache in case users want to pickle that object separately, independent from the torch module. Hope that makes sense, happy to sync up offline in case more info needed.

Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,34 @@ def restore_optimization_profile_state(
engine._auto_select_profiles = auto
engine.set_active_profile(active)

def __getstate__(self) -> dict[str, Any]:
"""Exclude per-engine, in-memory state from the pickle stream.

Mirrors the ``set_extra_state`` reset so that
``torch.save(module)`` / ``torch.load`` behaves the same way as
``state_dict`` / ``load_state_dict`` w.r.t. ``RuntimeSettings``:
the caller must reapply any ``runtime_cache`` / strategy / cuda-graph
configuration after load. See ``_pack_engine_info`` for the matching
cpp-side exclusion (engine bytes never carry these fields).

``_implicit_cache_handle`` is dropped alongside ``_runtime_settings``
because it aliases the same ``RuntimeCache`` instance.
"""
Comment thread
tp5uiuc marked this conversation as resolved.
get_state = getattr(super(), "__getstate__", None)
state = (get_state() if get_state else self.__dict__).copy()
state.pop("_runtime_settings", None)
state.pop("_implicit_cache_handle", None)
return state

def __setstate__(self, state: dict[str, Any]) -> None:
state.setdefault("_runtime_settings", RuntimeSettings())
state.setdefault("_implicit_cache_handle", None)
set_state = getattr(super(), "__setstate__", None)
if set_state is not None:
set_state(state)
else:
self.__dict__.update(state)

def set_pre_allocated_outputs(self, enable: bool) -> None:
self.get_engine().use_pre_allocated_outputs = enable

Expand Down
66 changes: 45 additions & 21 deletions py/torch_tensorrt/runtime/_runtime_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,30 +103,21 @@ def __init__(self, path: str = "") -> None:
self._pending_warm_bytes: Optional[bytes] = None
self._lock = threading.Lock()

def __getstate__(self) -> dict:
# ``threading.Lock`` is not picklable, which breaks ``copy.deepcopy``
# on any GraphModule that has us in its state (the cross-runtime
# export path calls deepcopy on the gm before re-tracing). The lock
# guards in-process mutations only; a freshly-deserialized cache
# always needs a new lock anyway.
state = self.__dict__.copy()
state.pop("_lock", None)
return state

def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
self._lock = threading.Lock()
def _serialize_unlocked(self) -> torch.Tensor:
"""Lock-free helper used by ``serialize`` and ``__getstate__``;
caller must hold ``self._lock``."""
if self._cache is None:
return torch.empty(0, dtype=torch.uint8)
host_mem = self._cache.serialize()
if host_mem is None or host_mem.nbytes == 0:
return torch.empty(0, dtype=torch.uint8)
return torch.frombuffer(
bytearray(bytes(memoryview(host_mem))), dtype=torch.uint8
)

def serialize(self) -> torch.Tensor:
with self._lock:
if self._cache is None:
return torch.empty(0, dtype=torch.uint8)
host_mem = self._cache.serialize()
if host_mem is None or host_mem.nbytes == 0:
return torch.empty(0, dtype=torch.uint8)
return torch.frombuffer(
bytearray(bytes(memoryview(host_mem))), dtype=torch.uint8
)
return self._serialize_unlocked()

def deserialize(self, data: torch.Tensor) -> None:
with self._lock:
Expand Down Expand Up @@ -157,6 +148,39 @@ def ensure_materialized(self, runtime_config: Any) -> Any:
self._pending_warm_bytes = None
return self._cache

def __getstate__(self) -> dict[str, Any]:
"""Pickle as ``(path, bytes)`` mirroring the cpp ``def_pickle``
contract. The bytes blob carries either the live ``IRuntimeCache``
contents (when materialized) or the pending warm bytes
(pre-materialize window), so the loading side gets a hot handle on
the first engine without an extra ``handle.load()`` from disk.

The lock and the live ``_cache`` are per-process and never cross
the pickle boundary; ``__setstate__`` rebuilds them fresh.
"""
with self._lock:
if self._cache is not None:
blob = self._serialize_unlocked()
elif self._pending_warm_bytes is not None:
blob = torch.frombuffer(
bytearray(self._pending_warm_bytes), dtype=torch.uint8
)
else:
blob = torch.empty(0, dtype=torch.uint8)
path = self.path
return {"path": path, "bytes": blob}

def __setstate__(self, state: dict[str, Any]) -> None:
self.path = state["path"]
self._cache = None
self._pending_warm_bytes = None
self._lock = threading.Lock()
blob = state.get("bytes")
if blob is not None and blob.numel() > 0:
# Stashes into ``_pending_warm_bytes``; first engine that calls
# ``ensure_materialized`` drains it into the live cache.
self.deserialize(blob)


def _autosave_at_exit(ref: "weakref.ref[RuntimeCache]") -> None:
"""Module-level so the atexit closure only holds a weakref, not a bound
Expand Down
26 changes: 26 additions & 0 deletions tests/py/dynamo/runtime/test_000_runtime_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,32 @@ def test_pickle_round_trip_strips_atexit_token(self):
"loaded handle must own its own atexit token (fresh weakref)",
)

def test_python_handle_pickle_preserves_pending_warm_bytes(self):
"""A python ``_RuntimeCacheHandle`` that has bytes loaded but
hasn't materialized them yet must round-trip those bytes through
pickle. Matches the cpp ``def_pickle`` contract (path + bytes).

The test instantiates ``_RuntimeCacheHandle`` directly rather than
going through ``RuntimeCache`` (which would pick the torchbind
backing on cpp rt), so the python class is exercisable regardless
of which runtime is active.
"""
import pickle

from torch_tensorrt.runtime._runtime_cache import _RuntimeCacheHandle

handle = _RuntimeCacheHandle(path="/tmp/test_cache.bin")
warm = b"\x01\x02\x03\x04\x05\x06\x07\x08"
handle.deserialize(torch.frombuffer(bytearray(warm), dtype=torch.uint8))
self.assertEqual(handle._pending_warm_bytes, warm)

loaded = pickle.loads(pickle.dumps(handle))

self.assertEqual(loaded.path, "/tmp/test_cache.bin")
self.assertEqual(loaded._pending_warm_bytes, warm)
self.assertIsNone(loaded._cache, "live cache is never persisted")
self.assertIsNotNone(loaded._lock, "lock must be a fresh Lock instance")


@unittest.skipIf(
not ENABLED_FEATURES.tensorrt_rtx,
Expand Down
Loading