Skip to content

[fsdp] Exclude fully-padding microbatches from metric aggregation (parity with #1817)#1863

Open
kurtislin wants to merge 2 commits into
NovaSky-AI:mainfrom
kurtislin:fix/fsdp-padding-microbatch-metrics
Open

[fsdp] Exclude fully-padding microbatches from metric aggregation (parity with #1817)#1863
kurtislin wants to merge 2 commits into
NovaSky-AI:mainfrom
kurtislin:fix/fsdp-padding-microbatch-metrics

Conversation

@kurtislin

Copy link
Copy Markdown
Contributor

Summary

#1817 excluded fully-padding microbatches from metric aggregation, but only in megatron_worker.py. The shared forward_backward loops in worker.py (used by FSDP) still append them, so with max_tokens_per_microbatch > 0 mean-reduced metrics (policy_entropy, policy_kl, loss_metrics/*) are dragged toward 0.

This PR mirrors the Megatron-side skip into both loops (policy + critic), same placement and comment. Note: critic_loss is mean-reduced, so its reported value was previously biased toward 0 and changes with this fix.

Testing

New CPU test tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py forces one padding microbatch and asserts mean metrics are undiluted (1.0, not 2/3) and summed metrics unchanged. Fails without the fix; adjacent suites pass (34 tests).

NovaSky-AI#1817 excluded fully-padding microbatches from metric aggregation for the
Megatron backend. Apply the same skip to the shared forward_backward loops
in worker.py used by FSDP, mirroring megatron_worker.py: padding
microbatches still run forward/backward (per-rank collective counts stay
equal), only the metric append is skipped.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces changes to skip fully-padding microbatches during metric aggregation in both the policy and critic workers, preventing dummy zero-valued metrics from skewing mean-reduced metrics. It also adds corresponding unit tests. The review feedback correctly identifies that loss_fn_outputs are currently extracted before the padding check in the policy worker, which would pollute the aggregated outputs with dummy data. The reviewer suggests moving the padding check before this extraction and updating the tests to verify that these dummy outputs are successfully excluded.

Comment thread skyrl/backends/skyrl_train/workers/worker.py Outdated
Move the padding skip above the loss_fn_outputs extraction so dummy
per-sample entries from padding microbatches are not returned (review
feedback). Test asserts only real-sample outputs remain.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant