Fix run_with_cache(device=...) permanently moving the model#1345
Merged
jlarson4 merged 2 commits intoMay 29, 2026
Conversation
The single-device branch moved the model and inputs to cache_device with no restore, leaving non-CPU models silently migrated and cfg.device stale. The move was redundant since make_cache_hook already offloads each captured activation, matching ActivationCache.to and the legacy get_caching_hooks contract. Flatten the conditional, add a regression test asserting original_model.to is not invoked, and document the device kwarg.
With the run_with_cache model-move fixed, TransformerBridge.generate return_cache device offload can use a run_with_cache(device=device) passthrough. The offload now happens at capture time, reducing peak memory. Drop the cache_dict direct-write and its justifying comment, simplify the offload test to a device-landing check.
1 task
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Fixes #1336.
run_with_cache(device=...)previously calledself.original_model.to(cache_device)in the single-device branch with no restore, silently migrating any non-CPU model tocache_deviceand leavingcfg.devicestale. The move was redundant:make_cache_hookalready offloads each captured activation, matchingActivationCache.to(move_modeldeprecated) and the documentedget_caching_hookscontract.Mock(wraps=)regression test (catches the bug on CPU CI whereto('cpu')is a no-op).devicekwarg in therun_with_cachedocstring.cache_dictdirect-write ingenerate(return_cache=True, device=...)with arun_with_cache(device=device)passthrough (offloads at capture time, reducing peak memory).Type of change
Checklist: