Skip to content

[train] Async batch collation (double-buffering) for the SFT trainer#1809

Open
dyurk-lila wants to merge 3 commits into
NovaSky-AI:mainfrom
dyurk-lila:feat/data-prefetch
Open

[train] Async batch collation (double-buffering) for the SFT trainer#1809
dyurk-lila wants to merge 3 commits into
NovaSky-AI:mainfrom
dyurk-lila:feat/data-prefetch

Conversation

@dyurk-lila

@dyurk-lila dyurk-lila commented Jun 18, 2026

Copy link
Copy Markdown

Reviewers: Where to Look

The correctness-critical logic is the concurrency contract and the loop wiring that guarantees byte-identical batches. Focus review here:

  • skyrl/train/utils/async_batch_collator.pyAsyncBatchCollator (the whole file, ~105 lines). This is the generic single-slot async double-buffer. Scrutinize the submit/get/clear/shutdown contract: the single-slot assert in submit, the step-match assert in get (guards against silently serving a stale/mismatched batch), and that clear/shutdown always drain the in-flight future and join the executor so no worker thread or dataset reference leaks.
  • skyrl/train/sft_trainer.pySFTTrainer.train() async-batch-collation block (~L1557-1743). The load-bearing pieces:
    • _slice_examples / _collate_batch / _epoch_of (~L1568-1579): the deterministic per-step producer that must reproduce the serial slice exactly.
    • _can_collate_ahead (~L1598-1602): the cross-epoch withhold — mirrors the loop's own reshuffle predicate so step N+1 is never collated ahead against the pre-shuffle order.
    • The consume/submit in the data-loading block (~L1614-1629) and the epoch-boundary drain (async_collator.clear() before rng.shuffle, ~L1721-1727) — the reshuffle/collate-ahead race guard.
    • The try/finally: async_collator.shutdown() around the loop (~L1604, L1737-1743): no leaked thread on exception.

Less scrutiny needed:

  • skyrl/train/config/sft_config.py — just the async_batch_collation: bool = True flag + docstring (config plumbing).
  • tests/train/test_async_batch_collation.py — test scaffolding; useful to confirm coverage (byte-identity across epoch boundaries for both DefaultCollator and PackedDataCollator, uneven wrap-around, and the AsyncBatchCollator invariant unit tests) but not core logic.

What does this PR do?

The SFT training loop builds each step's batch with a CPU-side slice + collate on the main thread, which blocks the GPU. The per-step slice is fully deterministic from (global_step, batch_size, len(tokenized)) given the current data order, and the order only changes at epoch boundaries (rng.shuffle). So step N+1's batch can be collated on a background thread while step N's forward/backward runs on the GPU, hiding the collate latency under GPU compute.

Note this overlaps the in-memory slice + collation that assembles each training batch from the already-tokenized dataset; it is not disk/network data prefetching (that is a separate, orthogonal optimization).

This is default-safe and byte-identical to the serial path: the collated-ahead batch is exactly what the synchronous loop would have produced for the same step.

Changes

  • AsyncBatchCollator, skyrl/train/utils/async_batch_collator.py — a generic single-slot async double-buffer: a max_workers=1 ThreadPoolExecutor with a one-slot future and a strict submit/get contract that asserts the in-flight step matches the step the loop expects, turning any mis-wiring into a loud failure rather than a silent training-order corruption. Teardown is exception-safe: clear() discards the in-flight batch (and any exception the orphaned worker raised) on a reset/teardown path so it cannot mask the caller's own unwind, and shutdown() always joins the executor even if draining the in-flight batch raises.
  • SFTTrainer.train(), skyrl/train/sft_trainer.py — wire the async collator into the data-loading block.
  • SFTConfig.async_batch_collation (default True), skyrl/train/config/sft_config.py — set False to A/B against the serial batch-building path. Defaulting to True flips the default SFT batch-building path to threaded; it is behavior-preserving (byte-identical batches; see tests) and only changes timing.

Correctness guarantees

  • Cross-epoch collate-ahead is withheld: _can_collate_ahead mirrors the loop's own reshuffle predicate, so step N+1 is never collated against the pre-shuffle order. The first step of each new epoch is collated synchronously against the post-shuffle order.
  • Drain before reshuffle: the in-flight slot is drained (clear()) before any reshuffle, so a worker thread can never read tokenized while it is being shuffled.
  • No leaked thread, even on exception: the executor is shut down in the loop's finally block (drains the in-flight batch + joins the worker). Teardown is double-fault-safe — if the loop body raises and an in-flight collation worker also raised, the original loop exception is what propagates (the worker's exception is swallowed during the reset), and the worker thread is still joined.

Scope

This is a controller-side SFT-trainer optimization. The other trainers do not need it:

  • In synchronous RL (trainer.py), the per-step training batch is built from the just-generated rollout (convert_to_training_input), so it cannot be collated ahead of generation.
  • The fully async RL trainer (fully_async_trainer.py) already overlaps generation with training via its asyncio queue.
  • SFTTrainer.run_eval() has the same collate-then-dispatch shape that train() optimizes here, and (since the eval slice has no wrap-around/reshuffle) would be an even cleaner AsyncBatchCollator fit, but it is intentionally left serial: eval is off by default (eval_interval=0) and infrequent, and the per-batch model forward dominates the small eval collate, so the payoff is marginal. It can adopt the async collator later without API changes.

