Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 260 additions & 2 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.cadence.aot.compiler_utils import get_shape
from executorch.backends.cadence.aot.pass_utils import get_arg, replace_with_op
from executorch.backends.cadence.aot.quantizer.pattern_utils import (
DQ_PER_TENSOR,
Expand All @@ -24,6 +25,7 @@
from executorch.backends.cadence.aot.quantizer.utils import (
check_out_zero_point_is_min_range,
get_bias_qparams,
quantize_tensor_multiplier,
)
from torch import fx
from torch._ops import OpOverload
Expand Down Expand Up @@ -806,6 +808,40 @@ def get_anchors(
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d_nchw.default

def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
return _fuse_max_pool2d(gm, anchor_node)


def _fuse_max_pool2d(gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
"""Shared fuse logic for both MaxPool2d variants."""
dq_input = anchor_node.args[0]
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
return None
quant_node = find_quant_user(anchor_node)
if quant_node is None:
return None
kernel_size = get_arg(anchor_node, "kernel_size", list[int])
stride = get_arg(anchor_node, "stride", list[int])
padding = get_arg(anchor_node, "padding", list[int])
dilation = get_arg(anchor_node, "dilation", list[int])
ceil_mode = get_arg(anchor_node, "ceil_mode", bool)
args = (get_arg(dq_input, "input", fx.Node),)
kwargs = {
"kernel_size": kernel_size,
"stride": stride,
"padding": padding,
"dilation": dilation,
"ceil_mode": ceil_mode,
}
return replace_with_op(
gm,
anchor_node,
torch.ops.cadence.quantized_max_pool2d_nchw.default,
args,
kwargs,
quant_node,
)


class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
"""
Expand Down Expand Up @@ -845,8 +881,8 @@ def get_anchors(
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d_nchw.default


# This is a base class for ReLU
def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
return _fuse_max_pool2d(gm, anchor_node)


# This is a base class for ReLU, since it can be used with two different aten ops
Expand Down Expand Up @@ -874,6 +910,28 @@ def get_anchors(
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.per_tensor

def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
dq_input = anchor_node.args[0]
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
return None
quant_node = find_quant_user(anchor_node)
if quant_node is None:
return None
input_scale = get_arg(dq_input, "scale", float)
requantize_scale = input_scale / get_arg(quant_node, "scale", float)
requantize_scale_t = torch.tensor([requantize_scale])
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
args = (get_arg(dq_input, "input", fx.Node),)
kwargs = {
"X_zero_point": get_arg(dq_input, "zero_point", int),
"out_zero_point": get_arg(quant_node, "zero_point", int),
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
return replace_with_op(
gm, anchor_node, self.replacement_op(), args, kwargs, quant_node
)


# Regular relu op
class ReluPattern0(ReluBasePattern):
Expand Down Expand Up @@ -933,6 +991,39 @@ def get_anchors(
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor

def anchor_ops(self) -> tuple[OpOverload, ...]:
return (self.partition_types()[0],)

def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
conv_users = list(anchor_node.users)
if len(conv_users) != 1:
return None
relu_node = conv_users[0]
if relu_node.target != self.partition_types()[1]:
return None
_arg0 = anchor_node.args[0]
dq_input = (
_arg0
if isinstance(_arg0, fx.Node) and _arg0.target == DQ_PER_TENSOR
else None
)
_arg1 = anchor_node.args[1]
dq_weight = (
_arg1
if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR
else None
)
if dq_input is None or dq_weight is None:
return None
quant_node = find_quant_user(relu_node)
if quant_node is None:
return None
check_out_zero_point_is_min_range(
get_arg(quant_node, "zero_point", int),
get_arg(quant_node, "dtype", torch.dtype),
)
return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node)


# Conv1d + regular relu op fusion
class Conv1dReluPattern0(ConvReluBasePattern):
Expand Down Expand Up @@ -987,6 +1078,56 @@ def get_anchors(
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_softmax.per_tensor

def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
dq_input = anchor_node.args[0]
if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR:
return None
quant_node = find_quant_user(anchor_node)
if quant_node is None:
return None
input_q = get_arg(dq_input, "input", fx.Node)
quant_input = get_arg(quant_node, "input", fx.Node)
mask_shape = get_shape(gm, quant_input)
if not mask_shape:
return None
mask_shape = list(mask_shape)
# Softmax mask is packed 16 elements per int32 word.
assert (
mask_shape[-1] % 16 == 0
), f"Softmax mask dimension must be divisible by 16, got {mask_shape[-1]}"
mask_shape[-1] = mask_shape[-1] // 16
mask_tensor = insert_node_with_meta(
gm,
torch.ops.aten.full.default,
(mask_shape, 0.0),
{"dtype": torch.int32},
anchor_node,
input_q,
)
# Initial position for streaming softmax (unused, set to 0).
pos_tensor = insert_node_with_meta(
gm,
torch.ops.aten.full.default,
([1], 0),
{"dtype": torch.int64},
anchor_node,
input_q,
)
args = (
input_q,
mask_tensor,
get_arg(anchor_node, "dim", int),
0,
pos_tensor,
get_arg(dq_input, "scale", float),
get_arg(dq_input, "zero_point", int),
get_arg(quant_node, "scale", float),
get_arg(quant_node, "zero_point", int),
)
return replace_with_op(
gm, anchor_node, self.replacement_op(), args, {}, quant_node
)


class MixedW8A32LinearPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
Expand Down Expand Up @@ -1041,6 +1182,36 @@ def get_anchors(
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_linear.default

def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0:
return None
_arg1 = anchor_node.args[1]
dq_weight = (
_arg1
if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR
else None
)
_arg2 = anchor_node.args[2]
dq_bias = (
_arg2
if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR
else None
)
if dq_weight is None or dq_bias is None:
return None
input_node = anchor_node.args[0]
assert isinstance(input_node, fx.Node)
args = (
input_node,
get_arg(dq_weight, "input", fx.Node),
get_arg(dq_weight, "scale", float),
get_arg(dq_bias, "input", fx.Node),
get_arg(dq_bias, "scale", float),
)
return replace_with_op(
gm, anchor_node, self.replacement_op(), args, {}, anchor_node
)


class MixedW8A32ConvPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
Expand Down Expand Up @@ -1115,6 +1286,57 @@ def get_anchors(
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_conv.default

def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0:
return None
_arg1 = anchor_node.args[1]
dq_weight = (
_arg1
if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR
else None
)
_arg2 = anchor_node.args[2]
dq_bias = (
_arg2
if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR
else None
)
if dq_weight is None or dq_bias is None:
return None
input_node = anchor_node.args[0]
assert isinstance(input_node, fx.Node)
assert get_arg(anchor_node, "stride", list[int]) == [1]
assert get_arg(anchor_node, "padding", list[int]) == [0]
assert get_arg(anchor_node, "dilation", list[int]) == [1]
assert get_arg(anchor_node, "groups", int) == 1
weight_q = get_arg(dq_weight, "input", fx.Node)
transposed_inputs = insert_node_with_meta(
gm,
torch.ops.aten.permute.default,
(input_node, [0, 2, 1]),
None,
anchor_node,
input_node,
)
transposed_weights = insert_node_with_meta(
gm,
torch.ops.aten.permute.default,
(weight_q, [2, 0, 1]),
None,
anchor_node,
weight_q,
)
args = (
transposed_inputs,
transposed_weights,
get_arg(dq_weight, "scale", float),
get_arg(dq_bias, "input", fx.Node),
get_arg(dq_bias, "scale", float),
)
return replace_with_op(
gm, anchor_node, self.replacement_op(), args, {}, anchor_node
)


class MixedW8A32GruPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
Expand Down Expand Up @@ -1187,6 +1409,42 @@ def __init__(self, args, meta):
def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_gru.default

def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> fx.Node | None:
if len(anchor_node.kwargs) > 0:
return None
params = anchor_node.args[2]
# GRU requires 4 weight/bias params: w_ih, w_hh, b_ih, b_hh
if not isinstance(params, (list, tuple)) or len(params) < 4:
return None
dq_w_ih = params[0]
if not isinstance(dq_w_ih, fx.Node) or dq_w_ih.target != DQ_PER_TENSOR:
return None
dq_w_hh = params[1]
if not isinstance(dq_w_hh, fx.Node) or dq_w_hh.target != DQ_PER_TENSOR:
return None
dq_b_ih = params[2]
if not isinstance(dq_b_ih, fx.Node) or dq_b_ih.target != DQ_PER_TENSOR:
return None
dq_b_hh = params[3]
if not isinstance(dq_b_hh, fx.Node) or dq_b_hh.target != DQ_PER_TENSOR:
return None
input_node = anchor_node.args[0]
hidden_node = anchor_node.args[1]
args = (
input_node,
hidden_node,
get_arg(dq_w_ih, "input", fx.Node),
get_arg(dq_w_ih, "scale", float),
get_arg(dq_w_hh, "input", fx.Node),
get_arg(dq_w_hh, "scale", float),
get_arg(dq_b_ih, "input", fx.Node),
get_arg(dq_b_ih, "scale", float),
get_arg(dq_b_hh, "input", fx.Node),
)
return replace_with_op(
gm, anchor_node, self.replacement_op(), args, {}, anchor_node
)


class RmsNormPattern(QuantizationPattern):
"""Pattern that preserves rms_norm from decomposition without matching anything."""
Expand Down
Loading