Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .coverage
Binary file not shown.
1 change: 1 addition & 0 deletions cspell/library-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ libcuda
PYTHONPATH
venv
cuda
funcs
6 changes: 6 additions & 0 deletions cspell/project-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ vocab
finetune
accum
nonpositive
GRPO
styledistance
Wegmann
embs
cdfs
unprimed
23 changes: 23 additions & 0 deletions docs/03_rl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Reward Function

---

## Preliminaries

Let $\pi_{\theta}$ and $\pi_{\text{base}}$ denote the online and reference (base) models respectively, with completions $d_i \sim \pi_{\theta}(\cdot \mid x_i)$ and $d_i^{\text{base}} \sim \pi_{\text{base}}(\cdot \mid x_i)$ for prompt $x_i$.

## Typicality Reward

The typicality reward reuses the stylometric metrics directly (see [Stylometric Metrics](01_stylometry.md)), and asks how typical a single completion is of the author's general style.

For each metric $f$, let $\hat{F}_f$ be its empirical CDF over the author's training split, the same split used to construct $\mathcal{W}_0$ in the eval suite. For a text $a$, define its percentile $u_f(a) = \hat{F}_f(f(a)) \in [0, 1]$ and

$$\tau_f(a) = 1 - 2\left|u_f(a) - \frac{1}{2}\right| \in [0, 1]$$

so that $\tau_f(a) = 1$ when $a$ sits exactly at the author's median for $f$ and $\tau_f(a) = 0$ at either extreme. Group and aggregate exactly as in the eval suite (see [Evaluation Suite](02_evals.md)), averaging within group and taking the floored geometric mean across groups:

$$\bar{\tau}_g(a) = \frac{1}{|g|} \sum_{f \in g} \tau_f(a), \qquad T(a) = \left(\prod_{g=1}^{G} \max(\bar{\tau}_g(a),\, \varepsilon)\right)^{1/G}$$

using the same floor $\varepsilon$ as the eval suite. The reward is then:

