From 1442e4abbeb4f004485a858d93a46318b78657ad Mon Sep 17 00:00:00 2001 From: Zuozhuo Date: Tue, 9 Jun 2026 12:30:11 +0800 Subject: [PATCH] feat: add OPD loss practice problem --- README.md | 7 +- solutions/41_opd_loss_solution.ipynb | 137 +++++++++++++++++++++++++++ templates/00_welcome.ipynb | 3 +- templates/41_opd_loss.ipynb | 120 +++++++++++++++++++++++ torch_judge/tasks/opd_loss.py | 109 +++++++++++++++++++++ 5 files changed, 372 insertions(+), 4 deletions(-) create mode 100644 solutions/41_opd_loss_solution.ipynb create mode 100644 templates/41_opd_loss.ipynb create mode 100644 torch_judge/tasks/opd_loss.py diff --git a/README.md b/README.md index 2b2fd2e..9847ea7 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Practice implementing operators and architectures from scratch โ€” the exact ski [![GitHub stars](https://img.shields.io/github/stars/duoan/TorchCode?style=social)](https://github.com/duoan/TorchCode) [![GitHub Container Registry](https://img.shields.io/badge/ghcr.io-TorchCode-blue?style=flat-square&logo=github)](https://ghcr.io/duoan/torchcode) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-TorchCode-blue?style=flat-square)](https://huggingface.co/spaces/duoan/TorchCode) -![Problems](https://img.shields.io/badge/problems-40-orange?style=flat-square) +![Problems](https://img.shields.io/badge/problems-41-orange?style=flat-square) ![GPU](https://img.shields.io/badge/GPU-not%20required-brightgreen?style=flat-square) [![Star History Chart](https://api.star-history.com/svg?repos=duoan/TorchCode&type=Date)](https://star-history.com/#duoan/TorchCode&Date) @@ -44,7 +44,7 @@ TorchCode gives you a **structured practice environment** with: | | Feature | | |---|---|---| -| ๐Ÿงฉ | **40 curated problems** | The most frequently asked PyTorch interview topics | +| ๐Ÿงฉ | **41 curated problems** | The most frequently asked PyTorch interview topics | | โš–๏ธ | **Automated judge** | Correctness checks, gradient verification, and timing | | ๐ŸŽจ | **Instant feedback** | Colored pass/fail per test case, just like competitive programming | | ๐Ÿ’ก | **Hints when stuck** | Nudges without full spoilers | @@ -198,6 +198,7 @@ If you're interviewing for any role touching LLMs or Transformers, expect at lea | 37 | DPO Loss Open In Colab | `dpo_loss(chosen, rejected, ...)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ๐Ÿ’ก | Direct preference optimization, alignment training | | 38 | GRPO Loss Open In Colab | `grpo_loss(logps, rewards, group_ids, eps)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ๐Ÿ’ก | Group relative policy optimization, RLAIF, within-group normalized advantages | | 39 | PPO Loss Open In Colab | `ppo_loss(new_logps, old_logps, advantages, clip_ratio)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ๐Ÿ’ก | PPO clipped surrogate loss, policy gradient, trust region | +| 41 | OPD Loss Open In Colab | `opd_loss(student_logits, teacher_logits, ...)` | ![Hard](https://img.shields.io/badge/Hard-F44336?style=flat-square) | ๐Ÿ’ก | On-policy distillation, reverse KL, multi-teacher alignment | --- @@ -244,7 +245,7 @@ status() # Progress dashboard โ€” solved / attempted / todo | **1** | ๐Ÿงฑ Foundations | ReLU โ†’ Softmax โ†’ CE Loss โ†’ Dropout โ†’ Embedding โ†’ GELU โ†’ Linear โ†’ LayerNorm โ†’ BatchNorm โ†’ RMSNorm โ†’ SwiGLU MLP โ†’ Conv2d | 2โ€“3 hrs | | **2** | ๐Ÿง  Attention Deep Dive | SDPA โ†’ MHA โ†’ Cross-Attn โ†’ Causal โ†’ GQA โ†’ KV Cache โ†’ Sliding Window โ†’ RoPE โ†’ Linear Attn โ†’ Flash Attn | 3โ€“4 hrs | | **3** | ๐Ÿ—๏ธ Architecture + Training | GPT-2 Block โ†’ LoRA โ†’ MoE โ†’ ViT Patch โ†’ Adam โ†’ Cosine LR โ†’ Grad Clip โ†’ Grad Accumulation โ†’ Kaiming Init | 3โ€“4 hrs | -| **4** | ๐ŸŽฏ Inference + Advanced | Top-k/p Sampling โ†’ Beam Search โ†’ Speculative Decoding โ†’ BPE โ†’ INT8 Quant โ†’ DPO Loss โ†’ GRPO Loss โ†’ PPO Loss + speed run | 3โ€“4 hrs | +| **4** | ๐ŸŽฏ Inference + Advanced | Top-k/p Sampling โ†’ Beam Search โ†’ Speculative Decoding โ†’ BPE โ†’ INT8 Quant โ†’ DPO Loss โ†’ GRPO Loss โ†’ PPO Loss โ†’ OPD Loss + speed run | 3โ€“4 hrs | --- diff --git a/solutions/41_opd_loss_solution.ipynb b/solutions/41_opd_loss_solution.ipynb new file mode 100644 index 0000000..e7c6590 --- /dev/null +++ b/solutions/41_opd_loss_solution.ipynb @@ -0,0 +1,137 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/41_opd_loss_solution.ipynb)\n", + "\n", + "# Solution: OPD (On-Policy Distillation) Loss\n", + "\n", + "Reference solution." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", + "try:\n", + " import google.colab\n", + " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", + "except ImportError:\n", + " pass\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# โœ… SOLUTION\n", + "\n", + "def opd_loss(student_logits: Tensor,\n", + " teacher_logits: Tensor,\n", + " teacher_weights: Tensor | None = None,\n", + " mask: Tensor | None = None,\n", + " temperature: float = 1.0) -> Tensor:\n", + " \"\"\"On-Policy Distillation (OPD) reverse-KL loss.\n", + "\n", + " student_logits: (..., V) logits from the student policy\n", + " teacher_logits: (..., V) for one teacher, or (T, ..., V) for T teachers\n", + " teacher_weights: optional (T,) weights, normalized internally\n", + " mask: optional (...) token mask, where 1 = include and 0 = ignore\n", + " temperature: softmax temperature used for distillation\n", + " returns: scalar loss (Tensor)\n", + " \"\"\"\n", + " if teacher_logits.dim() == student_logits.dim():\n", + " teacher_logits = teacher_logits.unsqueeze(0)\n", + " elif teacher_logits.dim() != student_logits.dim() + 1:\n", + " raise ValueError('teacher_logits must have shape (..., V) or (T, ..., V)')\n", + "\n", + " teacher_logits = teacher_logits.detach()\n", + " t = float(temperature)\n", + "\n", + " student_logp = F.log_softmax(student_logits / t, dim=-1)\n", + " student_prob = student_logp.exp()\n", + " teacher_logp = F.log_softmax(teacher_logits / t, dim=-1)\n", + "\n", + " per_teacher_kl = (\n", + " student_prob.unsqueeze(0) * (student_logp.unsqueeze(0) - teacher_logp)\n", + " ).sum(dim=-1)\n", + "\n", + " num_teachers = per_teacher_kl.shape[0]\n", + " if teacher_weights is None:\n", + " weights = torch.full(\n", + " (num_teachers,),\n", + " 1.0 / num_teachers,\n", + " dtype=per_teacher_kl.dtype,\n", + " device=per_teacher_kl.device,\n", + " )\n", + " else:\n", + " weights = teacher_weights.to(dtype=per_teacher_kl.dtype, device=per_teacher_kl.device)\n", + " weights = weights / weights.sum()\n", + "\n", + " view_shape = (num_teachers,) + (1,) * (per_teacher_kl.dim() - 1)\n", + " per_token_kl = (weights.view(view_shape) * per_teacher_kl).sum(dim=0)\n", + "\n", + " if mask is not None:\n", + " mask = mask.to(dtype=per_token_kl.dtype, device=per_token_kl.device)\n", + " loss = (per_token_kl * mask).sum() / mask.sum().clamp_min(1.0)\n", + " else:\n", + " loss = per_token_kl.mean()\n", + "\n", + " return loss * (t ** 2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Demo\n", + "student_logits = torch.tensor([[[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]])\n", + "teacher_logits = torch.tensor([[[1.0, 1.5, -0.5], [0.0, 2.0, -1.0]]])\n", + "print('Loss:', opd_loss(student_logits, teacher_logits).item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch_judge import check\n", + "check('opd_loss')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/templates/00_welcome.ipynb b/templates/00_welcome.ipynb index c3498f5..e8e46ef 100644 --- a/templates/00_welcome.ipynb +++ b/templates/00_welcome.ipynb @@ -64,7 +64,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Problem List (40 problems)\n", + "## Problem List (41 problems)\n", "\n", "### ๐Ÿงฑ Fundamentals โ€” \"Implement X from scratch\"\n", "\n", @@ -135,6 +135,7 @@ "| 37 | DPO Loss | ๐Ÿ”ด Hard | [Open](37_dpo_loss.ipynb) ยท GitHub ยท Colab | [Open](37_dpo_loss_solution.ipynb) ยท GitHub ยท Colab |\n", "| 38 | GRPO Loss | ๐Ÿ”ด Hard | [Open](38_grpo_loss.ipynb) ยท GitHub ยท Colab | [Open](38_grpo_loss_solution.ipynb) ยท GitHub ยท Colab |\n", "| 39 | PPO Loss | ๐Ÿ”ด Hard | [Open](39_ppo_loss.ipynb) ยท GitHub ยท Colab | [Open](39_ppo_loss_solution.ipynb) ยท GitHub ยท Colab |\n", + "| 41 | OPD Loss | ๐Ÿ”ด Hard | [Open](41_opd_loss.ipynb) ยท GitHub ยท Colab | [Open](41_opd_loss_solution.ipynb) ยท GitHub ยท Colab |\n", "\n", "## Useful Commands\n", "\n", diff --git a/templates/41_opd_loss.ipynb b/templates/41_opd_loss.ipynb new file mode 100644 index 0000000..d450fad --- /dev/null +++ b/templates/41_opd_loss.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/41_opd_loss.ipynb)\n", + "\n", + "# ๐Ÿ”ด Hard: OPD Loss\n", + "\n", + "Implement the **On-Policy Distillation (OPD)** loss used to distill one or more teacher policies into a student policy on trajectories sampled from the student.\n", + "\n", + "For each token position, OPD minimizes a weighted reverse KL from the student distribution to each teacher distribution:\n", + "\n", + "$$\\mathcal{L}_{\\text{OPD}} = \\sum_j w_j\\,D_{\\text{KL}}\\big(\\pi_\\theta(\\cdot \\mid x)\\;||\\;\\pi_{E_j}(\\cdot \\mid x)\\big)$$\n", + "\n", + "where\n", + "\n", + "$$D_{\\text{KL}}(p||q) = \\sum_v p(v)\\,[\\log p(v) - \\log q(v)].$$\n", + "\n", + "### Signature\n", + "```python\n", + "from torch import Tensor\n", + "\n", + "def opd_loss(student_logits: Tensor,\n", + " teacher_logits: Tensor,\n", + " teacher_weights: Tensor | None = None,\n", + " mask: Tensor | None = None,\n", + " temperature: float = 1.0) -> Tensor:\n", + " \"\"\"OPD reverse-KL distillation loss.\n", + "\n", + " student_logits: (..., V) logits from the student policy\n", + " teacher_logits: (..., V) for one teacher, or (T, ..., V) for T teachers\n", + " teacher_weights: optional (T,) weights, normalized internally\n", + " mask: optional (...) token mask, where 1 = include and 0 = ignore\n", + " temperature: softmax temperature used for distillation\n", + " returns: scalar loss (Tensor)\n", + " \"\"\"\n", + "```" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", + "try:\n", + " import google.colab\n", + " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", + "except ImportError:\n", + " pass\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", + "\n", + "def opd_loss(student_logits: Tensor,\n", + " teacher_logits: Tensor,\n", + " teacher_weights: Tensor | None = None,\n", + " mask: Tensor | None = None,\n", + " temperature: float = 1.0) -> Tensor:\n", + " pass # reverse KL: sum_v p_student * (log p_student - log p_teacher)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ๐Ÿงช Debug\n", + "student_logits = torch.tensor([[[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]])\n", + "teacher_logits = torch.tensor([[[1.0, 1.5, -0.5], [0.0, 2.0, -1.0]]])\n", + "print('Loss:', opd_loss(student_logits, teacher_logits).item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# โœ… SUBMIT\n", + "from torch_judge import check\n", + "check('opd_loss')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/torch_judge/tasks/opd_loss.py b/torch_judge/tasks/opd_loss.py new file mode 100644 index 0000000..c44d599 --- /dev/null +++ b/torch_judge/tasks/opd_loss.py @@ -0,0 +1,109 @@ +"""OPD (On-Policy Distillation) Loss task.""" + +TASK = { + "title": "OPD (On-Policy Distillation) Loss", + "difficulty": "Hard", + "function_name": "opd_loss", + "hint": ( + "Compute reverse KL from the student to each teacher: " + "KL(pi_student || pi_teacher) = sum_v p_student(v) * " + "(log p_student(v) - log p_teacher(v)). Average teacher KLs with " + "teacher_weights, apply mask over tokens if provided, and multiply by " + "temperature ** 2." + ), + "tests": [ + { + "name": "Basic shape & type", + "code": "\n" + "import torch\n" + "from torch import Tensor\n" + "student_logits = torch.randn(2, 3, 5, requires_grad=True)\n" + "teacher_logits = torch.randn(2, 3, 5)\n" + "loss = {fn}(student_logits, teacher_logits)\n" + "assert isinstance(loss, Tensor) and loss.dim() == 0, 'Loss must be a scalar Tensor'\n" + }, + { + "name": "Zero when student matches teacher", + "code": "\n" + "import torch\n" + "torch.manual_seed(0)\n" + "student_logits = torch.randn(2, 3, 4, requires_grad=True)\n" + "teacher_logits = student_logits.detach().clone()\n" + "loss = {fn}(student_logits, teacher_logits)\n" + "assert torch.allclose(loss, torch.tensor(0.0), atol=1e-6), f'Expected near-zero loss, got {loss.item():.8f}'\n" + }, + { + "name": "Numeric check vs single-teacher reverse KL", + "code": "\n" + "import torch\n" + "import torch.nn.functional as F\n" + "student_logits = torch.tensor([[[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]])\n" + "teacher_logits = torch.tensor([[[1.0, 1.5, -0.5], [0.0, 2.0, -1.0]]])\n" + "student_logits = student_logits.requires_grad_()\n" + "loss = {fn}(student_logits, teacher_logits)\n" + "s_logp = F.log_softmax(student_logits, dim=-1)\n" + "t_logp = F.log_softmax(teacher_logits, dim=-1)\n" + "expected = (s_logp.exp() * (s_logp - t_logp)).sum(dim=-1).mean()\n" + "assert torch.allclose(loss, expected, atol=1e-6), f'{loss.item():.6f} vs {expected.item():.6f}'\n" + }, + { + "name": "Multi-teacher weighted reverse KL", + "code": "\n" + "import torch\n" + "import torch.nn.functional as F\n" + "student_logits = torch.tensor([[[1.0, 0.0], [0.0, 1.0]]], requires_grad=True)\n" + "teacher_logits = torch.stack([\n" + " torch.tensor([[[2.0, -1.0], [1.0, 0.0]]]),\n" + " torch.tensor([[[-1.0, 2.0], [0.0, 1.0]]]),\n" + "])\n" + "weights = torch.tensor([0.25, 0.75])\n" + "loss = {fn}(student_logits, teacher_logits, teacher_weights=weights)\n" + "s_logp = F.log_softmax(student_logits, dim=-1)\n" + "s_prob = s_logp.exp()\n" + "t_logp = F.log_softmax(teacher_logits, dim=-1)\n" + "kl = (s_prob.unsqueeze(0) * (s_logp.unsqueeze(0) - t_logp)).sum(dim=-1)\n" + "expected = (weights.view(-1, 1, 1) * kl).sum(dim=0).mean()\n" + "assert torch.allclose(loss, expected, atol=1e-6), f'{loss.item():.6f} vs {expected.item():.6f}'\n" + }, + { + "name": "Mask ignores padded tokens", + "code": "\n" + "import torch\n" + "import torch.nn.functional as F\n" + "student_logits = torch.tensor([[[2.0, 0.0], [0.0, 2.0], [1.0, 1.0]]], requires_grad=True)\n" + "teacher_logits = torch.tensor([[[0.0, 2.0], [0.0, 2.0], [100.0, -100.0]]])\n" + "mask = torch.tensor([[1.0, 1.0, 0.0]])\n" + "loss = {fn}(student_logits, teacher_logits, mask=mask)\n" + "s_logp = F.log_softmax(student_logits, dim=-1)\n" + "t_logp = F.log_softmax(teacher_logits, dim=-1)\n" + "per_token = (s_logp.exp() * (s_logp - t_logp)).sum(dim=-1)\n" + "expected = (per_token * mask).sum() / mask.sum()\n" + "assert torch.allclose(loss, expected, atol=1e-6), 'Masked positions should not affect the loss'\n" + }, + { + "name": "Gradient flows only through student logits", + "code": "\n" + "import torch\n" + "student_logits = torch.randn(2, 3, 4, requires_grad=True)\n" + "teacher_logits = torch.randn(2, 3, 4, requires_grad=True)\n" + "loss = {fn}(student_logits, teacher_logits)\n" + "loss.backward()\n" + "assert student_logits.grad is not None, 'Student logits should receive gradients'\n" + "assert teacher_logits.grad is None, 'Teacher logits should be treated as frozen targets'\n" + }, + { + "name": "Temperature scaling", + "code": "\n" + "import torch\n" + "import torch.nn.functional as F\n" + "student_logits = torch.tensor([[[2.0, 0.0, -1.0]]], requires_grad=True)\n" + "teacher_logits = torch.tensor([[[0.0, 1.0, -1.0]]])\n" + "temperature = 2.0\n" + "loss = {fn}(student_logits, teacher_logits, temperature=temperature)\n" + "s_logp = F.log_softmax(student_logits / temperature, dim=-1)\n" + "t_logp = F.log_softmax(teacher_logits / temperature, dim=-1)\n" + "expected = (s_logp.exp() * (s_logp - t_logp)).sum(dim=-1).mean() * (temperature ** 2)\n" + "assert torch.allclose(loss, expected, atol=1e-6), f'{loss.item():.6f} vs {expected.item():.6f}'\n" + }, + ], +}