Skip to content

[train] Skip building unused per-token loss_fn_outputs when the caller does not consume them#1807

Open
dyurk-lila wants to merge 2 commits into
NovaSky-AI:mainfrom
dyurk-lila:perf/skip-unused-per-token-loss-outputs
Open

[train] Skip building unused per-token loss_fn_outputs when the caller does not consume them#1807
dyurk-lila wants to merge 2 commits into
NovaSky-AI:mainfrom
dyurk-lila:perf/skip-unused-per-token-loss-outputs

Conversation

@dyurk-lila

@dyurk-lila dyurk-lila commented Jun 18, 2026

Copy link
Copy Markdown

Reviewers: Where to Look

The behavioral core of this PR is small; concentrate review on these three sites:

  • skyrl/backends/skyrl_train/workers/worker_utils.pypop_return_per_token_outputs() (and the RETURN_PER_TOKEN_OUTPUTS_KEY constant). This is the single authoritative source for the flag's contract: it pops the flag from a shallow copy (caller dict never mutated) and returns default True for None. Verify the copy-before-pop and the None → (None, True) default — every gate site depends on these two invariants.
  • The gate + merge-guard at the two FSDP sites in skyrl/backends/skyrl_train/workers/worker.py_forward_backward_micro (train) and _forward_micro_with_loss (eval), and the equivalent gate in Megatron's loss_func inside forward_backward_mini_batch in skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py. Key things to confirm: (a) loss, the backward pass, and all consumed scalar metrics are computed outside/before the gated block so they are byte-identical regardless of the flag; (b) the if loss_fn_config is not None:if loss_fn_config: merge-guard change is a true no-op (an empty dict after the pop must skip the merge exactly as before); (c) the else branch yields one empty dict per sequence and only the cross_entropy branch is gated — the RL (non-cross_entropy) else-branch that builds logprobs must stay ungated.
  • Caller opt-out in skyrl/train/sft_trainer.pytrain_step (forward_backward) and run_eval (forward) now pass loss_fn_config={RETURN_PER_TOKEN_OUTPUTS_KEY: False}. Confirm the SFT trainer truly reads only output.metrics and never output.loss_fn_outputs, so opting out is safe.

Rebase note: this branch was rebased onto latest upstream/main, which now carries compute_minibatch_rollout_logprob_diff_metrics (in worker_utils.py and used in the RL branch of both backends). Those additions were merged alongside this PR's gate and are not part of this change — they need no scrutiny here beyond confirming the gate wraps only the cross_entropy branch and does not touch the RL metric path.

Needs less scrutiny: the near-identical per-token-build bodies moved wholesale under the if return_per_token_outputs: guard (unchanged logic, just re-indented); the docstring/comment additions at each pop/gate site; and the test scaffolding (CPU gate tests, extended test_worker_utils.py / test_sft_callbacks.py, and the parametrized GPU test_training_step.py leg).


Summary

Threads a per-request return_per_token_outputs flag (default True) carried on the already-plumbed loss_fn_config dict to gate the per-token loss_fn_outputs build on the cross_entropy branch of both backends and both the train and eval/forward paths. When the flag is False, the per-token NLL, the two detached [mb, seq] D2H copies (logprobs + elementwise loss), and the .tolist() loop are skipped; each sequence gets an empty dict instead. The loss / response_length metrics and the loss_fn_output_type tag are unchanged.

SkyRL's own SFTTrainer reads only output.metrics (loss / response_length), never output.loss_fn_outputs, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (default True).

