fix(runtime): RuntimeCache pickle path — wrapper exclusion + handle bytes round-trip#4368
fix(runtime): RuntimeCache pickle path — wrapper exclusion + handle bytes round-trip#4368tp5uiuc wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
I am still trying to figure out if RuntimeCache is runtime specific or engine specific. If its not engine specific then the module should not serialize. The module should be 1:1 to the engine
There was a problem hiding this comment.
I see your point Naren. RuntimeCache is a runtime-config knob and is not engine specific and so not a part of the engine's identity. The same instance can be 1:N with engines: that's why we have an in-memory runtime cache object that the users can attach and share across multiple engines.
Also, since the TorchTensorRTModule wrapper is 1:1 with an engine, we should not serialize runtime cache, and so TorchTensorRTModule.getstate drops _runtime_settings (and _implicit_cache_handle, which aliases the same RuntimeCache) on pickle.
This is consistent with the two serialization paths toda:
- _pack_engine_info (line 236 (https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py#L236)): "RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory init values, not part of the engine's identity (see Move TRT-RTX runtime mode controls from CompilationSettings to runtime context managers #4310)."
- set_extra_state (line 515 (https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py#L515)): explicitly resets _runtime_settings = RuntimeSettings() on load — "Caller can reapply via mod.runtime_settings = ...".
This MR ensures that the pickle path also adheres to the same settings. I have added getstate and setstate for the RuntimeCache in case users want to pickle that object separately, independent from the torch module. Hope that makes sense, happy to sync up offline in case more info needed.
1ad40e3 to
9776534
Compare
|
[by Claude Code] CI status after rebase on #4367:
Other red is environmental, identical to #4362 and to main:
|
…ytes round-trip Two independent pickle-correctness fixes that together let ``torch.save(module)`` and other pickle paths handle ``RuntimeCache`` cleanly. These don't depend on any new attribute additions to ``RuntimeCache``; they're standalone improvements. 1. ``TorchTensorRTModule.__getstate__`` / ``__setstate__`` exclude ``_runtime_settings`` and ``_implicit_cache_handle`` from the pickle stream and reset both to defaults on load. This mirrors ``set_extra_state`` (line 515) and the documented intent at line 236: "RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory init values, not part of the engine's identity (see pytorch#4310)." Aligns the pickle path with the state_dict path -- the caller is expected to reapply any ``runtime_cache`` / strategy / cuda-graph configuration on the loading side. 2. ``_RuntimeCacheHandle`` (python rt) gains ``__getstate__`` / ``__setstate__``. The default pickle walks ``__dict__``, which crashes on the ``threading.Lock`` (``_thread.lock`` is unpicklable). The new protocol pickles ``(path, bytes)``, matching the cpp ``def_pickle`` shape. The bytes blob captures live cache contents when materialized, falling back to ``_pending_warm_bytes`` so the pre-materialize window is preserved end-to-end. Factored a lock-free ``_serialize_unlocked`` helper to share read logic with ``serialize``. 3. Cpp ``RuntimeCacheHandle::serialize()`` now falls back to ``pending_warm_bytes_`` when ``trt_handle_`` is null, so callers asking for the handle's persistable bytes (``save_to_stream``, ``def_pickle``) get the correct answer in every lifecycle state. The misleading "called before materialized" warning is retired now that the pre-materialize case has a real answer. 4. Cpp ``def_pickle`` switches from ``std::string`` (path only) to ``std::tuple<std::string, at::Tensor>`` (path + bytes) and round-trips via the existing ``serialize`` / ``deserialize`` API. The matching python ``_RuntimeCacheHandle.__getstate__`` uses the same shape. Pre-existing CI failures on ``test_cross_runtime_serde::test_save_*`` (python-runtime subprocess save/load) trace to (2) -- the ``_thread.lock`` failure -- and should resolve once this lands. Test ``test_python_handle_pickle_preserves_pending_warm_bytes`` covers the python handle's bytes round-trip end-to-end. Refs pytorch#4359
- ``RuntimeCacheHandle::serialize``: switch the pending-bytes copy from ``std::memcpy`` to ``std::copy`` (with free-function ``std::cbegin`` / ``std::cend``), and use free-function ``std::empty`` for the empty-vector check. Idiomatic C++17 in a hot path that didn't need the C-string-style API. - Drop the explanatory comment block above the pending-bytes branch. - Trim the wrapper ``__getstate__`` docstring: the rationale that named specific picklability hazards (``weakref``, ``threading.Lock``) is implementation noise that doesn't belong in the wrapper's contract. - Drop the cpp-rt skip on ``test_python_handle_pickle_preserves_pending_warm_bytes``. The test instantiates ``_RuntimeCacheHandle`` directly rather than going through ``RuntimeCache._handle`` selection, so the python class is exercisable regardless of which runtime is active. Updated docstring notes the directness. Behavior is unchanged. Addresses review comments on core/runtime/RuntimeSettings.cpp:131, :136 and tests/py/dynamo/runtime/test_000_runtime_cache.py:341.
9776534 to
f547d50
Compare
|
[by Claude Code] CI on f547d50 (rebased onto main after #4362 merged):
Remaining red is environmental / pre-existing on main:
Nothing this PR modifies is on any failure list. |
Summary
This MR adds three independent pickle-correctness fixes that together let
torch.save(module)and standalone-pickle ofRuntimeCache/RuntimeCacheHandlework cleanly.__getstate__/__setstate__onTorchTensorRTModuleexcludes_runtime_settingsand_implicit_cache_handlefrom the pickle stream and resets both to defaults on load. Mirrorsset_extra_state(line 515) and the line-236 design note: "RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory init values, not part of the engine's identity." Pickle path now matchesstate_dict+load_state_dictbehavior._RuntimeCacheHandle.__getstate__/__setstate__ships a protocol pair pickling as(path, bytes). Default pickle crashes on_thread.lock; the new protocol carries the cache contents (live cache when materialized,_pending_warm_bytesfallback otherwise) and rebuilds a freshthreading.Lockon the load side. Factored a lock-free_serialize_unlockedhelper.RuntimeCacheHandle::serialize()falls back topending_warm_bytes_whentrt_handle_is null, so callers asking for the handle's persistable bytes (save_to_stream,def_pickle) get the right answer in every lifecycle state. The misleading "called before materialized" warning is retired now that the pre-materialize case hooks into serializable bytes.def_picklemoves fromstd::string(path-only) tostd::tuple<std::string, at::Tensor>(path + bytes) and round-trips via the existingserialize/deserializeAPI. Matches the python facade's shape.Pre-existing CI noise this resolves
test_cross_runtime_serde::test_save_{cpp_load_python,python_load_python,python_load_cpp}: root cause is_thread.lockfailing pickle in python-runtime subprocesses. Fixed by the_RuntimeCacheHandle.__getstate__protocol._atexit_tokenslot containing aweakref. The wrapper exclusion here keeps that out of pickle entirely, unblocking the sibling PR's CI if needed.Refs #4359.
Type of change
Test plan
test_python_handle_pickle_preserves_pending_warm_bytes— pythonhandle's bytes round-trip end-to-end (path + pending bytes
preserved; fresh lock; no live cache reused).
test_000_runtime_cache.pypasses locally (24 passed, 1unrelated cpp-rt-only skip).
pre-commit runclean on touched files (mypy skipped for twopre-existing errors at
_TorchTensorRTModule.py:380and:508,unrelated).