$$r_{\text{typicality}}(d_i, d_i^{\text{base}}) = \max\big(T(d_i) - T(d_i^{\text{base}}),\ 0\big)$$
File renamed without changes.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dependencies = [
"pyyaml>=6.0.3",
"scipy>=1.17.0",
"seaborn>=0.13.2",
"sentence-transformers>=5.3.0",
"transformers",
"click>=8.0.0",
"wandb>=0.25.1",
Expand All @@ -28,6 +27,7 @@ voice = "voice.finetune.cli:main"

[tool.uv]
required-environments = ["sys_platform == 'linux'"]
python-preference = "only-managed"

[tool.uv.sources]
axolotl = { git = "https://github.com/axolotl-ai-cloud/axolotl.git", branch = "main" }
Expand Down Expand Up @@ -128,8 +128,6 @@ module = [
"seaborn.*",
"matplotlib",
"matplotlib.*",
"sentence_transformers",
"sentence_transformers.*",
"peft",
"peft.*",
"torch",
Expand Down
1 change: 1 addition & 0 deletions src/voice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
stylometric_distribution,
)
from voice.datasets import DatasetSpec, LocalDatasetSpec, get_dataset
from voice.rl._utils import prompt_transform as _prompt_transform # noqa: F401
from voice.stylometry import get_groups, get_metrics

__all__: list[str] = [
Expand Down
2 changes: 1 addition & 1 deletion src/voice/datasets/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _canonicalise_dataset(ds: Dataset) -> Dataset:
if {"system", "question", "answer"}.issubset(cols):
return ds

if "messages" in cols and len(cols) == 1:
if "messages" in cols:
return ds.map(
_extract_chat_style,
remove_columns=ds.column_names,
Expand Down
49 changes: 29 additions & 20 deletions src/voice/finetune/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
}
)

# Axes that require custom output keys rather than a direct pass through.
# lora_r -> lora_r + lora_alpha (alpha mirrors rank)
# target_layers -> target_layers_name + lora_target_modules
_SPECIAL_AXES: frozenset[str] = frozenset({"lora_r", "target_layers"})

# -----------------------------------------------------------------------------
# Run naming
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -68,6 +73,13 @@ def expand_grid(sweep_cfg: dict[str, Any]) -> list[dict[str, Any]]:
"""
Expand the sweep section into a flat list of per-run hyperparameter dicts.

Required axes (``learning_rate``, ``lora_r``, ``micro_batch_size``,
``gradient_accumulation_steps``, ``target_layers``) must always be
present. Any additional axes (e.g. dotted keys such as
``trl.beta``) are expanded into the product and passed through
for :func:`~voice.finetune._orchestrator._build_axolotl_cfg`
to merge into the config.

:param sweep_cfg: The ``sweep`` section of a sweep config YAML.
:return: List of hyperparameter dicts, one per combination.
:raises ValueError: If required sweep axes are missing.
Expand All @@ -78,25 +90,22 @@ def expand_grid(sweep_cfg: dict[str, Any]) -> list[dict[str, Any]]:
f"Sweep config missing required axes: {sorted(missing)}"
)

generic_keys = [k for k in sweep_cfg if k not in _SPECIAL_AXES]
special_keys = [k for k in sweep_cfg if k in _SPECIAL_AXES]
all_keys = generic_keys + special_keys
all_values = [sweep_cfg[k] for k in all_keys]

runs: list[dict[str, Any]] = []
combos = itertools.product(
sweep_cfg["learning_rate"],
sweep_cfg["lora_r"],
sweep_cfg["micro_batch_size"],
sweep_cfg["gradient_accumulation_steps"],
sweep_cfg["target_layers"],
)
for idx, (lr, rank, mbs, gas, tl) in enumerate(combos):
runs.append(
{
"run_idx": idx,
"learning_rate": lr,
"lora_r": rank,
"lora_alpha": rank,
"micro_batch_size": mbs,
"gradient_accumulation_steps": gas,
"target_layers_name": tl["name"],
"lora_target_modules": tl["modules"],
}
)
for idx, combo in enumerate(itertools.product(*all_values)):
params: dict[str, Any] = {"run_idx": idx}
for key, value in zip(all_keys, combo): # noqa: B905
if key == "lora_r":
params["lora_r"] = value
params["lora_alpha"] = value
elif key == "target_layers":
params["target_layers_name"] = value["name"]
params["lora_target_modules"] = value["modules"]
else:
params[key] = value
runs.append(params)
return runs
12 changes: 10 additions & 2 deletions src/voice/finetune/_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def _build_axolotl_cfg(
"""
Build a per-run axolotl config by merging hyperparams into the base.

Sets ``output_dir`` to ``run_dir.adapter_dir`` and copies any
recognised hyperparams keys present in ``hyperparams``.
Sets ``output_dir`` to ``run_dir.adapter_dir``, copies recognised
flat hyperparams keys, and merges dotted keys (e.g. ``"trl.beta"``)
into their named subsection dict.

:param base_cfg: Shared axolotl config from the sweep or single-run
config.
Expand All @@ -82,6 +83,13 @@ def _build_axolotl_cfg(
for key in _HYPERPARAMS_CFG_KEYS:
if key in hyperparams:
cfg[key] = hyperparams[key]
# Dotted keys (e.g. "trl.beta") are merged into the named subsection.
for key, value in hyperparams.items():
if "." in key:
parent, _, child = key.partition(".")
section = dict(cfg.get(parent) or {})
section[child] = value
cfg[parent] = section
return cfg


Expand Down
65 changes: 65 additions & 0 deletions src/voice/finetune/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,49 @@ def on_epoch_end(
# -----------------------------------------------------------------------------


class TypicalityRewardPlugin(BasePlugin): # type: ignore[misc]
"""
Axolotl plugin that primes typicality CDFs before GRPO training starts.

Loads the training split of the configured dataset and builds per-metric
empirical CDFs so that ``typicality_reward`` can look them up at reward
computation time without needing the dataset path.

Register alongside EvalCompletionsPlugin in the axolotl YAML config::

plugins:
- voice.finetune.callbacks.EvalCompletionsPlugin
- voice.finetune.callbacks.TypicalityRewardPlugin
"""

def add_callbacks_post_trainer(
self, cfg: Any, _trainer: Any
) -> list[TrainerCallback]:
"""
Prime typicality CDFs from the configured training dataset.

:param cfg: Axolotl config object.
:param _trainer: The built HF Trainer instance (unused).
:return: Empty list — no callbacks are registered.
:raises ValueError: If ``cfg.datasets[0].path`` is not accessible.
"""
from voice.rl.rewards import prime_typicality_cdfs

try:
dataset: str = cfg.datasets[0].path
except (AttributeError, IndexError) as exc:
raise ValueError(
"TypicalityRewardPlugin requires cfg.datasets[0].path "
"to be set in the axolotl config."
) from exc

log.info(
"TypicalityRewardPlugin: priming typicality CDFs from %s", dataset
)
prime_typicality_cdfs(dataset)
return []


class EvalCompletionsPlugin(BasePlugin): # type: ignore[misc]
"""
Axolotl plugin that registers EvalCompletionsCallback before training.
Expand All @@ -422,6 +465,28 @@ def add_callbacks_post_trainer(
:return: List containing the EvalCompletionsCallback instance.
:raises ValueError: If ``cfg.datasets[0].path`` is not accessible.
"""
chat_template_kwargs: dict[str, Any] = (
getattr(cfg, "chat_template_kwargs", None) or {}
)
if "enable_thinking" in chat_template_kwargs:
enable_thinking: bool = chat_template_kwargs["enable_thinking"]
tokenizer = getattr(trainer, "processing_class", None) or getattr(
trainer, "tokenizer", None
)
if tokenizer is not None:
_orig = tokenizer.apply_chat_template

def _patch(*args: Any, **kwargs: Any) -> Any:
kwargs.setdefault("enable_thinking", enable_thinking)
return _orig(*args, **kwargs)

tokenizer.apply_chat_template = _patch
log.info(
"EvalCompletionsPlugin: patched tokenizer to default "
"enable_thinking=%s",
enable_thinking,
)

try:
dataset: str = cfg.datasets[0].path
except (AttributeError, IndexError) as exc:
Expand Down
7 changes: 7 additions & 0 deletions src/voice/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Reinforcement learning module for the VOICE project.

Provides utilities for GRPO fine-tuning jobs, including:
- Reward functions
- GRPO dataset transform factory
"""
50 changes: 50 additions & 0 deletions src/voice/rl/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Dataset transform utilities for GRPO training."""

from __future__ import annotations

from typing import Any


def prompt_transform(
cfg: Any, # noqa: ANN401
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> tuple[Any, dict[str, Any]]: # noqa: ANN401
"""
Axolotl dataset transform factory for GRPO.

Strips the assistant turn from ``messages`` to build the prompt, extracts
it as ``true_completion``, and passes through ``ref_completion`` from the
dataset column. Both extra fields are forwarded to reward functions via
``**kwargs`` by TRL.

To register in axolotl config::

datasets:
- path: <hf-dataset>
type: voice._prompt_transform

:param cfg: Axolotl config object (passed by axolotl; not used here).
:param *args: positional arguments passed to ``prompt``.
:param **kwargs: keyword arguments passed to ``prompt``.
:return: ``(map_fn, dataset_map_kwargs)`` tuple consumed by axolotl.
"""
_ = cfg
__ = args
___ = kwargs

def _map(
example: dict[str, Any], # noqa: ANN401
tokenizer: Any = None, # noqa: ANN401
) -> dict[str, Any]: # noqa: ANN401
_ = tokenizer
messages: list[dict[str, str]] = example["messages"]
return {
"prompt": [m for m in messages if m["role"] != "assistant"],
"true_completion": next(
m["content"] for m in messages if m["role"] == "assistant"
),
"ref_completion": example["ref_completion"],
}

return _map, {}
Loading
Loading