What changed

  • worker_utils.py: new module-level helper pop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool) that copies the dict before popping the flag (callers' dicts are never mutated) and returns the default True for None. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.
  • Megatron (megatron/megatron_model_wrapper.py): pop the flag once in forward_backward_mini_batch and gate the per-token build inside the shared loss_func (covers train + forward_only eval).
  • FSDP (workers/worker.py): pop + gate in _forward_backward_micro (train) and _forward_micro_with_loss (eval).
  • Caller wiring (train/sft_trainer.py): train_step (forward_backward) and run_eval (forward) now pass loss_fn_config={"return_per_token_outputs": False}.

The flag is popped from a per-call copy of loss_fn_config before the AlgorithmConfig merge — it is not an AlgorithmConfig field, so leaving it in would trip the key validation in build_nested_dataclass (reached via from_dict_config). The merge guard was changed from if loss_fn_config is not None: to if loss_fn_config: so an empty dict after the pop skips the merge exactly as today.

return_per_token_outputs is documented in the pop_return_per_token_outputs helper docstring and at each pop/gate site; the loss_fn_config docstrings on the three consuming worker methods that read the flag — _forward_backward_micro, _forward_micro_with_loss (FSDP), and forward_backward_mini_batch (Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.

Numerical equivalence / safety

Byte-identical for all existing callers:

  • When the flag is absent/None, the pop never runs and the full prior per-token build executes unchanged. Default True reproduces the exact prior code paths.
  • The merge-guard change (is not None → truthy) is a strict no-op: no existing caller ever passed {}, a non-empty override dict is still truthy and still merges, and OmegaConf.merge(base, {}) is itself a no-op.
  • loss, the backward pass, and all consumed scalar metrics (loss, response_length, lr) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.
  • loss_fn_output_type is a WorkerOutput field that always defaults to "scalar" (never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-"scalar"-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.
  • RL's separate (non-cross_entropy) loss_fn_outputs else-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelists loss_fn_config keys (empty allowed-key set for cross_entropy), so the flag cannot be injected by users; and even if present it is popped before any merge.

Note on the eval path: run_eval already iterates eval batches serially and reads only output.metrics["loss"], so opting out there removes per-token work without changing any reported eval metric.

Test plan

  • tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py (new, CPU): drives the real FSDP _forward_backward_micro / _forward_micro_with_loss cross_entropy builds on CPU; asserts default/explicit-True populate logprobs + elementwise_loss, False yields empty dicts, loss/response_length/lr are identical across the flag, the flag is popped before the AlgorithmConfig merge (a real eps_clip_low override alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-cross_entropy else-branch is ungated (logprobs still built; outputs + loss identical across the flag); that test disables use_kl_loss/use_entropy_loss explicitly (rather than relying on a default) so it isolates the gate from the KL/entropy terms. CPU run: 12 passed.
  • tests/backends/skyrl_train/workers/test_worker_utils.py (extended): unit tests for pop_return_per_token_outputsNone(None, True), absent-flag→True with config preserved, explicit False/True popped leaving legitimate overrides, and no caller-dict mutation.
  • tests/train/test_sft_callbacks.py (extended): assert SFTTrainer.train_step and run_eval pass loss_fn="cross_entropy" and loss_fn_config={"return_per_token_outputs": False} to the dispatch.
  • tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py (extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries @pytest.mark.megatron like the sibling tests), DP=2; runs the real worker forward_backward/forward twice on the same dummy batch (flag default-True vs explicit-False) and asserts loss + response_length identical and loss_fn_output_type == "scalar" in both, with per-token outputs populated when kept vs empty when skipped.

Run locally:

uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py tests/backends/skyrl_train/workers/test_worker_utils.py tests/train/test_sft_callbacks.py

The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (tests/backends/skyrl_train/ + tests/train/); confirm that check is green before merge.

Generality & follow-ups

Covered: both backends (Megatron loss_func; FSDP micro methods) and both the train (forward_backward) and eval/forward (forward(loss_fn="cross_entropy")) paths. RL and Tinker contracts preserved (default True).

Intentionally out of scope:

  • The RL loss_fn_outputs build (separate else-branch, not cross_entropy) is untouched; the forward(loss_fn=None) pure-inference path is untouched.
  • The JAX backend builds loss_fn_outputs in a jit-traced path driven by a structured LossFnConfig dataclass rather than the runtime loss_fn_config dict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.
  • No config schema field is added — the flag rides the runtime loss_fn_config dict only, preferring a function-arg/flag with a safe default over a global switch.

Doc follow-up: the caller-facing dispatch docstrings WorkerDispatch.forward / WorkerDispatch.forward_backward (the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.

Relationship to open PRs

This gate is orthogonal/complementary to several in-flight efforts touching the same files:

The gate logic itself is composable with all of these.

@dyurk-lila dyurk-lila marked this pull request as ready for review June 18, 2026 16:34

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization to skip the generation of per-token loss function outputs (such as log probabilities and elementwise NLL) when they are not needed by the caller. A new configuration flag return_per_token_outputs is introduced and extracted from loss_fn_config via a utility function pop_return_per_token_outputs. In SFTTrainer, this flag is set to False during training and evaluation steps to avoid unnecessary GPU-to-CPU transfers and Python loops, as only scalar metrics are consumed. The Megatron and FSDP workers are updated to respect this flag, and comprehensive unit and integration tests are added to verify the correctness of this optimization. No review comments were provided, so there is no feedback to address.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

…r does not consume them

Threads a per-request `return_per_token_outputs` flag (default `True`) carried on the already-plumbed `loss_fn_config` dict to gate the per-token `loss_fn_outputs` build on the `cross_entropy` branch of **both** backends and **both** the train and eval/forward paths. When the flag is `False`, the per-token NLL, the two detached `[mb, seq]` D2H copies (logprobs + elementwise loss), and the `.tolist()` loop are skipped; each sequence gets an empty dict instead. The `loss` / `response_length` metrics and the `loss_fn_output_type` tag are unchanged.

SkyRL's own `SFTTrainer` reads only `output.metrics` (loss / response_length), never `output.loss_fn_outputs`, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (default `True`).

- **`worker_utils.py`**: new module-level helper `pop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool)` that copies the dict before popping the flag (callers' dicts are never mutated) and returns the default `True` for `None`. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.
- **Megatron** (`megatron/megatron_model_wrapper.py`): pop the flag once in `forward_backward_mini_batch` and gate the per-token build inside the shared `loss_func` (covers train + `forward_only` eval).
- **FSDP** (`workers/worker.py`): pop + gate in `_forward_backward_micro` (train) and `_forward_micro_with_loss` (eval).
- **Caller wiring** (`train/sft_trainer.py`): `train_step` (forward_backward) and `run_eval` (forward) now pass `loss_fn_config={"return_per_token_outputs": False}`.

The flag is popped from a per-call **copy** of `loss_fn_config` **before** the `AlgorithmConfig` merge — it is not an `AlgorithmConfig` field, so leaving it in would trip the key validation in `build_nested_dataclass` (reached via `from_dict_config`). The merge guard was changed from `if loss_fn_config is not None:` to `if loss_fn_config:` so an empty dict after the pop skips the merge exactly as today.

`return_per_token_outputs` is documented in the `pop_return_per_token_outputs` helper docstring and at each pop/gate site; the `loss_fn_config` docstrings on the three consuming worker methods that read the flag — `_forward_backward_micro`, `_forward_micro_with_loss` (FSDP), and `forward_backward_mini_batch` (Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.

Byte-identical for all existing callers:

- When the flag is absent/`None`, the pop never runs and the full prior per-token build executes unchanged. Default `True` reproduces the exact prior code paths.
- The merge-guard change (`is not None` → truthy) is a strict no-op: no existing caller ever passed `{}`, a non-empty override dict is still truthy and still merges, and `OmegaConf.merge(base, {})` is itself a no-op.
- `loss`, the backward pass, and all consumed scalar metrics (`loss`, `response_length`, `lr`) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.
- `loss_fn_output_type` is a `WorkerOutput` field that always defaults to `"scalar"` (never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-`"scalar"`-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.
- RL's separate (non-`cross_entropy`) `loss_fn_outputs` else-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelists `loss_fn_config` keys (empty allowed-key set for `cross_entropy`), so the flag cannot be injected by users; and even if present it is popped before any merge.

Note on the eval path: `run_eval` already iterates eval batches serially and reads only `output.metrics["loss"]`, so opting out there removes per-token work without changing any reported eval metric.

- `tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py` (new, CPU): drives the real FSDP `_forward_backward_micro` / `_forward_micro_with_loss` `cross_entropy` builds on CPU; asserts default/explicit-True populate `logprobs` + `elementwise_loss`, `False` yields empty dicts, `loss`/`response_length`/`lr` are identical across the flag, the flag is popped before the `AlgorithmConfig` merge (a real `eps_clip_low` override alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-`cross_entropy` else-branch is ungated (logprobs still built; outputs + loss identical across the flag); that test disables `use_kl_loss`/`use_entropy_loss` explicitly (rather than relying on a default) so it isolates the gate from the KL/entropy terms. CPU run: 12 passed.
- `tests/backends/skyrl_train/workers/test_worker_utils.py` (extended): unit tests for `pop_return_per_token_outputs` — `None`→`(None, True)`, absent-flag→`True` with config preserved, explicit `False`/`True` popped leaving legitimate overrides, and no caller-dict mutation.
- `tests/train/test_sft_callbacks.py` (extended): assert `SFTTrainer.train_step` and `run_eval` pass `loss_fn="cross_entropy"` and `loss_fn_config={"return_per_token_outputs": False}` to the dispatch.
- `tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py` (extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries `@pytest.mark.megatron` like the sibling tests), DP=2; runs the real worker `forward_backward`/`forward` twice on the same dummy batch (flag default-True vs explicit-False) and asserts `loss` + `response_length` identical and `loss_fn_output_type == "scalar"` in both, with per-token outputs populated when kept vs empty when skipped.

Run locally:

```bash
uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py tests/backends/skyrl_train/workers/test_worker_utils.py tests/train/test_sft_callbacks.py
```

The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (`tests/backends/skyrl_train/` + `tests/train/`); confirm that check is green before merge.

Covered: both backends (Megatron `loss_func`; FSDP micro methods) and both the train (`forward_backward`) and eval/forward (`forward(loss_fn="cross_entropy")`) paths. RL and Tinker contracts preserved (default `True`).

Intentionally out of scope:
- The RL `loss_fn_outputs` build (separate else-branch, not `cross_entropy`) is untouched; the `forward(loss_fn=None)` pure-inference path is untouched.
- The JAX backend builds `loss_fn_outputs` in a jit-traced path driven by a structured `LossFnConfig` dataclass rather than the runtime `loss_fn_config` dict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.
- No config schema field is added — the flag rides the runtime `loss_fn_config` dict only, preferring a function-arg/flag with a safe default over a global switch.

Doc follow-up: the caller-facing dispatch docstrings `WorkerDispatch.forward` / `WorkerDispatch.forward_backward` (the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.

This gate is orthogonal/complementary to several in-flight efforts touching the same files:

- **NovaSky-AI#1513** (SFT loss-aggregation rewrite of the same FSDP `cross_entropy` branch) — note it renames the SFT status key `loss`→`sft_loss`, so merge ordering matters; only the test coupling to the literal `loss` metric key would need a touch-up if it lands first.
- **NovaSky-AI#1752** (VLM SFT on Megatron) — disjoint regions in the shared files.
- **NovaSky-AI#1534** (preserve staged `forward_backward` loss_fn_outputs across DP ranks; `worker_dispatch.py`) — no overlapping file.

The gate logic itself is composable with all of these.
@dyurk-lila dyurk-lila force-pushed the perf/skip-unused-per-token-loss-outputs branch from 9107e37 to fb3a6b0 Compare July 1, 2026 20:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant