[train] Skip building unused per-token loss_fn_outputs when the caller does not consume them#1807
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization to skip the generation of per-token loss function outputs (such as log probabilities and elementwise NLL) when they are not needed by the caller. A new configuration flag return_per_token_outputs is introduced and extracted from loss_fn_config via a utility function pop_return_per_token_outputs. In SFTTrainer, this flag is set to False during training and evaluation steps to avoid unnecessary GPU-to-CPU transfers and Python loops, as only scalar metrics are consumed. The Megatron and FSDP workers are updated to respect this flag, and comprehensive unit and integration tests are added to verify the correctness of this optimization. No review comments were provided, so there is no feedback to address.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
…r does not consume them
Threads a per-request `return_per_token_outputs` flag (default `True`) carried on the already-plumbed `loss_fn_config` dict to gate the per-token `loss_fn_outputs` build on the `cross_entropy` branch of **both** backends and **both** the train and eval/forward paths. When the flag is `False`, the per-token NLL, the two detached `[mb, seq]` D2H copies (logprobs + elementwise loss), and the `.tolist()` loop are skipped; each sequence gets an empty dict instead. The `loss` / `response_length` metrics and the `loss_fn_output_type` tag are unchanged.
SkyRL's own `SFTTrainer` reads only `output.metrics` (loss / response_length), never `output.loss_fn_outputs`, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (default `True`).
- **`worker_utils.py`**: new module-level helper `pop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool)` that copies the dict before popping the flag (callers' dicts are never mutated) and returns the default `True` for `None`. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.
- **Megatron** (`megatron/megatron_model_wrapper.py`): pop the flag once in `forward_backward_mini_batch` and gate the per-token build inside the shared `loss_func` (covers train + `forward_only` eval).
- **FSDP** (`workers/worker.py`): pop + gate in `_forward_backward_micro` (train) and `_forward_micro_with_loss` (eval).
- **Caller wiring** (`train/sft_trainer.py`): `train_step` (forward_backward) and `run_eval` (forward) now pass `loss_fn_config={"return_per_token_outputs": False}`.
The flag is popped from a per-call **copy** of `loss_fn_config` **before** the `AlgorithmConfig` merge — it is not an `AlgorithmConfig` field, so leaving it in would trip the key validation in `build_nested_dataclass` (reached via `from_dict_config`). The merge guard was changed from `if loss_fn_config is not None:` to `if loss_fn_config:` so an empty dict after the pop skips the merge exactly as today.
`return_per_token_outputs` is documented in the `pop_return_per_token_outputs` helper docstring and at each pop/gate site; the `loss_fn_config` docstrings on the three consuming worker methods that read the flag — `_forward_backward_micro`, `_forward_micro_with_loss` (FSDP), and `forward_backward_mini_batch` (Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.
Byte-identical for all existing callers:
- When the flag is absent/`None`, the pop never runs and the full prior per-token build executes unchanged. Default `True` reproduces the exact prior code paths.
- The merge-guard change (`is not None` → truthy) is a strict no-op: no existing caller ever passed `{}`, a non-empty override dict is still truthy and still merges, and `OmegaConf.merge(base, {})` is itself a no-op.
- `loss`, the backward pass, and all consumed scalar metrics (`loss`, `response_length`, `lr`) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.
- `loss_fn_output_type` is a `WorkerOutput` field that always defaults to `"scalar"` (never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-`"scalar"`-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.
- RL's separate (non-`cross_entropy`) `loss_fn_outputs` else-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelists `loss_fn_config` keys (empty allowed-key set for `cross_entropy`), so the flag cannot be injected by users; and even if present it is popped before any merge.
Note on the eval path: `run_eval` already iterates eval batches serially and reads only `output.metrics["loss"]`, so opting out there removes per-token work without changing any reported eval metric.
- `tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py` (new, CPU): drives the real FSDP `_forward_backward_micro` / `_forward_micro_with_loss` `cross_entropy` builds on CPU; asserts default/explicit-True populate `logprobs` + `elementwise_loss`, `False` yields empty dicts, `loss`/`response_length`/`lr` are identical across the flag, the flag is popped before the `AlgorithmConfig` merge (a real `eps_clip_low` override alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-`cross_entropy` else-branch is ungated (logprobs still built; outputs + loss identical across the flag); that test disables `use_kl_loss`/`use_entropy_loss` explicitly (rather than relying on a default) so it isolates the gate from the KL/entropy terms. CPU run: 12 passed.
- `tests/backends/skyrl_train/workers/test_worker_utils.py` (extended): unit tests for `pop_return_per_token_outputs` — `None`→`(None, True)`, absent-flag→`True` with config preserved, explicit `False`/`True` popped leaving legitimate overrides, and no caller-dict mutation.
- `tests/train/test_sft_callbacks.py` (extended): assert `SFTTrainer.train_step` and `run_eval` pass `loss_fn="cross_entropy"` and `loss_fn_config={"return_per_token_outputs": False}` to the dispatch.
- `tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py` (extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries `@pytest.mark.megatron` like the sibling tests), DP=2; runs the real worker `forward_backward`/`forward` twice on the same dummy batch (flag default-True vs explicit-False) and asserts `loss` + `response_length` identical and `loss_fn_output_type == "scalar"` in both, with per-token outputs populated when kept vs empty when skipped.
Run locally:
```bash
uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py tests/backends/skyrl_train/workers/test_worker_utils.py tests/train/test_sft_callbacks.py
```
The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (`tests/backends/skyrl_train/` + `tests/train/`); confirm that check is green before merge.
Covered: both backends (Megatron `loss_func`; FSDP micro methods) and both the train (`forward_backward`) and eval/forward (`forward(loss_fn="cross_entropy")`) paths. RL and Tinker contracts preserved (default `True`).
Intentionally out of scope:
- The RL `loss_fn_outputs` build (separate else-branch, not `cross_entropy`) is untouched; the `forward(loss_fn=None)` pure-inference path is untouched.
- The JAX backend builds `loss_fn_outputs` in a jit-traced path driven by a structured `LossFnConfig` dataclass rather than the runtime `loss_fn_config` dict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.
- No config schema field is added — the flag rides the runtime `loss_fn_config` dict only, preferring a function-arg/flag with a safe default over a global switch.
Doc follow-up: the caller-facing dispatch docstrings `WorkerDispatch.forward` / `WorkerDispatch.forward_backward` (the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.
This gate is orthogonal/complementary to several in-flight efforts touching the same files:
- **NovaSky-AI#1513** (SFT loss-aggregation rewrite of the same FSDP `cross_entropy` branch) — note it renames the SFT status key `loss`→`sft_loss`, so merge ordering matters; only the test coupling to the literal `loss` metric key would need a touch-up if it lands first.
- **NovaSky-AI#1752** (VLM SFT on Megatron) — disjoint regions in the shared files.
- **NovaSky-AI#1534** (preserve staged `forward_backward` loss_fn_outputs across DP ranks; `worker_dispatch.py`) — no overlapping file.
The gate logic itself is composable with all of these.
9107e37 to
fb3a6b0
Compare
Reviewers: Where to Look
The behavioral core of this PR is small; concentrate review on these three sites:
skyrl/backends/skyrl_train/workers/worker_utils.py—pop_return_per_token_outputs()(and theRETURN_PER_TOKEN_OUTPUTS_KEYconstant). This is the single authoritative source for the flag's contract: it pops the flag from a shallow copy (caller dict never mutated) and returns defaultTrueforNone. Verify the copy-before-pop and theNone → (None, True)default — every gate site depends on these two invariants.skyrl/backends/skyrl_train/workers/worker.py—_forward_backward_micro(train) and_forward_micro_with_loss(eval), and the equivalent gate in Megatron'sloss_funcinsideforward_backward_mini_batchinskyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py. Key things to confirm: (a)loss, the backward pass, and all consumed scalar metrics are computed outside/before the gated block so they are byte-identical regardless of the flag; (b) theif loss_fn_config is not None:→if loss_fn_config:merge-guard change is a true no-op (an empty dict after the pop must skip the merge exactly as before); (c) theelsebranch yields one empty dict per sequence and only thecross_entropybranch is gated — the RL (non-cross_entropy) else-branch that buildslogprobsmust stay ungated.skyrl/train/sft_trainer.py—train_step(forward_backward) andrun_eval(forward) now passloss_fn_config={RETURN_PER_TOKEN_OUTPUTS_KEY: False}. Confirm the SFT trainer truly reads onlyoutput.metricsand neveroutput.loss_fn_outputs, so opting out is safe.Rebase note: this branch was rebased onto latest
upstream/main, which now carriescompute_minibatch_rollout_logprob_diff_metrics(inworker_utils.pyand used in the RL branch of both backends). Those additions were merged alongside this PR's gate and are not part of this change — they need no scrutiny here beyond confirming the gate wraps only thecross_entropybranch and does not touch the RL metric path.Needs less scrutiny: the near-identical per-token-build bodies moved wholesale under the
if return_per_token_outputs:guard (unchanged logic, just re-indented); the docstring/comment additions at each pop/gate site; and the test scaffolding (CPU gate tests, extendedtest_worker_utils.py/test_sft_callbacks.py, and the parametrized GPUtest_training_step.pyleg).Summary
Threads a per-request
return_per_token_outputsflag (defaultTrue) carried on the already-plumbedloss_fn_configdict to gate the per-tokenloss_fn_outputsbuild on thecross_entropybranch of both backends and both the train and eval/forward paths. When the flag isFalse, the per-token NLL, the two detached[mb, seq]D2H copies (logprobs + elementwise loss), and the.tolist()loop are skipped; each sequence gets an empty dict instead. Theloss/response_lengthmetrics and theloss_fn_output_typetag are unchanged.SkyRL's own
SFTTrainerreads onlyoutput.metrics(loss / response_length), neveroutput.loss_fn_outputs, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (defaultTrue).What changed
worker_utils.py: new module-level helperpop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool)that copies the dict before popping the flag (callers' dicts are never mutated) and returns the defaultTrueforNone. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.megatron/megatron_model_wrapper.py): pop the flag once inforward_backward_mini_batchand gate the per-token build inside the sharedloss_func(covers train +forward_onlyeval).workers/worker.py): pop + gate in_forward_backward_micro(train) and_forward_micro_with_loss(eval).train/sft_trainer.py):train_step(forward_backward) andrun_eval(forward) now passloss_fn_config={"return_per_token_outputs": False}.The flag is popped from a per-call copy of
loss_fn_configbefore theAlgorithmConfigmerge — it is not anAlgorithmConfigfield, so leaving it in would trip the key validation inbuild_nested_dataclass(reached viafrom_dict_config). The merge guard was changed fromif loss_fn_config is not None:toif loss_fn_config:so an empty dict after the pop skips the merge exactly as today.return_per_token_outputsis documented in thepop_return_per_token_outputshelper docstring and at each pop/gate site; theloss_fn_configdocstrings on the three consuming worker methods that read the flag —_forward_backward_micro,_forward_micro_with_loss(FSDP), andforward_backward_mini_batch(Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.Numerical equivalence / safety
Byte-identical for all existing callers:
None, the pop never runs and the full prior per-token build executes unchanged. DefaultTruereproduces the exact prior code paths.is not None→ truthy) is a strict no-op: no existing caller ever passed{}, a non-empty override dict is still truthy and still merges, andOmegaConf.merge(base, {})is itself a no-op.loss, the backward pass, and all consumed scalar metrics (loss,response_length,lr) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.loss_fn_output_typeis aWorkerOutputfield that always defaults to"scalar"(never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-"scalar"-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.cross_entropy)loss_fn_outputselse-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelistsloss_fn_configkeys (empty allowed-key set forcross_entropy), so the flag cannot be injected by users; and even if present it is popped before any merge.Note on the eval path:
run_evalalready iterates eval batches serially and reads onlyoutput.metrics["loss"], so opting out there removes per-token work without changing any reported eval metric.Test plan
tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py(new, CPU): drives the real FSDP_forward_backward_micro/_forward_micro_with_losscross_entropybuilds on CPU; asserts default/explicit-True populatelogprobs+elementwise_loss,Falseyields empty dicts,loss/response_length/lrare identical across the flag, the flag is popped before theAlgorithmConfigmerge (a realeps_clip_lowoverride alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-cross_entropyelse-branch is ungated (logprobs still built; outputs + loss identical across the flag); that test disablesuse_kl_loss/use_entropy_lossexplicitly (rather than relying on a default) so it isolates the gate from the KL/entropy terms. CPU run: 12 passed.tests/backends/skyrl_train/workers/test_worker_utils.py(extended): unit tests forpop_return_per_token_outputs—None→(None, True), absent-flag→Truewith config preserved, explicitFalse/Truepopped leaving legitimate overrides, and no caller-dict mutation.tests/train/test_sft_callbacks.py(extended): assertSFTTrainer.train_stepandrun_evalpassloss_fn="cross_entropy"andloss_fn_config={"return_per_token_outputs": False}to the dispatch.tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py(extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries@pytest.mark.megatronlike the sibling tests), DP=2; runs the real workerforward_backward/forwardtwice on the same dummy batch (flag default-True vs explicit-False) and assertsloss+response_lengthidentical andloss_fn_output_type == "scalar"in both, with per-token outputs populated when kept vs empty when skipped.Run locally:
The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (
tests/backends/skyrl_train/+tests/train/); confirm that check is green before merge.Generality & follow-ups
Covered: both backends (Megatron
loss_func; FSDP micro methods) and both the train (forward_backward) and eval/forward (forward(loss_fn="cross_entropy")) paths. RL and Tinker contracts preserved (defaultTrue).Intentionally out of scope:
loss_fn_outputsbuild (separate else-branch, notcross_entropy) is untouched; theforward(loss_fn=None)pure-inference path is untouched.loss_fn_outputsin a jit-traced path driven by a structuredLossFnConfigdataclass rather than the runtimeloss_fn_configdict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.loss_fn_configdict only, preferring a function-arg/flag with a safe default over a global switch.Doc follow-up: the caller-facing dispatch docstrings
WorkerDispatch.forward/WorkerDispatch.forward_backward(the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.Relationship to open PRs
This gate is orthogonal/complementary to several in-flight efforts touching the same files:
cross_entropybranch) — note it renames the SFT status keyloss→sft_loss, so merge ordering matters; only the test coupling to the literallossmetric key would need a touch-up if it lands first.forward_backwardloss_fn_outputs across DP ranks;worker_dispatch.py) — no overlapping file.The gate logic itself is composable with all of these.