[train] Async batch collation (double-buffering) for the SFT trainer#1809
[train] Async batch collation (double-buffering) for the SFT trainer#1809dyurk-lila wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces async batch collation (double-buffering) to the SFT training loop, allowing the CPU-side batch slicing and collation for step N+1 to run on a background thread while step N's forward/backward pass runs on the GPU. This change hides collation latency under GPU compute. It adds the AsyncBatchCollator utility class, integrates it into SFTTrainer, adds configuration options, and includes comprehensive tests. Feedback was provided on skyrl/train/utils/async_batch_collator.py to avoid catching BaseException in the clear method, as it can intercept system-exiting exceptions like SystemExit and KeyboardInterrupt, recommending catching Exception instead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
The SFT training loop builds each step's batch with a CPU-side slice + collate on the main thread, which blocks the GPU. The per-step slice is fully deterministic from (global_step, batch_size, len(tokenized)) given the current data order, and the order only changes at epoch boundaries (rng.shuffle). So step N+1's batch can be collated on a background thread while step N's forward/backward runs on the GPU, hiding the collate latency under GPU compute. Add BatchPrefetcher (skyrl/train/utils/batch_prefetcher.py), a generic single-slot async double-buffer: a max_workers=1 ThreadPoolExecutor with a strict submit/get contract that asserts the in-flight step matches the step the loop expects, turning any mis-wiring into a loud failure rather than a silent training-order corruption. Wire it into SFTTrainer.train(): * Cross-epoch prefetch is withheld: _can_prefetch_next mirrors the loop's own reshuffle predicate, so step N+1 is never prefetched against the pre-shuffle order. The first step of each new epoch is computed synchronously against the post-shuffle order. * The prefetch slot is drained (clear()) before any reshuffle, so a worker thread can never read `tokenized` while it is being shuffled. * The executor is shut down in the loop's finally block (drains the in-flight batch + joins the worker) so the background thread is never leaked, even on exception. Gated by SFTConfig.prefetch_data (default True); set False to A/B against the serial data-loading path. Scope: this is an SFT-trainer optimization. The RL trainers do not need it -- in synchronous RL the per-step batch depends on the just-generated rollout, so it cannot be prefetched ahead of generation, and the fully async RL trainer already overlaps generation with training via its asyncio queue. BatchPrefetcher is generic, so a future deterministic producer can reuse it. tests/train/test_sft_prefetch.py: runs the real SFTTrainer.train() loop twice over the same dummy dataset/seed (prefetch OFF vs ON) across two epoch boundaries and asserts every batch handed to train_step is tensor-equal step-for-step, for both DefaultCollator and PackedDataCollator. Examples have distinct token ids, so a stale pre-shuffle batch would diverge from the baseline and fail. Plus BatchPrefetcher invariant unit tests (step-mismatch assertion, single-slot guard, worker-exception propagation, clear() drain). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
b563e03 to
6ef26b2
Compare
Standardize on async-batch-collation vocabulary: rename the class BatchPrefetcher → AsyncBatchCollator (module batch_prefetcher.py → async_batch_collator.py), the SFTConfig flag prefetch_data → async_batch_collation, and the loop wiring / helper / test symbols accordingly. Behavior is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Reviewers: Where to Look
The correctness-critical logic is the concurrency contract and the loop wiring that guarantees byte-identical batches. Focus review here:
skyrl/train/utils/async_batch_collator.py—AsyncBatchCollator(the whole file, ~105 lines). This is the generic single-slot async double-buffer. Scrutinize thesubmit/get/clear/shutdowncontract: the single-slotassertinsubmit, the step-matchassertinget(guards against silently serving a stale/mismatched batch), and thatclear/shutdownalways drain the in-flight future and join the executor so no worker thread or dataset reference leaks.skyrl/train/sft_trainer.py—SFTTrainer.train()async-batch-collation block (~L1557-1743). The load-bearing pieces:_slice_examples/_collate_batch/_epoch_of(~L1568-1579): the deterministic per-step producer that must reproduce the serial slice exactly._can_collate_ahead(~L1598-1602): the cross-epoch withhold — mirrors the loop's own reshuffle predicate so step N+1 is never collated ahead against the pre-shuffle order.async_collator.clear()beforerng.shuffle, ~L1721-1727) — the reshuffle/collate-ahead race guard.try/finally: async_collator.shutdown()around the loop (~L1604, L1737-1743): no leaked thread on exception.Less scrutiny needed:
skyrl/train/config/sft_config.py— just theasync_batch_collation: bool = Trueflag + docstring (config plumbing).tests/train/test_async_batch_collation.py— test scaffolding; useful to confirm coverage (byte-identity across epoch boundaries for bothDefaultCollatorandPackedDataCollator, uneven wrap-around, and theAsyncBatchCollatorinvariant unit tests) but not core logic.What does this PR do?
The SFT training loop builds each step's batch with a CPU-side slice + collate on the main thread, which blocks the GPU. The per-step slice is fully deterministic from
(global_step, batch_size, len(tokenized))given the current data order, and the order only changes at epoch boundaries (rng.shuffle). So stepN+1's batch can be collated on a background thread while stepN's forward/backward runs on the GPU, hiding the collate latency under GPU compute.Note this overlaps the in-memory slice + collation that assembles each training batch from the already-tokenized dataset; it is not disk/network data prefetching (that is a separate, orthogonal optimization).
This is default-safe and byte-identical to the serial path: the collated-ahead batch is exactly what the synchronous loop would have produced for the same step.
Changes
AsyncBatchCollator,skyrl/train/utils/async_batch_collator.py— a generic single-slot async double-buffer: amax_workers=1ThreadPoolExecutorwith a one-slot future and a strictsubmit/getcontract that asserts the in-flight step matches the step the loop expects, turning any mis-wiring into a loud failure rather than a silent training-order corruption. Teardown is exception-safe:clear()discards the in-flight batch (and any exception the orphaned worker raised) on a reset/teardown path so it cannot mask the caller's own unwind, andshutdown()always joins the executor even if draining the in-flight batch raises.SFTTrainer.train(),skyrl/train/sft_trainer.py— wire the async collator into the data-loading block.SFTConfig.async_batch_collation(defaultTrue),skyrl/train/config/sft_config.py— setFalseto A/B against the serial batch-building path. Defaulting toTrueflips the default SFT batch-building path to threaded; it is behavior-preserving (byte-identical batches; see tests) and only changes timing.Correctness guarantees
_can_collate_aheadmirrors the loop's own reshuffle predicate, so stepN+1is never collated against the pre-shuffle order. The first step of each new epoch is collated synchronously against the post-shuffle order.clear()) before any reshuffle, so a worker thread can never readtokenizedwhile it is being shuffled.finallyblock (drains the in-flight batch + joins the worker). Teardown is double-fault-safe — if the loop body raises and an in-flight collation worker also raised, the original loop exception is what propagates (the worker's exception is swallowed during the reset), and the worker thread is still joined.Scope
This is a controller-side SFT-trainer optimization. The other trainers do not need it:
trainer.py), the per-step training batch is built from the just-generated rollout (convert_to_training_input), so it cannot be collated ahead of generation.fully_async_trainer.py) already overlaps generation with training via itsasyncioqueue.SFTTrainer.run_eval()has the same collate-then-dispatch shape thattrain()optimizes here, and (since the eval slice has no wrap-around/reshuffle) would be an even cleanerAsyncBatchCollatorfit, but it is intentionally left serial: eval is off by default (eval_interval=0) and infrequent, and the per-batch model forward dominates the small eval collate, so the payoff is marginal. It can adopt the async collator later without API changes.AsyncBatchCollatoris written as a generic, reusable component (it holds no trainer state), so a future deterministic per-step producer can reuse it.Test plan
tests/train/test_async_batch_collation.py— 7 integration cases + 5AsyncBatchCollatorunit tests:SFTTrainer.train()loop twice over the same dummy dataset/seed (async collation OFF vs ON) across two epoch boundaries and assert everybatchhanded totrain_stepis tensor-equal step-for-step — for bothDefaultCollator(FSDP left-pad) andPackedDataCollator(Megatron FFD packing), plus an uneven wrap-around case, a mid-epoch resume case, an eval-enabled case, and anum_epochs-derivednum_stepscase. Examples have distinct token ids, so a stale pre-shuffle batch would diverge from the baseline and fail. (Mutation-checked: disabling either the cross-epoch withhold or the epoch-boundary drain makes the integration tests fail.)train_stepfailure coincides with an in-flight collation worker that also raises, asserting the original loop exception propagates (not the worker's re-raised error) and that nosft-batch-collateworker thread survives aftertrain()returns.AsyncBatchCollatorinvariant unit tests: step-mismatch assertion, single-slot guard, worker-exception propagation throughget(), andclear()drain.ruff==0.11.9+black==24.10.0clean.