AsyncBatchCollator is written as a generic, reusable component (it holds no trainer state), so a future deterministic per-step producer can reuse it.

Test plan

  • tests/train/test_async_batch_collation.py7 integration cases + 5 AsyncBatchCollator unit tests:
    • The integration cases run the real SFTTrainer.train() loop twice over the same dummy dataset/seed (async collation OFF vs ON) across two epoch boundaries and assert every batch handed to train_step is tensor-equal step-for-step — for both DefaultCollator (FSDP left-pad) and PackedDataCollator (Megatron FFD packing), plus an uneven wrap-around case, a mid-epoch resume case, an eval-enabled case, and a num_epochs-derived num_steps case. Examples have distinct token ids, so a stale pre-shuffle batch would diverge from the baseline and fail. (Mutation-checked: disabling either the cross-epoch withhold or the epoch-boundary drain makes the integration tests fail.)
    • A teardown double-fault case: a loop-body train_step failure coincides with an in-flight collation worker that also raises, asserting the original loop exception propagates (not the worker's re-raised error) and that no sft-batch-collate worker thread survives after train() returns.
    • AsyncBatchCollator invariant unit tests: step-mismatch assertion, single-slot guard, worker-exception propagation through get(), and clear() drain.
    • The collate path is pure Python + torch (no numpy), so the packed-collator test needs no Megatron runtime.
  • ruff==0.11.9 + black==24.10.0 clean.
uv run --isolated --extra dev --extra fsdp pytest tests/train/test_async_batch_collation.py -v

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces async batch collation (double-buffering) to the SFT training loop, allowing the CPU-side batch slicing and collation for step N+1 to run on a background thread while step N's forward/backward pass runs on the GPU. This change hides collation latency under GPU compute. It adds the AsyncBatchCollator utility class, integrates it into SFTTrainer, adds configuration options, and includes comprehensive tests. Feedback was provided on skyrl/train/utils/async_batch_collator.py to avoid catching BaseException in the clear method, as it can intercept system-exiting exceptions like SystemExit and KeyboardInterrupt, recommending catching Exception instead.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread skyrl/train/utils/async_batch_collator.py Outdated
The SFT training loop builds each step's batch with a CPU-side slice +
collate on the main thread, which blocks the GPU. The per-step slice is
fully deterministic from (global_step, batch_size, len(tokenized)) given
the current data order, and the order only changes at epoch boundaries
(rng.shuffle). So step N+1's batch can be collated on a background thread
while step N's forward/backward runs on the GPU, hiding the collate
latency under GPU compute.

Add BatchPrefetcher (skyrl/train/utils/batch_prefetcher.py), a generic
single-slot async double-buffer: a max_workers=1 ThreadPoolExecutor with
a strict submit/get contract that asserts the in-flight step matches the
step the loop expects, turning any mis-wiring into a loud failure rather
than a silent training-order corruption. Wire it into SFTTrainer.train():

* Cross-epoch prefetch is withheld: _can_prefetch_next mirrors the loop's
  own reshuffle predicate, so step N+1 is never prefetched against the
  pre-shuffle order. The first step of each new epoch is computed
  synchronously against the post-shuffle order.
* The prefetch slot is drained (clear()) before any reshuffle, so a
  worker thread can never read `tokenized` while it is being shuffled.
* The executor is shut down in the loop's finally block (drains the
  in-flight batch + joins the worker) so the background thread is never
  leaked, even on exception.

Gated by SFTConfig.prefetch_data (default True); set False to A/B against
the serial data-loading path.

Scope: this is an SFT-trainer optimization. The RL trainers do not need
it -- in synchronous RL the per-step batch depends on the just-generated
rollout, so it cannot be prefetched ahead of generation, and the fully
async RL trainer already overlaps generation with training via its
asyncio queue. BatchPrefetcher is generic, so a future deterministic
producer can reuse it.

tests/train/test_sft_prefetch.py: runs the real SFTTrainer.train() loop
twice over the same dummy dataset/seed (prefetch OFF vs ON) across two
epoch boundaries and asserts every batch handed to train_step is
tensor-equal step-for-step, for both DefaultCollator and
PackedDataCollator. Examples have distinct token ids, so a stale
pre-shuffle batch would diverge from the baseline and fail. Plus
BatchPrefetcher invariant unit tests (step-mismatch assertion,
single-slot guard, worker-exception propagation, clear() drain).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@dyurk-lila dyurk-lila force-pushed the feat/data-prefetch branch from b563e03 to 6ef26b2 Compare July 1, 2026 20:17
dyurk-lila and others added 2 commits July 1, 2026 13:41
Standardize on async-batch-collation vocabulary: rename the class
BatchPrefetcher → AsyncBatchCollator (module batch_prefetcher.py →
async_batch_collator.py), the SFTConfig flag prefetch_data →
async_batch_collation, and the loop wiring / helper / test symbols
accordingly. Behavior is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant