From ebf88d64c5a080f0925b4fb1fd889d9f78efe4cc Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 3 Jun 2026 11:57:20 -0700 Subject: [PATCH] Restructure Python custom operators guide Split the custom Python operators tutorial into an overview plus focused pages for functional operators, mutable operators, and optional registrations. Emphasize required schema and mutation/aliasing contracts, opcheck validation, fake kernels for torch.compile/export, and 2.13-only in-place/out custom operator behavior. Add the new pages under the custom operators landing page so the left navigation exposes a nested custom-ops section and users can move through the guide in order. Validated with lintrunner -m main and make clean-cache && make html-noplot. --- advanced_source/custom_ops_landing_page.rst | 15 +- advanced_source/python_custom_ops.py | 278 ------------------ advanced_source/python_custom_ops.rst | 158 ++++++++++ .../python_custom_ops_functional.py | 133 +++++++++ advanced_source/python_custom_ops_mutable.py | 208 +++++++++++++ .../python_custom_ops_registrations.py | 136 +++++++++ extension.rst | 6 +- index.rst | 2 +- 8 files changed, 650 insertions(+), 286 deletions(-) delete mode 100644 advanced_source/python_custom_ops.py create mode 100644 advanced_source/python_custom_ops.rst create mode 100644 advanced_source/python_custom_ops_functional.py create mode 100644 advanced_source/python_custom_ops_mutable.py create mode 100644 advanced_source/python_custom_ops_registrations.py diff --git a/advanced_source/custom_ops_landing_page.rst b/advanced_source/custom_ops_landing_page.rst index f05eee43060..8065e560a35 100644 --- a/advanced_source/custom_ops_landing_page.rst +++ b/advanced_source/custom_ops_landing_page.rst @@ -10,12 +10,21 @@ In order to do so, you must register the custom operation with PyTorch via the P `torch.library docs `_ or C++ ``TORCH_LIBRARY`` APIs. +.. toctree:: + :maxdepth: 2 + :hidden: + + python_custom_ops + cpp_custom_ops Authoring a custom operator from Python ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Please see :ref:`python-custom-ops-tutorial`. +Please see :ref:`python-custom-ops-tutorial` for the Python guide. It covers +the required schema and mutation/aliasing contract, ``opcheck``, functional +operators, mutable operators, and optional registrations such as autograd and +``torch.vmap``. You may wish to author a custom operator from Python (as opposed to C++) if: @@ -43,8 +52,8 @@ The Custom Operators Manual ^^^^^^^^^^^^^^^^^^^^^^^^^^^ For information not covered in the tutorials and this page, please see -`The Custom Operators Manual `_ -(we're working on moving the information to our docs site). We recommend that you +`The Custom Operators Manual `_. +We recommend that you first read one of the tutorials above and then use the Custom Operators Manual as a reference; it is not meant to be read head to toe. diff --git a/advanced_source/python_custom_ops.py b/advanced_source/python_custom_ops.py deleted file mode 100644 index 1f20125f785..00000000000 --- a/advanced_source/python_custom_ops.py +++ /dev/null @@ -1,278 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -.. _python-custom-ops-tutorial: - -Custom Python Operators -======================= - -.. grid:: 2 - - .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn - :class-card: card-prerequisites - - * How to integrate custom operators written in Python with PyTorch - * How to test custom operators using ``torch.library.opcheck`` - - .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites - :class-card: card-prerequisites - - * PyTorch 2.4 or later - -PyTorch offers a large library of operators that work on Tensors (e.g. -``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized -operator with PyTorch, perhaps written by a third-party library. This tutorial -shows how to wrap Python functions so that they behave like PyTorch native -operators. Reasons why you may wish to create a custom operator in PyTorch include: - -- Treating an arbitrary Python function as an opaque callable with respect - to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing - into the function). -- Adding training support to an arbitrary Python function - -Use :func:`torch.library.custom_op` to create Python custom operators. -Use the C++ ``TORCH_LIBRARY`` APIs to create C++ custom operators (these -work in Python-less environments). -See the `Custom Operators Landing Page `_ -for more details. - -Please note that if your operation can be expressed as a composition of -existing PyTorch operators, then there is usually no need to use the custom operator -API -- everything (for example ``torch.compile``, training support) should -just work. -""" -###################################################################### -# Example: Wrapping PIL's crop into a custom operator -# ------------------------------------ -# Let's say that we are using PIL's ``crop`` operation. - -import torch -from torchvision.transforms.functional import to_pil_image, pil_to_tensor -import PIL -import IPython -import matplotlib.pyplot as plt - -def crop(pic, box): - img = to_pil_image(pic.cpu()) - cropped_img = img.crop(box) - return pil_to_tensor(cropped_img).to(pic.device) / 255. - -def display(img): - plt.imshow(img.numpy().transpose((1, 2, 0))) - -img = torch.ones(3, 64, 64) -img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1) -display(img) - -###################################################################### - -cropped_img = crop(img, (10, 10, 50, 50)) -display(cropped_img) - -###################################################################### -# ``crop`` is not handled effectively out-of-the-box by -# ``torch.compile``: ``torch.compile`` induces a -# `"graph break" `_ -# on functions it is unable to handle and graph breaks are bad for performance. -# The following code demonstrates this by raising an error -# (``torch.compile`` with ``fullgraph=True`` raises an error if a -# graph break occurs). - -@torch.compile(fullgraph=True) -def f(img): - return crop(img, (10, 10, 50, 50)) - -# The following raises an error. Uncomment the line to see it. -# cropped_img = f(img) - -###################################################################### -# In order to black-box ``crop`` for use with ``torch.compile``, we need to -# do two things: -# -# 1. wrap the function into a PyTorch custom operator. -# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator. -# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage), -# this function should return dummy Tensors of your choice with the correct -# Tensor metadata (shape/strides/``dtype``/device). - - -from typing import Sequence - -# Use torch.library.custom_op to define a new custom operator. -# If your operator mutates any input Tensors, their names must be specified -# in the ``mutates_args`` argument. -@torch.library.custom_op("mylib::crop", mutates_args=()) -def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor: - img = to_pil_image(pic.cpu()) - cropped_img = img.crop(box) - return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype) - -# Use register_fake to add a ``FakeTensor`` kernel for the operator -@crop.register_fake -def _(pic, box): - channels = pic.shape[0] - x0, y0, x1, y1 = box - result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1) - # The result should have the same metadata (shape/strides/``dtype``/device) - # as running the ``crop`` function above. - return result - -###################################################################### -# After this, ``crop`` now works without graph breaks: - -@torch.compile(fullgraph=True) -def f(img): - return crop(img, (10, 10, 50, 50)) - -cropped_img = f(img) -display(img) - -###################################################################### - -display(cropped_img) - -###################################################################### -# Adding training support for crop -# -------------------------------- -# Use ``torch.library.register_autograd`` to add training support for an operator. -# Prefer this over directly using ``torch.autograd.Function``; some compositions of -# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and -# has led to) silent incorrectness when composed with ``torch.compile``. -# -# If you don't need training support, there is no need to use -# ``torch.library.register_autograd``. -# If you end up training with a ``custom_op`` that doesn't have an autograd -# registration, we'll raise an error message. -# -# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the -# derivation as an exercise to the reader). Let's first wrap ``paste`` into a -# custom operator: - -@torch.library.custom_op("mylib::paste", mutates_args=()) -def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor: - assert im1.device == im2.device - assert im1.dtype == im2.dtype - im1_pil = to_pil_image(im1.cpu()) - im2_pil = to_pil_image(im2.cpu()) - PIL.Image.Image.paste(im1_pil, im2_pil, coord) - return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype) - -@paste.register_fake -def _(im1, im2, coord): - assert im1.device == im2.device - assert im1.dtype == im2.dtype - return torch.empty_like(im1) - -###################################################################### -# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``: - -def backward(ctx, grad_output): - grad_input = grad_output.new_zeros(ctx.pic_shape) - grad_input = paste(grad_input, grad_output, ctx.coords) - return grad_input, None - -def setup_context(ctx, inputs, output): - pic, box = inputs - ctx.coords = box[:2] - ctx.pic_shape = pic.shape - -crop.register_autograd(backward, setup_context=setup_context) - -###################################################################### -# Note that the backward must be a composition of PyTorch-understood operators, -# which is why we wrapped paste into a custom operator instead of directly using -# PIL's paste. - -img = img.requires_grad_() -result = crop(img, (10, 10, 50, 50)) -result.sum().backward() -display(img.grad) - -###################################################################### -# This is the correct gradient, with 1s (white) in the cropped region and 0s -# (black) in the unused region. - -###################################################################### -# Testing Python Custom operators -# ------------------------------- -# Use ``torch.library.opcheck`` to test that the custom operator was registered -# correctly. This does not test that the gradients are mathematically correct; -# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``). -# -# To use ``opcheck``, pass it a set of example inputs to test against. If your -# operator supports training, then the examples should include Tensors that -# require grad. If your operator supports multiple devices, then the examples -# should include Tensors from each device. - -examples = [ - [torch.randn(3, 64, 64), [0, 0, 10, 10]], - [torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]], - [torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]], - [torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]], -] - -for example in examples: - torch.library.opcheck(crop, example) - -###################################################################### -# Mutable Python Custom operators -# ------------------------------- -# You can also wrap a Python function that mutates its inputs into a custom -# operator. -# Functions that mutate inputs are common because that is how many low-level -# kernels are written; for example, a kernel that computes ``sin`` may take in -# the input and an output tensor and write ``input.sin()`` to the output tensor. -# -# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python -# custom operator. - -import numpy as np - -@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu") -def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None: - assert input.device == output.device - assert input.device.type == "cpu" - input_np = input.numpy() - output_np = output.numpy() - np.sin(input_np, out=output_np) - -###################################################################### -# Because the operator doesn't return anything, there is no need to register -# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``. - -@torch.compile(fullgraph=True) -def f(x): - out = torch.empty(3) - numpy_sin(x, out) - return out - -x = torch.randn(3) -y = f(x) -assert torch.allclose(y, x.sin()) - -###################################################################### -# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly. -# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example. - -example_inputs = [ - [torch.randn(3), torch.empty(3)], - [torch.randn(0, 3), torch.empty(0, 3)], - [torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)], -] - -for example in example_inputs: - torch.library.opcheck(numpy_sin, example) - -###################################################################### -# Conclusion -# ---------- -# In this tutorial, we learned how to use ``torch.library.custom_op`` to -# create a custom operator in Python that works with PyTorch subsystems -# such as ``torch.compile`` and autograd. -# -# This tutorial provides a basic introduction to custom operators. -# For more detailed information, see: -# -# - `the torch.library documentation `_ -# - `the Custom Operators Manual `_ -# diff --git a/advanced_source/python_custom_ops.rst b/advanced_source/python_custom_ops.rst new file mode 100644 index 00000000000..915b52709fa --- /dev/null +++ b/advanced_source/python_custom_ops.rst @@ -0,0 +1,158 @@ +.. _python-custom-ops-tutorial: + +Custom Python Operators +======================= + +.. grid:: 1 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites + + * When to create a Python custom operator + * How to choose between functional and mutable operator contracts + * Why the schema and mutation/aliasing contract are required + * Where fake kernels, autograd, and other registrations fit + +PyTorch offers a large library of operators that work on Tensors, such as +``torch.add`` and ``torch.sum``. However, you might wish to use a new customized +operator with PyTorch, perhaps written by a third-party library. This guide +shows how to wrap Python functions so that they behave like PyTorch native +operators. + +Reasons why you may wish to create a custom operator in PyTorch include: + +* treating an arbitrary Python function as an opaque callable with respect to + ``torch.compile``; and +* adding training support to an arbitrary Python function. + +Please note that if your operation can be expressed as a composition of +existing PyTorch operators, then there is usually no need to use the custom +operator API. ``torch.compile``, training support, and other PyTorch subsystems +should usually work. + +Every custom operator needs: + +* a stable schema and mutation/aliasing contract; +* validation with ``torch.library.opcheck``; +* a fake kernel if it returns tensors and must work with ``torch.compile`` or + ``torch.export``. + +Choose one path: + +* :ref:`Functional custom operators `: the + operator returns fresh tensors and mutates no inputs. +* :ref:`Mutable custom operators `: the operator + mutates an input or writes into an output buffer. Starting in PyTorch 2.13, + this includes PyTorch-style in-place and ``out=`` custom operators. +* :ref:`Optional registrations `: add + autograd, ``torch.vmap``, Tensor subclass behavior, or other subsystem + support after the base operator passes ``opcheck``. + +.. dropdown:: Choose your path + :open: + + .. list-table:: + :header-rows: 1 + :widths: 30 30 40 + + * - If you have... + - Read... + - Required pieces + * - Any custom operator + - :ref:`Schema and mutation/aliasing contract ` and + :ref:`Validation ` + - A stable schema, representative examples, and ``opcheck`` + * - Code that returns new tensors and does not mutate inputs + - :ref:`Functional custom operators ` + - ``custom_op(..., mutates_args=())``, a fake kernel for + ``torch.compile``, and ``opcheck`` + * - A kernel that writes into existing memory + - :ref:`Mutable custom operators ` + - Accurate ``mutates_args`` and one clear mutation pattern + * - In-place, ``out=``, or maybe-out behavior + - :ref:`Mutable custom operators ` and + :ref:`Schema contract ` + - PyTorch 2.13 or later for tagged in-place/``out=`` custom + operators; split maybe-out behavior into separate operators + * - Training support, ``vmap``, or Tensor subclass behavior + - :ref:`Adding registrations ` + - A validated base operator plus the registration for that subsystem + +For Python-less environments or AOTInductor, define the operator and backend +kernels in C++ instead. See the +:ref:`C++ custom operator tutorial `. + +.. toctree:: + :maxdepth: 1 + :hidden: + + python_custom_ops_functional + python_custom_ops_mutable + python_custom_ops_registrations + +Before you start +---------------- + +A kernel is the implementation. An operator is the PyTorch-facing contract: +name, inputs, outputs, mutation behavior, and subsystem registrations. + +A custom operator gives PyTorch an explicit boundary. Use it when tracing into +the implementation is impossible or undesirable. + +.. _python-custom-ops-schema-contract: + +Required: schema and mutation/aliasing contract +------------------------------------------------ + +Decide the schema and mutation/aliasing contract before writing registrations. +PyTorch uses the schema and registrations to reason about aliasing; it does not +infer the contract from the Python body. + +* The schema must be stable. Do not change mutation or aliasing behavior based + on values, shapes, devices, dtypes, or optional arguments. +* A functional custom operator must return fresh tensors. Do not return an + input tensor, a view of an input, or two outputs that alias each other. +* A mutable custom operator must list every mutated argument in ``mutates_args``. +* A fake kernel must return tensors with the same metadata as the real kernel: + shape, dtype, device, layout, strides, and storage offset when relevant. + ``empty_like(x)`` is only correct when the real output has the same metadata + as ``x``. +* Fake kernels may inspect metadata, but must not read tensor data. +* Avoid "maybe-out" operators. An operator that sometimes allocates a new + tensor and sometimes writes into an output buffer has different aliasing + contracts for different calls. + +Split maybe-out behavior into two operators: one functional operator that +allocates and one mutable operator that writes into an output buffer. Starting +in PyTorch 2.13, the mutable page shows the executable ``out=`` form. On older +versions, expose a mutable operator that writes to an output buffer and returns +``None``. + +.. _python-custom-ops-validation: + +Required: validate with opcheck +------------------------------- + +``torch.library.opcheck`` validates the registration contract: schema, fake +kernel, autograd registration, and behavior under compilation APIs. + +Run ``opcheck`` on representative inputs: + +* each supported device; +* important dtypes; +* edge shapes such as empty tensors; +* important memory formats or non-contiguous strides; +* inputs with ``requires_grad=True`` if the operator supports training. + +``opcheck`` is not a numerical correctness test. Use +``torch.testing.assert_close`` or ordinary unit tests for forward correctness, +and ``torch.autograd.gradcheck`` for gradient formulas. + +Next steps +---------- + +Read one base-contract page first, then add registrations only if needed: + +* :ref:`Functional custom operators for torch.compile ` +* :ref:`Mutable custom operators ` +* :ref:`Adding training and other registrations ` diff --git a/advanced_source/python_custom_ops_functional.py b/advanced_source/python_custom_ops_functional.py new file mode 100644 index 00000000000..823444fbca2 --- /dev/null +++ b/advanced_source/python_custom_ops_functional.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + +""" +.. _python-custom-ops-functional: + +Functional Python Custom Operators for torch.compile +==================================================== + +Use this path when the operator mutates no Tensor inputs and returns fresh +Tensor outputs. + +If the operator must work with ``torch.compile`` or ``torch.export``, register a +fake kernel. The fake kernel describes output metadata without running the real +kernel. + +Before writing the operator, read the required schema and mutation/aliasing +contract rules in :ref:`python-custom-ops-schema-contract`. + +Checklist: + +* use ``mutates_args=()``; +* return tensors that do not alias any input; +* register a fake kernel for ``torch.compile`` and ``torch.export``; +* validate the operator with ``torch.library.opcheck``. +""" + +###################################################################### +# Example: wrapping NumPy sin into a custom operator +# -------------------------------------------------- +# Let's say that we are using NumPy's ``sin`` operation. This is an ordinary +# Python function from PyTorch's point of view: it converts the Tensor to a +# NumPy array, calls NumPy, and returns a fresh Tensor. + +import numpy as np +import torch +from torch import Tensor + + +def numpy_sin_impl(x: Tensor) -> Tensor: + result = torch.empty_like(x) + np.sin(x.detach().numpy(), out=result.numpy()) + return result + + +x = torch.randn(5) +torch.testing.assert_close(numpy_sin_impl(x), x.sin()) + +# This small example focuses on the custom-operator mechanics. More complex +# Python or third-party library calls may not be handled effectively +# out-of-the-box by ``torch.compile``: ``torch.compile`` may induce a +# `"graph break" `_ +# on functions it is unable to handle, and graph breaks are bad for performance. +# A custom operator gives PyTorch an explicit boundary for such code. +# +# To make ``numpy_sin_impl`` available as a custom operator that works with +# ``torch.compile`` and ``torch.export``, we need to do two things: +# +# 1. wrap the function into a PyTorch custom operator. +# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator. +# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage), +# this function should return dummy Tensors of your choice with the correct +# Tensor metadata (shape/strides/``dtype``/device). + + +@torch.library.custom_op( + "mylib_functional::numpy_sin", + mutates_args=(), + device_types="cpu", +) +def numpy_sin(x: Tensor) -> Tensor: + result = torch.empty_like(x) + np.sin(x.detach().numpy(), out=result.numpy()) + return result + + +###################################################################### +# Use ``register_fake`` to add a ``FakeTensor`` kernel for the operator. +# ``numpy_sin`` returns one Tensor with the same shape, strides, dtype, device, +# and storage offset as ``torch.empty_like(x)``, so the fake kernel can return +# ``empty_like(x)``. In general, the fake kernel must match all output metadata, +# including storage offset when relevant. + + +@numpy_sin.register_fake +def _(x): + return torch.empty_like(x) + + +###################################################################### +# After this, ``numpy_sin`` can be used under ``torch.compile``: + + +@torch.compile(fullgraph=True) +def f(x): + return numpy_sin(x) + + +result = f(x) +torch.testing.assert_close(result, x.sin()) + +###################################################################### +# A PIL image transform, Python binding to a C++ extension, or another +# third-party library call follows the same pattern. If it returns tensors, +# write the fake kernel to match the real output metadata exactly: shape, +# strides, dtype, device, layout, and storage offset when relevant. + +###################################################################### +# Testing Python custom operators +# ------------------------------- +# Use ``torch.library.opcheck`` to test that the custom operator was registered +# correctly. This does not test numerical correctness; write separate tests for +# that. +# +# To use ``opcheck``, pass it a set of example inputs to test against. If your +# operator supports training, then the examples should include Tensors that +# require grad. If your operator supports multiple devices, then the examples +# should include Tensors from each device. + + +examples = [ + (torch.randn(5),), + (torch.randn(0, 3),), + (torch.randn(2, 3, dtype=torch.double),), + (torch.randn(2, 3).t(),), + (torch.randn(8)[1:],), +] + +for example in examples: + torch.library.opcheck(numpy_sin, example) + +###################################################################### +# To add autograd, ``torch.vmap``, or other subsystem support, continue to +# :ref:`python-custom-ops-registrations`. diff --git a/advanced_source/python_custom_ops_mutable.py b/advanced_source/python_custom_ops_mutable.py new file mode 100644 index 00000000000..d25bcc55a53 --- /dev/null +++ b/advanced_source/python_custom_ops_mutable.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- + +""" +.. _python-custom-ops-mutable: + +Mutable Python Custom Operators +=============================== + +The functional page wrapped ``numpy.sin`` as an operator that returns a fresh +Tensor. This page shows the mutable version: a kernel that writes +``sin(x)`` into an existing output Tensor. Mutable operators have a different +contract from functional operators. + +Before writing the operator, read the required schema and mutation/aliasing +contract rules in :ref:`python-custom-ops-schema-contract`. + +Checklist: + +* choose one mutation pattern and keep it stable; +* list every mutated Tensor argument in ``mutates_args``; +* do not return mutated inputs unless you are using a tagged in-place or + ``out=`` operator, starting in PyTorch 2.13; +* validate the operator with ``torch.library.opcheck``. +""" + +###################################################################### +# Choose one mutation contract +# ---------------------------- +# Choose the mutation behavior before adding optional registrations. PyTorch +# needs this contract for functionalization, fake tensors, ``torch.compile``, +# and autograd. +# +# If the operator does not mutate a Tensor input, use the functional operator +# path instead. +# +# If the operator mutates the first positional Tensor and returns it, use a +# tagged in-place operator, starting in PyTorch 2.13. +# +# If the operator mutates keyword-only ``out=`` Tensor arguments and returns +# them, use a tagged ``out=`` operator, starting in PyTorch 2.13. Do not read +# from the ``out=`` tensors. +# +# For other mutable operators, list every mutated argument in ``mutates_args`` +# and do not return mutated inputs or their aliases. + +import numpy as np +import torch +from torch import Tensor + + +###################################################################### +# Example: write NumPy sin into an output buffer +# ---------------------------------------------- +# Functions that mutate inputs are common because that is how many low-level +# kernels are written; for example, a kernel that computes ``sin`` may take in +# the input and an output tensor and write ``input.sin()`` to the output tensor. +# +# This operator writes ``sin(x)`` into ``out`` and returns ``None``. + + +@torch.library.custom_op( + "mylib_mutable::numpy_sin_out", + mutates_args={"out"}, + device_types="cpu", +) +def numpy_sin_out(x: Tensor, out: Tensor) -> None: + if x.shape != out.shape: + raise RuntimeError("x and out must have the same shape") + if x.dtype != out.dtype: + raise RuntimeError("x and out must have the same dtype") + if x.device != out.device: + raise RuntimeError("x and out must be on the same device") + np.sin(x.detach().numpy(), out=out.numpy()) + + +x = torch.randn(5) +out = torch.empty_like(x) +numpy_sin_out(x, out) +torch.testing.assert_close(out, x.sin()) + +###################################################################### +# Because the operator doesn't return anything, there is no need to register a +# ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``. +# If a mutable operator also returns a fresh Tensor, register a fake kernel for +# that output. + + +@torch.compile(fullgraph=True) +def compiled_numpy_sin_out(x): + out = torch.empty_like(x) + numpy_sin_out(x, out) + return out + + +torch.testing.assert_close(compiled_numpy_sin_out(x), x.sin()) + +###################################################################### +# PyTorch-style in-place and out= operators +# ----------------------------------------- +# Starting in PyTorch 2.13, ``torch.library.custom_op`` supports tagged +# in-place and ``out=`` custom operators. +# Tagged in-place operators return the same Tensor they mutate. Tagged ``out=`` +# operators return their keyword-only output buffers in declaration order. +# We switch to ``add`` in this section because it naturally demonstrates both +# binary broadcasting and an ``out=`` overload. +# Use ``mylib_mutable::add.out`` to register the ``out`` overload of +# ``mylib_mutable::add``; this lets PyTorch associate the functional operator +# with its ``out=`` variant. Other mutable operators should not return mutated +# inputs or aliases of inputs. + + +supports_tagged_mutable_ops = ( + hasattr(torch, "Tag") + and hasattr(torch.Tag, "inplace") + and hasattr(torch.Tag, "out") +) + +if supports_tagged_mutable_ops: + + @torch.library.custom_op( + "mylib_mutable::add_inplace", + mutates_args={"x"}, + tags=torch.Tag.inplace, + ) + def add_inplace(x: Tensor, y: Tensor) -> Tensor: + x.add_(y) + return x + + + @torch.library.custom_op("mylib_mutable::add", mutates_args=()) + def add(x: Tensor, y: Tensor) -> Tensor: + return x + y + + + # The fake kernel must match the real output metadata. ``x + y`` accounts + # for broadcasting, dtype promotion, strides, and storage offset. + @add.register_fake + def _(x, y): + return torch.empty_like(x + y) + + + @torch.library.custom_op( + "mylib_mutable::add.out", + mutates_args={"out"}, + tags=torch.Tag.out, + ) + def add_out(x: Tensor, y: Tensor, *, out: Tensor) -> Tensor: + out.copy_(x + y) + return out + + + x_for_inplace = torch.ones(3) + y_for_add = torch.arange(3.0) + torch.testing.assert_close(add_inplace(x_for_inplace, y_for_add), y_for_add + 1) + + out_for_add = torch.empty(3) + torch.testing.assert_close( + add_out(torch.ones(3), y_for_add, out=out_for_add), + y_for_add + 1, + ) + torch.testing.assert_close(out_for_add, y_for_add + 1) + + torch.library.opcheck(add_inplace, (torch.ones(3), y_for_add)) + torch.library.opcheck(add, (torch.ones(3), y_for_add)) + torch.library.opcheck(add, (torch.ones(1, 3), torch.ones(2, 3))) + torch.library.opcheck( + add, + ( + torch.ones(2, 3, dtype=torch.float32), + torch.ones(2, 3, dtype=torch.float64), + ), + ) + torch.library.opcheck( + add_out, + (torch.ones(3), y_for_add), + {"out": torch.empty(3)}, + ) +else: + print("Tagged in-place and out= custom operators require PyTorch 2.13 or later.") + + +###################################################################### +# Validate the operator +# --------------------- +# And here's an ``opcheck`` run telling us that we did indeed register the +# operator correctly. ``opcheck`` would error out if we forgot to add ``out`` to +# ``mutates_args``, for example. + + +examples = [ + (torch.randn(5), torch.empty(5)), + (torch.randn(0, 3), torch.empty(0, 3)), + ( + torch.randn(2, 3, dtype=torch.double), + torch.empty(2, 3, dtype=torch.double), + ), + ( + torch.randn(2, 3).t(), + torch.empty_strided((3, 2), (1, 3)), + ), +] + +for example in examples: + torch.library.opcheck(numpy_sin_out, example) + +###################################################################### +# For autograd, ``torch.vmap``, or other subsystem behavior, continue to +# :ref:`python-custom-ops-registrations`. diff --git a/advanced_source/python_custom_ops_registrations.py b/advanced_source/python_custom_ops_registrations.py new file mode 100644 index 00000000000..5589e3f7944 --- /dev/null +++ b/advanced_source/python_custom_ops_registrations.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- + +""" +.. _python-custom-ops-registrations: + +Adding Training and Other Registrations to Python Custom Operators +================================================================== + +Start here after a base operator passes ``torch.library.opcheck``: + +* :ref:`python-custom-ops-functional` +* :ref:`python-custom-ops-mutable` + +Registrations do not change the base contract. After adding one, rerun +``torch.library.opcheck`` on representative inputs for that subsystem. +""" + +###################################################################### +# Adding training support for NumPy sin +# ------------------------------------- +# Use ``torch.library.register_autograd`` to add training support for an +# operator. Prefer this over directly using ``torch.autograd.Function``; some +# compositions of ``autograd.Function`` with PyTorch operator registration APIs +# can lead to (and has led to) silent incorrectness when composed with +# ``torch.compile``. +# +# If you don't need training support, there is no need to use +# ``torch.library.register_autograd``. If you end up training with a +# ``custom_op`` that doesn't have an autograd registration, we'll raise an error +# message. +# +# This page uses the same ``numpy.sin`` operation as the functional and mutable +# pages so the only new concept is the autograd registration. + +import numpy as np +import torch +from torch import Tensor + + +@torch.library.custom_op( + "mylib_training::numpy_sin", + mutates_args=(), + device_types="cpu", +) +def numpy_sin(x: Tensor) -> Tensor: + result = torch.empty_like(x) + np.sin(x.detach().numpy(), out=result.numpy()) + return result + + +@numpy_sin.register_fake +def _(x): + return torch.empty_like(x) + + +###################################################################### +# The fake kernel must describe the same output metadata as the real kernel, +# including shape, strides, dtype, device, layout, and storage offset when +# relevant. Here the real kernel returns ``torch.empty_like(x)``, so the fake +# kernel does the same. +# +# The gradient formula for ``sin(x)`` is ``cos(x)``. The backward formula must +# be written in terms of PyTorch-understood operations or other custom +# operators. Do not directly use non-traceable Python or NumPy code from the +# backward formula. + + +def numpy_sin_setup_context(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + +def numpy_sin_backward(ctx, grad_output): + (x,) = ctx.saved_tensors + return grad_output * x.cos() + + +###################################################################### +# Register the backward formula and the context setup function: + + +numpy_sin.register_autograd( + numpy_sin_backward, + setup_context=numpy_sin_setup_context, +) + + +x = torch.randn(5, requires_grad=True) +y = numpy_sin(x) +y.sum().backward() +torch.testing.assert_close(x.grad, x.detach().cos()) + +###################################################################### +# Testing autograd registration +# ----------------------------- +# ``opcheck`` verifies that autograd was registered in a supported way, but it +# does not prove that the gradient formula is mathematically correct. Use +# separate numerical tests for that, either manual ones or +# ``torch.autograd.gradcheck``. + + +gradcheck_input = torch.randn(3, dtype=torch.double, requires_grad=True) +torch.autograd.gradcheck(numpy_sin, (gradcheck_input,)) + +examples = [ + (torch.randn(5),), + (torch.randn(0, 3),), + (torch.randn(4, requires_grad=True),), + (torch.randn(2, dtype=torch.double, requires_grad=True),), + (torch.randn(2, 3).t(),), + (torch.randn(8)[1:],), +] + +for example in examples: + torch.library.opcheck(numpy_sin, example) + + +###################################################################### +# Other registrations +# ------------------- +# Add these only when users need them. +# +# * **Multiple device kernels:** pass ``device_types="cpu"`` or +# ``device_types="cuda"`` if the implementation only works on one device. +# Register device-specific kernels when devices need different code. +# * **``torch.vmap``:** register a vmap rule with ``torch.library.register_vmap`` +# when batching over the operator should do something different from a Python +# loop over the batch dimension. +# * **Tensor subclasses or modes:** use ``torch.library.register_torch_dispatch`` +# when a Tensor subclass or ``TorchDispatchMode`` needs special behavior. +# * **Autocast:** for C++/CUDA operators that should participate in autocast, +# add an autocast registration as described in the C++ custom operator guide. +# +# A ``vmap`` rule should preserve the meaning of the operator. In particular, +# ``grad(vmap(op))`` should agree with ``grad(map(op))``. Test the rule against +# a Python loop over the batch dimension. diff --git a/extension.rst b/extension.rst index ee4d4524418..aa2d9bd8d28 100644 --- a/extension.rst +++ b/extension.rst @@ -38,7 +38,7 @@ C++ extensions and dispatcher usage. .. customcarditem:: :header: Custom Python Operators - :card_description: Create Custom Operators in Python. Useful for black-boxing a Python function for use with torch.compile. + :card_description: Create Python custom operators with correct mutation behavior, fake kernels, autograd, and opcheck. :image: _static/img/thumbnails/cropped/Custom-Cpp-and-CUDA-Extensions.png :link: advanced/python_custom_ops.html :tags: Extending-PyTorch,Frontend-APIs,C++,CUDA @@ -90,14 +90,12 @@ C++ extensions and dispatcher usage. .. Page TOC .. ----------------------------------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 3 :includehidden: :hidden: :caption: Extending PyTorch advanced/custom_ops_landing_page - advanced/python_custom_ops - advanced/cpp_custom_ops intermediate/custom_function_double_backward_tutorial intermediate/custom_function_conv_bn_tutorial advanced/cpp_extension diff --git a/index.rst b/index.rst index ffdefe21e3a..ec4e809d1e8 100644 --- a/index.rst +++ b/index.rst @@ -410,7 +410,7 @@ Welcome to PyTorch Tutorials .. customcarditem:: :header: Custom Python Operators - :card_description: Create Custom Operators in Python. Useful for black-boxing a Python function for use with torch.compile. + :card_description: Create Python custom operators with correct mutation behavior, fake kernels, autograd, and opcheck. :image: _static/img/thumbnails/cropped/Custom-Cpp-and-CUDA-Extensions.png :link: advanced/python_custom_ops.html :tags: Extending-PyTorch,Frontend-APIs,C++,CUDA