diff --git a/README.md b/README.md
index 2b2fd2e..9847ea7 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,7 @@ Practice implementing operators and architectures from scratch โ the exact ski
[](https://github.com/duoan/TorchCode)
[](https://ghcr.io/duoan/torchcode)
[](https://huggingface.co/spaces/duoan/TorchCode)
-
+

[](https://star-history.com/#duoan/TorchCode&Date)
@@ -44,7 +44,7 @@ TorchCode gives you a **structured practice environment** with:
| | Feature | |
|---|---|---|
-| ๐งฉ | **40 curated problems** | The most frequently asked PyTorch interview topics |
+| ๐งฉ | **41 curated problems** | The most frequently asked PyTorch interview topics |
| โ๏ธ | **Automated judge** | Correctness checks, gradient verification, and timing |
| ๐จ | **Instant feedback** | Colored pass/fail per test case, just like competitive programming |
| ๐ก | **Hints when stuck** | Nudges without full spoilers |
@@ -198,6 +198,7 @@ If you're interviewing for any role touching LLMs or Transformers, expect at lea
| 37 | DPO Loss
| `dpo_loss(chosen, rejected, ...)` |  | ๐ก | Direct preference optimization, alignment training |
| 38 | GRPO Loss
| `grpo_loss(logps, rewards, group_ids, eps)` |  | ๐ก | Group relative policy optimization, RLAIF, within-group normalized advantages |
| 39 | PPO Loss
| `ppo_loss(new_logps, old_logps, advantages, clip_ratio)` |  | ๐ก | PPO clipped surrogate loss, policy gradient, trust region |
+| 41 | OPD Loss
| `opd_loss(student_logits, teacher_logits, ...)` |  | ๐ก | On-policy distillation, reverse KL, multi-teacher alignment |
---
@@ -244,7 +245,7 @@ status() # Progress dashboard โ solved / attempted / todo
| **1** | ๐งฑ Foundations | ReLU โ Softmax โ CE Loss โ Dropout โ Embedding โ GELU โ Linear โ LayerNorm โ BatchNorm โ RMSNorm โ SwiGLU MLP โ Conv2d | 2โ3 hrs |
| **2** | ๐ง Attention Deep Dive | SDPA โ MHA โ Cross-Attn โ Causal โ GQA โ KV Cache โ Sliding Window โ RoPE โ Linear Attn โ Flash Attn | 3โ4 hrs |
| **3** | ๐๏ธ Architecture + Training | GPT-2 Block โ LoRA โ MoE โ ViT Patch โ Adam โ Cosine LR โ Grad Clip โ Grad Accumulation โ Kaiming Init | 3โ4 hrs |
-| **4** | ๐ฏ Inference + Advanced | Top-k/p Sampling โ Beam Search โ Speculative Decoding โ BPE โ INT8 Quant โ DPO Loss โ GRPO Loss โ PPO Loss + speed run | 3โ4 hrs |
+| **4** | ๐ฏ Inference + Advanced | Top-k/p Sampling โ Beam Search โ Speculative Decoding โ BPE โ INT8 Quant โ DPO Loss โ GRPO Loss โ PPO Loss โ OPD Loss + speed run | 3โ4 hrs |
---
diff --git a/solutions/41_opd_loss_solution.ipynb b/solutions/41_opd_loss_solution.ipynb
new file mode 100644
index 0000000..e7c6590
--- /dev/null
+++ b/solutions/41_opd_loss_solution.ipynb
@@ -0,0 +1,137 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/41_opd_loss_solution.ipynb)\n",
+ "\n",
+ "# Solution: OPD (On-Policy Distillation) Loss\n",
+ "\n",
+ "Reference solution."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
+ "try:\n",
+ " import google.colab\n",
+ " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
+ "except ImportError:\n",
+ " pass\n"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from torch import Tensor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# โ
SOLUTION\n",
+ "\n",
+ "def opd_loss(student_logits: Tensor,\n",
+ " teacher_logits: Tensor,\n",
+ " teacher_weights: Tensor | None = None,\n",
+ " mask: Tensor | None = None,\n",
+ " temperature: float = 1.0) -> Tensor:\n",
+ " \"\"\"On-Policy Distillation (OPD) reverse-KL loss.\n",
+ "\n",
+ " student_logits: (..., V) logits from the student policy\n",
+ " teacher_logits: (..., V) for one teacher, or (T, ..., V) for T teachers\n",
+ " teacher_weights: optional (T,) weights, normalized internally\n",
+ " mask: optional (...) token mask, where 1 = include and 0 = ignore\n",
+ " temperature: softmax temperature used for distillation\n",
+ " returns: scalar loss (Tensor)\n",
+ " \"\"\"\n",
+ " if teacher_logits.dim() == student_logits.dim():\n",
+ " teacher_logits = teacher_logits.unsqueeze(0)\n",
+ " elif teacher_logits.dim() != student_logits.dim() + 1:\n",
+ " raise ValueError('teacher_logits must have shape (..., V) or (T, ..., V)')\n",
+ "\n",
+ " teacher_logits = teacher_logits.detach()\n",
+ " t = float(temperature)\n",
+ "\n",
+ " student_logp = F.log_softmax(student_logits / t, dim=-1)\n",
+ " student_prob = student_logp.exp()\n",
+ " teacher_logp = F.log_softmax(teacher_logits / t, dim=-1)\n",
+ "\n",
+ " per_teacher_kl = (\n",
+ " student_prob.unsqueeze(0) * (student_logp.unsqueeze(0) - teacher_logp)\n",
+ " ).sum(dim=-1)\n",
+ "\n",
+ " num_teachers = per_teacher_kl.shape[0]\n",
+ " if teacher_weights is None:\n",
+ " weights = torch.full(\n",
+ " (num_teachers,),\n",
+ " 1.0 / num_teachers,\n",
+ " dtype=per_teacher_kl.dtype,\n",
+ " device=per_teacher_kl.device,\n",
+ " )\n",
+ " else:\n",
+ " weights = teacher_weights.to(dtype=per_teacher_kl.dtype, device=per_teacher_kl.device)\n",
+ " weights = weights / weights.sum()\n",
+ "\n",
+ " view_shape = (num_teachers,) + (1,) * (per_teacher_kl.dim() - 1)\n",
+ " per_token_kl = (weights.view(view_shape) * per_teacher_kl).sum(dim=0)\n",
+ "\n",
+ " if mask is not None:\n",
+ " mask = mask.to(dtype=per_token_kl.dtype, device=per_token_kl.device)\n",
+ " loss = (per_token_kl * mask).sum() / mask.sum().clamp_min(1.0)\n",
+ " else:\n",
+ " loss = per_token_kl.mean()\n",
+ "\n",
+ " return loss * (t ** 2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Demo\n",
+ "student_logits = torch.tensor([[[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]])\n",
+ "teacher_logits = torch.tensor([[[1.0, 1.5, -0.5], [0.0, 2.0, -1.0]]])\n",
+ "print('Loss:', opd_loss(student_logits, teacher_logits).item())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torch_judge import check\n",
+ "check('opd_loss')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/templates/00_welcome.ipynb b/templates/00_welcome.ipynb
index c3498f5..e8e46ef 100644
--- a/templates/00_welcome.ipynb
+++ b/templates/00_welcome.ipynb
@@ -64,7 +64,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Problem List (40 problems)\n",
+ "## Problem List (41 problems)\n",
"\n",
"### ๐งฑ Fundamentals โ \"Implement X from scratch\"\n",
"\n",
@@ -135,6 +135,7 @@
"| 37 | DPO Loss | ๐ด Hard | [Open](37_dpo_loss.ipynb) ยท GitHub ยท Colab | [Open](37_dpo_loss_solution.ipynb) ยท GitHub ยท Colab |\n",
"| 38 | GRPO Loss | ๐ด Hard | [Open](38_grpo_loss.ipynb) ยท GitHub ยท Colab | [Open](38_grpo_loss_solution.ipynb) ยท GitHub ยท Colab |\n",
"| 39 | PPO Loss | ๐ด Hard | [Open](39_ppo_loss.ipynb) ยท GitHub ยท Colab | [Open](39_ppo_loss_solution.ipynb) ยท GitHub ยท Colab |\n",
+ "| 41 | OPD Loss | ๐ด Hard | [Open](41_opd_loss.ipynb) ยท GitHub ยท Colab | [Open](41_opd_loss_solution.ipynb) ยท GitHub ยท Colab |\n",
"\n",
"## Useful Commands\n",
"\n",
diff --git a/templates/41_opd_loss.ipynb b/templates/41_opd_loss.ipynb
new file mode 100644
index 0000000..d450fad
--- /dev/null
+++ b/templates/41_opd_loss.ipynb
@@ -0,0 +1,120 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/41_opd_loss.ipynb)\n",
+ "\n",
+ "# ๐ด Hard: OPD Loss\n",
+ "\n",
+ "Implement the **On-Policy Distillation (OPD)** loss used to distill one or more teacher policies into a student policy on trajectories sampled from the student.\n",
+ "\n",
+ "For each token position, OPD minimizes a weighted reverse KL from the student distribution to each teacher distribution:\n",
+ "\n",
+ "$$\\mathcal{L}_{\\text{OPD}} = \\sum_j w_j\\,D_{\\text{KL}}\\big(\\pi_\\theta(\\cdot \\mid x)\\;||\\;\\pi_{E_j}(\\cdot \\mid x)\\big)$$\n",
+ "\n",
+ "where\n",
+ "\n",
+ "$$D_{\\text{KL}}(p||q) = \\sum_v p(v)\\,[\\log p(v) - \\log q(v)].$$\n",
+ "\n",
+ "### Signature\n",
+ "```python\n",
+ "from torch import Tensor\n",
+ "\n",
+ "def opd_loss(student_logits: Tensor,\n",
+ " teacher_logits: Tensor,\n",
+ " teacher_weights: Tensor | None = None,\n",
+ " mask: Tensor | None = None,\n",
+ " temperature: float = 1.0) -> Tensor:\n",
+ " \"\"\"OPD reverse-KL distillation loss.\n",
+ "\n",
+ " student_logits: (..., V) logits from the student policy\n",
+ " teacher_logits: (..., V) for one teacher, or (T, ..., V) for T teachers\n",
+ " teacher_weights: optional (T,) weights, normalized internally\n",
+ " mask: optional (...) token mask, where 1 = include and 0 = ignore\n",
+ " temperature: softmax temperature used for distillation\n",
+ " returns: scalar loss (Tensor)\n",
+ " \"\"\"\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {},
+ "source": [
+ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n",
+ "try:\n",
+ " import google.colab\n",
+ " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
+ "except ImportError:\n",
+ " pass\n"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from torch import Tensor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# โ๏ธ YOUR IMPLEMENTATION HERE\n",
+ "\n",
+ "def opd_loss(student_logits: Tensor,\n",
+ " teacher_logits: Tensor,\n",
+ " teacher_weights: Tensor | None = None,\n",
+ " mask: Tensor | None = None,\n",
+ " temperature: float = 1.0) -> Tensor:\n",
+ " pass # reverse KL: sum_v p_student * (log p_student - log p_teacher)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ๐งช Debug\n",
+ "student_logits = torch.tensor([[[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]])\n",
+ "teacher_logits = torch.tensor([[[1.0, 1.5, -0.5], [0.0, 2.0, -1.0]]])\n",
+ "print('Loss:', opd_loss(student_logits, teacher_logits).item())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# โ
SUBMIT\n",
+ "from torch_judge import check\n",
+ "check('opd_loss')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/torch_judge/tasks/opd_loss.py b/torch_judge/tasks/opd_loss.py
new file mode 100644
index 0000000..c44d599
--- /dev/null
+++ b/torch_judge/tasks/opd_loss.py
@@ -0,0 +1,109 @@
+"""OPD (On-Policy Distillation) Loss task."""
+
+TASK = {
+ "title": "OPD (On-Policy Distillation) Loss",
+ "difficulty": "Hard",
+ "function_name": "opd_loss",
+ "hint": (
+ "Compute reverse KL from the student to each teacher: "
+ "KL(pi_student || pi_teacher) = sum_v p_student(v) * "
+ "(log p_student(v) - log p_teacher(v)). Average teacher KLs with "
+ "teacher_weights, apply mask over tokens if provided, and multiply by "
+ "temperature ** 2."
+ ),
+ "tests": [
+ {
+ "name": "Basic shape & type",
+ "code": "\n"
+ "import torch\n"
+ "from torch import Tensor\n"
+ "student_logits = torch.randn(2, 3, 5, requires_grad=True)\n"
+ "teacher_logits = torch.randn(2, 3, 5)\n"
+ "loss = {fn}(student_logits, teacher_logits)\n"
+ "assert isinstance(loss, Tensor) and loss.dim() == 0, 'Loss must be a scalar Tensor'\n"
+ },
+ {
+ "name": "Zero when student matches teacher",
+ "code": "\n"
+ "import torch\n"
+ "torch.manual_seed(0)\n"
+ "student_logits = torch.randn(2, 3, 4, requires_grad=True)\n"
+ "teacher_logits = student_logits.detach().clone()\n"
+ "loss = {fn}(student_logits, teacher_logits)\n"
+ "assert torch.allclose(loss, torch.tensor(0.0), atol=1e-6), f'Expected near-zero loss, got {loss.item():.8f}'\n"
+ },
+ {
+ "name": "Numeric check vs single-teacher reverse KL",
+ "code": "\n"
+ "import torch\n"
+ "import torch.nn.functional as F\n"
+ "student_logits = torch.tensor([[[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]])\n"
+ "teacher_logits = torch.tensor([[[1.0, 1.5, -0.5], [0.0, 2.0, -1.0]]])\n"
+ "student_logits = student_logits.requires_grad_()\n"
+ "loss = {fn}(student_logits, teacher_logits)\n"
+ "s_logp = F.log_softmax(student_logits, dim=-1)\n"
+ "t_logp = F.log_softmax(teacher_logits, dim=-1)\n"
+ "expected = (s_logp.exp() * (s_logp - t_logp)).sum(dim=-1).mean()\n"
+ "assert torch.allclose(loss, expected, atol=1e-6), f'{loss.item():.6f} vs {expected.item():.6f}'\n"
+ },
+ {
+ "name": "Multi-teacher weighted reverse KL",
+ "code": "\n"
+ "import torch\n"
+ "import torch.nn.functional as F\n"
+ "student_logits = torch.tensor([[[1.0, 0.0], [0.0, 1.0]]], requires_grad=True)\n"
+ "teacher_logits = torch.stack([\n"
+ " torch.tensor([[[2.0, -1.0], [1.0, 0.0]]]),\n"
+ " torch.tensor([[[-1.0, 2.0], [0.0, 1.0]]]),\n"
+ "])\n"
+ "weights = torch.tensor([0.25, 0.75])\n"
+ "loss = {fn}(student_logits, teacher_logits, teacher_weights=weights)\n"
+ "s_logp = F.log_softmax(student_logits, dim=-1)\n"
+ "s_prob = s_logp.exp()\n"
+ "t_logp = F.log_softmax(teacher_logits, dim=-1)\n"
+ "kl = (s_prob.unsqueeze(0) * (s_logp.unsqueeze(0) - t_logp)).sum(dim=-1)\n"
+ "expected = (weights.view(-1, 1, 1) * kl).sum(dim=0).mean()\n"
+ "assert torch.allclose(loss, expected, atol=1e-6), f'{loss.item():.6f} vs {expected.item():.6f}'\n"
+ },
+ {
+ "name": "Mask ignores padded tokens",
+ "code": "\n"
+ "import torch\n"
+ "import torch.nn.functional as F\n"
+ "student_logits = torch.tensor([[[2.0, 0.0], [0.0, 2.0], [1.0, 1.0]]], requires_grad=True)\n"
+ "teacher_logits = torch.tensor([[[0.0, 2.0], [0.0, 2.0], [100.0, -100.0]]])\n"
+ "mask = torch.tensor([[1.0, 1.0, 0.0]])\n"
+ "loss = {fn}(student_logits, teacher_logits, mask=mask)\n"
+ "s_logp = F.log_softmax(student_logits, dim=-1)\n"
+ "t_logp = F.log_softmax(teacher_logits, dim=-1)\n"
+ "per_token = (s_logp.exp() * (s_logp - t_logp)).sum(dim=-1)\n"
+ "expected = (per_token * mask).sum() / mask.sum()\n"
+ "assert torch.allclose(loss, expected, atol=1e-6), 'Masked positions should not affect the loss'\n"
+ },
+ {
+ "name": "Gradient flows only through student logits",
+ "code": "\n"
+ "import torch\n"
+ "student_logits = torch.randn(2, 3, 4, requires_grad=True)\n"
+ "teacher_logits = torch.randn(2, 3, 4, requires_grad=True)\n"
+ "loss = {fn}(student_logits, teacher_logits)\n"
+ "loss.backward()\n"
+ "assert student_logits.grad is not None, 'Student logits should receive gradients'\n"
+ "assert teacher_logits.grad is None, 'Teacher logits should be treated as frozen targets'\n"
+ },
+ {
+ "name": "Temperature scaling",
+ "code": "\n"
+ "import torch\n"
+ "import torch.nn.functional as F\n"
+ "student_logits = torch.tensor([[[2.0, 0.0, -1.0]]], requires_grad=True)\n"
+ "teacher_logits = torch.tensor([[[0.0, 1.0, -1.0]]])\n"
+ "temperature = 2.0\n"
+ "loss = {fn}(student_logits, teacher_logits, temperature=temperature)\n"
+ "s_logp = F.log_softmax(student_logits / temperature, dim=-1)\n"
+ "t_logp = F.log_softmax(teacher_logits / temperature, dim=-1)\n"
+ "expected = (s_logp.exp() * (s_logp - t_logp)).sum(dim=-1).mean() * (temperature ** 2)\n"
+ "assert torch.allclose(loss, expected, atol=1e-6), f'{loss.item():.6f} vs {expected.item():.6f}'\n"
+ },
+ ],
+}