Skip to content

Add low-memory streaming conversion for unscanned DeepSeek-family checkpoints#4160

Open
discobot wants to merge 1 commit into
AI-Hypercomputer:mainfrom
discobot:fix/4071-unscanned-ckpt-low-memory
Open

Add low-memory streaming conversion for unscanned DeepSeek-family checkpoints#4160
discobot wants to merge 1 commit into
AI-Hypercomputer:mainfrom
discobot:fix/4071-unscanned-ckpt-low-memory

Conversation

@discobot

Copy link
Copy Markdown

Description

Fixes #4071.

The unscanned converter buffered every dequantized tensor for all shards before
assembly, OOM-ing on hosts with less than ~2.5 TB RAM for Kimi-K2.6.

Two things beyond what the issue establishes: assembly builds a second full fp16 copy
of the model while the buffered dict is still alive, and even with loading fixed, the
default save path (simulated_cpu_devices_count=16) re-materializes the whole pytree
as RAM-resident jax.Arrays inside shard_jax_weights — so streaming the loads alone
is not enough.

This PR makes tensor loading lazy unconditionally (a header-only index over the shards;
each tensor is read and int4-dequantized exactly once, when assembly consumes it) and
adds an opt-in --low_memory flag that stages converted leaves in read-only disk-backed
memmaps under TMPDIR and saves via the single-device path. Peak RSS drops from
O(2x model) to O(one tensor); on an 8-shard synthetic kimi-style checkpoint, RSS during
the shard scan goes from +400 MB monotonic growth to flat (+0.5 MB).

Checkpoints are bit-identical in all directions (verified via Orbax restore): old
converter vs new default, new default vs --low_memory, and 16-simulated-device
sharded save vs the low-memory single-device save — the saved checkpoint is
topology-independent, so skipping simulated-device sharding in low-memory mode does not
change the artifact.

Adds tests/unit/convert_deepseek_unscanned_low_memory_test.py (synthetic int4
multi-shard checkpoint; asserts no tensor reads before assembly, no re-reads, bit-exact
low-memory equivalence, and a save/restore round trip against independently computed
values), ignored in default CI like the other torch-dependent conversion tests, and
documents the flag in the Kimi runbook.

Tests

  • python3 -m pytest tests/unit/convert_deepseek_unscanned_low_memory_test.py tests/unit/dequantize_pack_quantized_int4_test.py (9 passed)
  • CLI end-to-end on a synthetic checkpoint with and without --low_memory true; restored pytrees bit-identical
  • pylint 10.00/10 and pyink --pyink-indentation=2 --line-length=122 clean on touched files

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

This change was developed with assistance from Claude Code.

…ckpoints

The converter buffered every (dequantized) tensor from all safetensors shards in one dict before assembling a second full copy of the model, so converting Kimi-K2.6 needed ~2.5 TB of host RAM (AI-Hypercomputer#4071). Tensor loading is now lazy: an index over the shard headers is built up front and each tensor is read (and int4-dequantized) only when the assembly consumes it. A new --low_memory flag additionally stages converted leaves in disk-backed numpy memmaps under TMPDIR and saves without simulated-device sharding, keeping peak RSS at O(one tensor) instead of O(2x model). Both the default path and the low-memory path produce bit-identical checkpoints to before; a new unit test covers streaming, disk spilling, and a save/restore round trip on a tiny synthetic kimi-style checkpoint.
@codecov

codecov Bot commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@shralex

shralex commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

How does this approach compare to lazy_load_tensors (which is now true by default) mode in src/maxtext/checkpoint_conversion/to_maxtext.py ?

@discobot

Copy link
Copy Markdown
Author

This is the standalone convert_deepseek_family_unscanned_ckpt.py, which didn't have lazy loading. _LazyShardLoader adds the same on-demand load/dequant. The part lazy_load_tensors doesn't cover is the save: with simulated_cpu_devices_count=16, device_put pulls the whole pytree back into RAM and still OOMs, so --low_memory spills leaves to .npy memmaps and skips simulated sharding to keep them disk-backed.

Tested on a synthetic int4-expert checkpoint, not a full Kimi-K2.6 run (that's the >2TB-host case this is meant to fix), so I've left the e2e box unchecked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Kimi-K2.6 unscanned checkpoint converter OOMs on hosts < ~2.5 TB RAM (buffers all 64 shards before write)

2 participants