[train] Vectorize controller-side training-batch collation (SFT + RL)#1808
[train] Vectorize controller-side training-batch collation (SFT + RL)#1808dyurk-lila wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request vectorizes the collation and preprocessing pipelines in SkyRL by replacing per-token and per-sample Python loops with NumPy slice assignments and broadcast operations, followed by a single conversion to PyTorch tensors. It also adds a comprehensive test suite to verify bit-identical equivalence with the original implementations. The review feedback highlights potential runtime errors when inputs like rewards, logprobs, or loss_masks are PyTorch tensors (especially if they reside on GPU or require gradients), and suggests defensively detaching and moving them to CPU before converting to NumPy arrays.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
# What does this PR do? The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. **Outputs are bit-identical** (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change. ## Changes - **`PackedDataCollator` (Megatron SFT FFD packing), `skyrl/train/dataset/collators.py`** — the per-bin packed row tensors (`sequences` / `attention_mask` / `loss_mask`) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, and `total_nonpad` is a single vectorized reduction. - **`collate_sft_batch` / `DefaultCollator` (unpacked SFT), `skyrl/train/sft_trainer.py`** — each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list. - **`convert_prompts_responses_to_batch_tensors` (RL), `skyrl/train/dataset/preprocess.py`** — the left-padded `sequences` are built with two slice copies per row, and the fixed-width `attention_mask` / `action_mask` / `loss_mask` / `rewards` / `logprobs` tensors are produced with broadcast comparisons / slice writes instead of per-token Python loops. This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all `RayPPOTrainer` subclasses (sync / async / full-context / agentic), since none override `convert_to_training_input`. ### Intentionally out of scope - **MoE router-replay (`rollout_expert_indices`).** The optional `rollout_expert_indices` branch in `convert_prompts_responses_to_batch_tensors` is left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded). - **Eval path.** Packing only fires on the training-step batch (`batch_size == self.batch_size`); on the eval path `PackedDataCollator` delegates to the un-packed `DefaultCollator`, so eval collation is unchanged by this PR. Notes on the bit-identical claim: - dtypes are preserved exactly: `int64` / `torch.long` for `sequences` / masks (incl. `action_mask`), `float32` / `torch.float` for `loss_mask` / `rewards` / `logprobs`. `dtype=np.int64` is pinned explicitly (NumPy's platform-default int is `int32` on Windows). - The RL reward path accepts Python lists and `float32` reward tensors (what the reward postprocessing produces today). A `requires_grad`, CUDA, or `bfloat16` reward tensor is not accepted; no reward producer in the repo emits those. - The `PackedDataCollator` loss-mask write window keeps the original `row_p < max_packed_len - 1` clamp. That `min()` is a defensive no-op — `max_packed_len` is `>=` every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly. ## Benchmarks Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths): | Path | Before | After | Speedup | |------|--------|-------|---------| | `PackedDataCollator` (FFD, dp=8) | 288.5 ms | 8.4 ms | ~34x | | `convert_prompts_responses_to_batch_tensors` (RL) | 92.8 ms | 14.5 ms | ~6.4x | | `collate_sft_batch` (unpacked SFT) | 98.4 ms | 16.2 ms | ~6x | ## Test plan - [x] New `tests/train/test_collation_vectorization_equivalence.py`: pins a faithful reference of each *original* loop and fuzzes the vectorized output against it with `torch.equal` plus explicit per-tensor `dtype` assertions — RL (with/without logprobs, list and `float32`-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Because `torch.equal` is dtype-insensitive on matching values, the integer/float dtypes (`action_mask`/`attention_mask` `int64`/`long`, `loss_mask`/`rewards` `float32`) are pinned with explicit `.dtype` assertions. The packed test re-derives the FFD / DP-shard / `max_packed_len` decision inline as its own oracle, so any production drift surfaces as a `torch.equal` mismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachable `loss_mask` clamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.) - [x] Existing `tests/train/dataset/test_preprocess.py`, `tests/train/test_sft_packing_collate.py`, `tests/train/test_packing_round_trip.py`, `tests/train/test_sft_tokenization.py` pass unchanged. - [x] `ruff` + `black` clean. ```bash uv run --isolated --extra dev --extra megatron -- pytest \ tests/train/test_collation_vectorization_equivalence.py \ tests/train/dataset/test_preprocess.py \ tests/train/test_sft_packing_collate.py \ tests/train/test_packing_round_trip.py ``` > Heads-up for reviewers: this overlaps open PR NovaSky-AI#1752 ([train] VLM SFT on Megatron), which edits the same `collate_sft_batch` loop and `TrainingInputBatch` dict to collect `pixel_values` / `image_grid_thw`. Whichever lands second needs a small rebase; if NovaSky-AI#1752 lands first, its per-sample VLM tensor collection should be reinstated inside the vectorized `for i, ex in enumerate(examples):` loop and its two keys re-added to the `from_numpy` batch dict. 🤖 Generated with [Claude Code](https://claude.com/claude-code)
7c06f81 to
2261bb1
Compare
Reviewers: Where to Look
The behavioral core is three vectorized collation rewrites — all must produce bit-identical tensors (dtype + layout) to the loops they replace. Focus review here:
skyrl/train/dataset/preprocess.py—convert_prompts_responses_to_batch_tensors(RL). The per-rowsequencesslice writes (promptthenresponse), the broadcastattention_mask/action_maskcomparisons (col >= pad_len), and the right-alignedloss_mask/rewards/logprobsslice writes. Also the new_reward_to_numpyhelper: confirm list vs tensor,float32cast,detach().cpu(), and the 1-D shape guard.skyrl/train/dataset/collators.py—PackedDataCollator.__call__(Megatron SFT FFD packing). Highest-risk path. Check the per-sub-seqsequences/attention_maskslice copies, the right-shiftedloss_maskwrite window (full_mask[1:1+n_write]) and itswrite_end = min(row_offset + s - 1, loss_mask_width)clamp reproducing the originalrow_p < max_packed_len - 1guard, andtotal_nonpadas a single reduction.skyrl/train/sft_trainer.py—collate_sft_batch(unpacked SFT). The left-pad slice writes into preallocated arrays; confirmsequences/attention_mask/loss_maskdtypes staytorch.long.Sanity-check the equivalence oracle in
tests/train/test_collation_vectorization_equivalence.py: each vectorized path is fuzzed against a faithful copy of the original loop withtorch.equalplus explicit.dtypeassertions (the packed test re-derives the FFD/DP-shard/max_packed_lendecision inline).Needs less scrutiny: type-annotation widenings (
Union[List[float], torch.Tensor]), theimport numpy as npadditions, docstrings/comments, and the untouchedrollout_expert_indicesMoE branch (intentionally out of scope — byte-identical to before).What does this PR do?
The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. Outputs are bit-identical (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change.
Changes
PackedDataCollator(Megatron SFT FFD packing),skyrl/train/dataset/collators.py— the per-bin packed row tensors (sequences/attention_mask/loss_mask) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, andtotal_nonpadis a single vectorized reduction.collate_sft_batch/DefaultCollator(unpacked SFT),skyrl/train/sft_trainer.py— each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list.convert_prompts_responses_to_batch_tensors(RL),skyrl/train/dataset/preprocess.py— the left-paddedsequencesare built with two slice copies per row, and the fixed-widthattention_mask/action_mask/loss_mask/rewards/logprobstensors are produced with broadcast comparisons / slice writes instead of per-token Python loops.This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all
RayPPOTrainersubclasses (sync / async / full-context / agentic), since none overrideconvert_to_training_input.Intentionally out of scope
rollout_expert_indices). The optionalrollout_expert_indicesbranch inconvert_prompts_responses_to_batch_tensorsis left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded).batch_size == self.batch_size); on the eval pathPackedDataCollatordelegates to the un-packedDefaultCollator, so eval collation is unchanged by this PR.Notes on the bit-identical claim:
int64/torch.longforsequences/ masks (incl.action_mask),float32/torch.floatforloss_mask/rewards/logprobs.dtype=np.int64is pinned explicitly (NumPy's platform-default int isint32on Windows).float32reward tensors (what the reward postprocessing produces today). Arequires_grad, CUDA, orbfloat16reward tensor is not accepted; no reward producer in the repo emits those.PackedDataCollatorloss-mask write window keeps the originalrow_p < max_packed_len - 1clamp. Thatmin()is a defensive no-op —max_packed_lenis>=every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly.Benchmarks
Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths):
PackedDataCollator(FFD, dp=8)convert_prompts_responses_to_batch_tensors(RL)collate_sft_batch(unpacked SFT)Test plan
tests/train/test_collation_vectorization_equivalence.py: pins a faithful reference of each original loop and fuzzes the vectorized output against it withtorch.equalplus explicit per-tensordtypeassertions — RL (with/without logprobs, list andfloat32-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Becausetorch.equalis dtype-insensitive on matching values, the integer/float dtypes (action_mask/attention_maskint64/long,loss_mask/rewardsfloat32) are pinned with explicit.dtypeassertions. The packed test re-derives the FFD / DP-shard /max_packed_lendecision inline as its own oracle, so any production drift surfaces as atorch.equalmismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachableloss_maskclamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.)tests/train/dataset/test_preprocess.py,tests/train/test_sft_packing_collate.py,tests/train/test_packing_round_trip.py,tests/train/test_sft_tokenization.pypass unchanged.ruff+blackclean.🤖 Generated with Claude Code