Skip to content

Add reward_functions_path + reward_functions CLI knobs for custom rewards#4149

Open
py4 wants to merge 1 commit into
mainfrom
pr/reward-functions-knobs
Open

Add reward_functions_path + reward_functions CLI knobs for custom rewards#4149
py4 wants to merge 1 commit into
mainfrom
pr/reward-functions-knobs

Conversation

@py4

@py4 py4 commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Summary

The post-train RL reward stack in create_rl_components is hardcoded to a 3-function list:

reward_fns = [
    make_reward_fn(utils_rl.match_format_exactly),
    make_reward_fn(utils_rl.match_format_approximately),
    make_reward_fn(utils_rl.check_numbers),
]

Replacing it requires editing train_rl.py. This blocks any user-provided reward (VTC-style partial credit, math_verify-based scorer, a custom domain reward) without forking maxtext.

This PR adds two config knobs:

  • reward_functions_path (str, default ""): filesystem path to a user Python file
  • reward_functions (str, default ""): comma-separated list of function names to import from that file

When both are set, the built-in reward list is replaced entirely by the user-provided functions. Each function must accept (prompts, completions, tmvp_config, **kwargs) and return list[float]. Default behavior is unchanged when either knob is empty.

Reuses the load_custom_callable helper (added in #4031) for the file import, so this is a small wiring change on top of an already-reviewed loader.

Verified on TPU (v6e 8x8)

Ran two live GRPO smoke tests on a Qwen3-1.7B / GSM8K recipe driving the reward through these knobs:

  1. VTC partial-credit reward (reward_functions=match_format_exactly from a user file): full run completed (pre-eval, 3 training steps, checkpoint, post-checks) with exit 0. Confirmed reward_fns: using 1 custom reward function(s) ['match_format_exactly'].

  2. Random reward (returns uniform(0,1), an unforgeable signal): confirmed the custom function is actually called and averaged, not silently defaulted:

    • Called during training: n=32 batches (4 prompts x 8 generations) scored each step.
    • Called during eval: the eval mean_reward exactly equalled the average of the printed random values (pre-RL: mean of 8 values = 0.6065 == reported 0.6065; post-RL: 0.3761 == 0.3761).

Files

File Change
src/maxtext/configs/types.py New reward_functions_path + reward_functions fields on the RLDataset mixin
src/maxtext/configs/post_train/rl.yml Defaults ('') + comment
src/maxtext/trainers/post_train/rl/train_rl.py Load + use custom reward fns when both knobs set; else built-in stack

Checklist

  • Pyink-clean (--pyink-indentation=2 --line-length=122)
  • Pydantic schema updated in types.py (RLDataset mixin)
  • Backward compatible: default-off when either knob is empty
  • No effect on non-RL paths
  • Verified end-to-end on TPU with two distinct custom rewards (VTC + random)

…ards

Currently the reward stack is hardcoded to a 3-fn list:
[match_format_exactly, match_format_approximately, check_numbers].
Replacing it requires editing train_rl.py.

Add two new config fields:
  - reward_functions_path: path to a user Python file
  - reward_functions: comma-separated list of function names to import

When both are set, the built-in reward stack is REPLACED entirely by the
user-provided functions (so users can pin a single VTC-style partial-credit
reward, swap in a math_verify-based scorer, etc., without editing maxtext).

Each user function must accept (prompts, completions, tmvp_config, **kwargs)
and return a list of floats.

Default (both empty) keeps existing behavior unchanged.

Reuses `_load_custom_callable` helper added in the previous commit.
@codecov

codecov Bot commented Jun 11, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/post_train/rl/train_rl.py 0.00% 7 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant