Skip to content

Fix run_with_cache(device=...) permanently moving the model#1345

Merged
jlarson4 merged 2 commits into
TransformerLensOrg:devfrom
RecreationalMath:fix-run-with-cache
May 29, 2026
Merged

Fix run_with_cache(device=...) permanently moving the model#1345
jlarson4 merged 2 commits into
TransformerLensOrg:devfrom
RecreationalMath:fix-run-with-cache

Conversation

@RecreationalMath
Copy link
Copy Markdown
Contributor

Description

Fixes #1336.

run_with_cache(device=...) previously called self.original_model.to(cache_device) in the single-device branch with no restore, silently migrating any non-CPU model to cache_device and leaving cfg.device stale. The move was redundant: make_cache_hook already offloads each captured activation, matching ActivationCache.to (move_model deprecated) and the documented get_caching_hooks contract.

  • Flatten the buggy conditional and add a Mock(wraps=) regression test (catches the bug on CPU CI where to('cpu') is a no-op).
  • Document the device kwarg in the run_with_cache docstring.
  • Replace the temporary cache_dict direct-write in generate(return_cache=True, device=...) with a run_with_cache(device=device) passthrough (offloads at capture time, reducing peak memory).
  • Simplify the corresponding device-offload test.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

The single-device branch moved the model and inputs to cache_device
with no restore, leaving non-CPU models silently migrated and
cfg.device stale. The move was redundant since make_cache_hook already
offloads each captured activation, matching ActivationCache.to and the
legacy get_caching_hooks contract.

Flatten the conditional, add a regression test asserting
original_model.to is not invoked, and document the device kwarg.
With the run_with_cache model-move fixed, TransformerBridge.generate
return_cache device offload can use a run_with_cache(device=device)
passthrough. The offload now happens at capture time, reducing peak
memory. Drop the cache_dict direct-write and its justifying comment,
simplify the offload test to a device-landing check.
@jlarson4 jlarson4 merged commit 34e6dc4 into TransformerLensOrg:dev May 29, 2026
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants