Adding all project files

This commit is contained in:
Martina Burlando 2025-08-02 02:00:33 +02:00
parent 6c9e127bdc
commit cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions

View file

@ -0,0 +1,775 @@
# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC
import logging
from abc import ABCMeta
from typing import Any, Optional, Union
import torch
from torch.ao.quantization.observer import (
AffineQuantizedObserverBase,
get_block_size,
MappingType,
TorchAODType,
ZeroPointDomain,
)
from torch.fx import Node
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
logger = logging.getLogger(__name__)
FP8_TYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
}
_SUB_BYTE_UINT_BOUNDS = {
torch.uint1: (0, 2**1 - 1),
torch.uint2: (0, 2**2 - 1),
torch.uint3: (0, 2**3 - 1),
torch.uint4: (0, 2**4 - 1),
torch.uint5: (0, 2**5 - 1),
torch.uint6: (0, 2**6 - 1),
torch.uint7: (0, 2**7 - 1),
}
"""
Map from dtype to the bound value of integers
TODO: maybe can replace this with call to torch.iinfo
"""
_DTYPE_TO_QVALUE_BOUNDS: dict[Union[torch.dtype, TorchAODType], tuple[int, int]] = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
}
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
def _is_float8_type(dtype: torch.dtype) -> bool:
fp8_types = {
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
}
return dtype in fp8_types
# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
verify that they are within the range of possible quant_min/quant_max
for dtype
"""
if dtype in FP8_TYPES:
quant_min_lower_bound, quant_max_upper_bound = (
torch.finfo(dtype).min,
torch.finfo(dtype).max,
)
elif dtype not in _DTYPE_TO_QVALUE_BOUNDS:
raise ValueError(f"Unsupported dtype: {dtype}")
else:
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
if quant_min is None:
quant_min = quant_min_lower_bound
if quant_max is None:
quant_max = quant_max_upper_bound
assert quant_min >= quant_min_lower_bound, (
"quant_min out of bound for dtype, "
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
)
assert quant_max <= quant_max_upper_bound, (
"quant_max out of bound for dtype, "
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
)
return quant_min, quant_max
def _get_reduction_params(block_size, input_size):
"""Given block_size and input size find the parameters for reduction:
Output:
shape_for_reduction: the shape we use to `view` input to prepare it for reduction
reduction_dims: the dims we'll do reduction over
Example::
Input:
block_size: (3, 3, 2, 10)
input_size: (3, 3, 10, 10)
Output:
shape_for_reduction: (3, 3, 5, 2, 10)
reduction_dim: [0, 1, 3, 4]
"""
assert len(block_size) == len(input_size)
shape_for_reduction = []
reduction_dims = []
cur_dim = 0
for i in range(len(block_size)):
if block_size[i] != input_size[i] and block_size[i] > 1:
assert input_size[i] % block_size[i] == 0, (
f"Expecting input size at {i} dimension: "
f"{input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}"
)
shape_for_reduction.append(input_size[i] // block_size[i])
shape_for_reduction.append(block_size[i])
# reduce over the block_size[i] dim
reduction_dims.append(cur_dim + 1)
cur_dim += 2
else:
# block_size[i] == input_size[i] or block_size[i] == 1
shape_for_reduction.append(input_size[i])
# we only need to reduce over the dimension if block_size is greater than 1
# otherwise it's already the same as reduced dimension
if block_size[i] != 1:
reduction_dims.append(cur_dim)
cur_dim += 1
return shape_for_reduction, reduction_dims
def _register_custom_op(lib):
"""This decorator is used to preserve some high level operators for torch.export.export
while still allow them to be decomposed for inductor path
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
NOTE: This should be applied at the top, after all other decorators have been applied
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
sense for downstream system (like executorch) to accept as well
Example:
lib = torch.library.Library("my_namespace', "FRAGMENT")
register_custom_op = _register_custom_op(lib)
@register_custom_op
def _the_op_that_needs_to_be_preserved(...)
...
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.export_for_training
"""
from torch._inductor.decomposition import register_decomposition
def decorator(fn):
from torch._library.infer_schema import infer_schema
# expecting fn.__name__ starts with `_` and we want to take the rest
# to be the name of the custom op
assert (
fn.__name__[0] == "_"
), f"Expecting function name starts with `_`, got {fn.__name__}"
assert not any(
c in fn.__name__ for c in ".<>"
), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
op_name = fn.__name__[1:]
schema = op_name + infer_schema(fn, mutates_args={})
lib.define(schema)
lib.impl(op_name, fn, "CompositeImplicitAutograd")
lib_namespace = lib.ns
op = getattr(getattr(torch.ops, lib_namespace), op_name)
register_decomposition([op])(fn)
return op
return decorator
quant_lib = torch.library.Library("pt2e_quant", "FRAGMENT") # noqa: TOR901
register_custom_op = _register_custom_op(quant_lib)
def choose_qparams_affine_with_min_max(
min_val: torch.Tensor,
max_val: torch.Tensor,
mapping_type: MappingType,
block_size: tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
operator that pass in min_val and max_val directly instead of deriving these from a single input.
This is used for observers in static quantization where min_val and max_val may be obtained through
tracking all the data in calibration data set.
Args:
Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
and then scale/zero_point, we pass in min_val/max_val directly
"""
return _choose_qparams_affine(
None,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name if zero_point_domain is not None else None,
min_val,
max_val,
)
@register_custom_op
def _choose_qparams_affine(
input: Optional[torch.Tensor],
mapping_type: str,
block_size: list[int],
target_dtype: torch.dtype,
quant_min: Optional[Union[int, float, bool]] = None,
quant_max: Optional[Union[int, float, bool]] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[str] = "INT",
min_val: Optional[torch.Tensor] = None,
max_val: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library
The op does the following:
1. figure out the dimension for reduction based on block_size
2. find min_val/max_val based on the dimension for reduction
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
and `zero_point_domain`
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [
MappingType.SYMMETRIC.name,
MappingType.SYMMETRIC_NO_CLIPPING_ERR.name,
MappingType.ASYMMETRIC.name,
], f"Unsupported mapping type: {mapping_type}"
if target_dtype in FP8_TYPES:
assert (
mapping_type == MappingType.SYMMETRIC.name
), f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
if input is not None:
if scale_dtype is None:
scale_dtype = input.dtype
if zero_point_dtype is None:
zero_point_dtype = input.dtype
if eps is None:
eps = torch.finfo(input.dtype).eps
assert (
len(block_size) == input.dim()
), f"Got input dim:{input.dim()}, block_size: {block_size}"
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
input = input.view(shape_for_reduction)
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
else:
assert (
min_val is not None and max_val is not None
), "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
assert (
min_val.dtype == max_val.dtype
), "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"
if scale_dtype is None:
scale_dtype = min_val.dtype
if zero_point_dtype is None:
zero_point_dtype = min_val.dtype
if eps is None:
eps = torch.finfo(min_val.dtype).eps
if preserve_zero:
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
else:
min_val_neg = min_val
max_val_pos = max_val
if (
mapping_type == MappingType.SYMMETRIC.name
or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
):
# scales
if mapping_type == MappingType.SYMMETRIC.name:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
else:
assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
# calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and
# quant_max = 7.
# - If smin is bigger: There would be coverage on negative values down to -8, and less rounding
# error than the existing SYMMETRIC case.
# - If smax is bigger: it covers the positive values up to 7. The round
# error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after
# quantization.
smin = min_val_neg / float(quant_min)
smax = max_val_pos / float(quant_max)
mask = smin > smax
scale = torch.where(mask, smin, smax)
# zeros
if not preserve_zero:
raise ValueError(
"preserve_zero == False is not supported for symmetric quantization"
)
if (
zero_point_domain is not None
and zero_point_domain != ZeroPointDomain.INT.name
):
raise ValueError(
"zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization"
)
scale = torch.clamp(scale, min=eps)
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.clamp(scale, min=eps)
if zero_point_domain == ZeroPointDomain.NONE.name:
zero_point = None
else:
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
assert (
zero_point_domain == ZeroPointDomain.FLOAT.name
), "if not preserve_zero, zero_point must be in FLOAT domain"
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point
if zero_point is not None:
zero_point = zero_point.to(dtype=zero_point_dtype)
return scale.to(dtype=scale_dtype), zero_point
@torch.no_grad()
def quantize_affine(
input: torch.Tensor,
block_size: tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
block_size: (Tuple[int, ...]): granularity of quantization,
this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Note:
How can block_size represent different granularities?
let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different
granularities:
granularity type | block_size
per_tensor | (3, 3, 10, 10)
per_axis (axis=0) | (1, 3, 10, 10)
per_axis (axis=1) | (3, 1, 10, 10)
per_group (groupsize=2) | (3, 3, 10, 2)
per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10)
Output:
quantized tensor with requested dtype
"""
return _quantize_affine(
input,
block_size,
scale,
zero_point,
output_dtype,
quant_min,
quant_max,
zero_point_domain.name if zero_point_domain is not None else None,
)
@register_custom_op
def _quantize_affine(
input: torch.Tensor,
block_size: list[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[Union[int, float, bool]] = None,
quant_max: Optional[Union[int, float, bool]] = None,
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
Note:
zero_point_domain is optional specifies how we quantize the floating point to quantized data:
INT: quantized_val = (float_val / scale) (integer) + zero_point (integer)
FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization
Where we do not want to round values to nearest integer and instead scale and cast.
"""
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
# torch.uintx dtypes yet
if output_dtype in _SUB_BYTE_UINT_BOUNDS:
output_dtype = torch.uint8
return _quantize_affine_no_dtype_cast(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain,
).to(output_dtype)
def _quantize_affine_no_dtype_cast(
input: torch.Tensor,
block_size: list[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_min: Union[int, float],
quant_max: Union[int, float],
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
) -> torch.Tensor:
"""
The op does the following:
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
the shape after reduction
2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
3. reshape the quantized result to origianl shape
"""
# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Unsupported input dtype: {input.dtype}"
assert (
len(block_size) == input.dim()
), f"Got input dim:{input.dim()}, block_size: {block_size}"
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
original_shape = input.shape
input = input.view(shape_for_reduction)
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
scale = scale.view(shape_after_reduction)
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)
if zero_point_domain == ZeroPointDomain.INT.name:
quant = torch.clamp(
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
)
elif zero_point_domain == ZeroPointDomain.NONE.name:
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is NONE"
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
elif zero_point_domain is None:
# This case handles quantization for float8 we expect no zero point and no zero point domain
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is None"
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT.name
mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = torch.clamp(
torch.round((input - min_val) / scale), quant_min, quant_max
)
quant = quant.view(original_shape)
return quant
def dequantize_affine(
input: torch.Tensor,
block_size: tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
*,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
block_size: (List[int]): granularity of quantization,
this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (Tensor): quantization parameter for affine quantization
zero_point (Tensor): quantization parameter for affine quantization
input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for input Tensor
quant_max (Optional[int]): maximum quantized value for input Tensor
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Output:
dequantized Tensor, with requested dtype or fp32
"""
return _dequantize_affine(
input,
block_size,
scale,
zero_point,
input_dtype,
quant_min,
quant_max,
zero_point_domain.name if zero_point_domain is not None else None,
output_dtype=output_dtype,
)
@register_custom_op
def _dequantize_affine(
input: torch.Tensor,
block_size: list[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
quant_min: Optional[Union[int, float, bool]] = None,
quant_max: Optional[Union[int, float, bool]] = None,
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library"""
# TODO: validate scale/zero_point dimensions are compatible with block_size
if input_dtype not in _SUB_BYTE_UINT_BOUNDS:
assert (
input.dtype == input_dtype
), f"Expected: {input_dtype}, got: {input.dtype}"
assert output_dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
return _dequantize_affine_no_dtype_check(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain,
output_dtype,
)
def _dequantize_affine_no_dtype_check(
input: torch.Tensor,
block_size: list[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_min: Union[int, float],
quant_max: Union[int, float],
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function converts AQT tensors to their high precision floating point representation
The op does the following:
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
the shape after reduction
2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
3. reshape the quantized result to origianl shape and change dtype to the output_dtype
"""
assert (
len(block_size) == input.dim()
), f"Got input dim:{input.dim()}, block_size: {block_size}"
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
original_shape = input.shape
input = input.view(shape_for_reduction)
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
scale = scale.view(shape_after_reduction)
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)
if zero_point_domain == ZeroPointDomain.INT.name:
# Force a copy to avoid input modification due
# to upcoming in-place operations.
dequant = input.to(torch.int32, copy=True)
if zero_point is not None:
dequant = dequant - zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant = dequant * scale
elif zero_point_domain == ZeroPointDomain.NONE.name:
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is NONE"
dequant = input.to(output_dtype)
dequant = dequant * scale
elif zero_point_domain is None:
# This case handles dequantization for float8 we expect no zero point and no zero point domain
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is None"
assert _is_float8_type(
input.dtype
), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}"
dequant = input.to(output_dtype)
dequant = dequant * scale
else:
assert (
zero_point_domain == ZeroPointDomain.FLOAT.name
), f"Unexpected zero point domain: {zero_point_domain}"
# TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this)
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
if zero_point is not None:
dequant += zero_point
return dequant.view(original_shape).to(output_dtype)
class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
def forward(self, input: torch.Tensor):
if input.numel() == 0:
return input
input_detached = input.detach()
self.original_dtype = input_detached.dtype
assert self.granularity is not None, "granularity is None"
self.block_size = get_block_size(input_detached.shape, self.granularity)
shape_for_reduction, reduction_dims = _get_reduction_params(
self.block_size, input_detached.size()
)
input_detached = input_detached.view(shape_for_reduction)
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
self.min_val = min_val
self.max_val = max_val
else:
assert (
self.min_val.shape == min_val.shape
), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
assert (
self.max_val.shape == max_val.shape
), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
min_val = torch.min(self.min_val, min_val)
max_val = torch.max(self.max_val, max_val)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
# returning original input
return input
def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]:
assert hasattr(self, "min_val") and hasattr(
self, "max_val"
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
return choose_qparams_affine_with_min_max(
self.min_val,
self.max_val,
self.mapping_type,
[], # BlockSize is not needed because the min/max are already reduced
self.target_dtype,
self.quant_min,
self.quant_max,
self.eps,
self.scale_dtype,
self.zero_point_dtype,
self.preserve_zero,
self.zero_point_domain,
)
def convert(self, model: torch.fx.GraphModule, observer_node: Node):
print("calling convert")
from torch.ao.quantization.fx.utils import create_getattr_from_value
scale, zero_point = self.calculate_qparams()
with model.graph.inserting_before(observer_node):
assert self.block_size is not None, "Expecting block_size to be populated"
assert (
self.original_dtype is not None
), "Expecting original_dtype to be populated"
scale_node = create_getattr_from_value(model, model.graph, "_scale", scale)
zero_point_node = create_getattr_from_value(
model, model.graph, "_zero_point", zero_point
)
q_node = model.graph.call_function(
torch.ops.pt2e_quant.quantize_affine,
(
observer_node.args[0],
self.block_size,
scale_node,
zero_point_node,
self.target_dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain.name,
),
{},
)
dq_node = model.graph.call_function(
torch.ops.pt2e_quant.dequantize_affine,
(
q_node,
self.block_size,
scale_node,
zero_point_node,
self.target_dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain.name,
),
{"output_dtype": self.original_dtype},
)
observer_node.replace_all_uses_with(dq_node)
model.graph.erase_node(observer_node)

View file

@ -0,0 +1,342 @@
import copy
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Callable, Optional
import torch
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
from torch.export import ExportedProgram
from torch.fx import GraphModule, Node
from torch.nn import functional as F
NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
CUSTOM_KEY = "custom"
log = logging.getLogger(__name__)
def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
"""
Attach numeric_debug_handle_id for all nodes in the graph module of the given
ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder.
Notice that nodes like getattr are out of scope since they are not in the graph.
The graph nodes of input exported program are modified inplace.
Here's an example of using debug handle quantize flow::
ep = export_for_training(eager_model, example_inputs)
generate_numeric_debug_handle(ep)
m = ep.module()
quantizer = XNNPACKQuantizer()
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
"""
# Sanity check the input data type
if not isinstance(ep, ExportedProgram):
raise ValueError(
f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}"
)
unique_id = 0
def _find_max_id(node: torch.fx.Node) -> None:
nonlocal unique_id
unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
)
def _assign_debug_handle(node: torch.fx.Node) -> None:
nonlocal unique_id
if CUSTOM_KEY not in node.meta:
node.meta[CUSTOM_KEY] = {}
if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]:
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
unique_id += 1
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
bfs_trace_with_node_process(ep, _find_max_id)
unique_id += 1
# Assign debug handles to all nodes in the graph that don't have one based on the
# max ID found in the previous step.
bfs_trace_with_node_process(ep, _assign_debug_handle)
def _detach(x: object) -> object:
detached: object = None
if isinstance(x, torch.Tensor):
detached = x.detach()
elif isinstance(x, (list, tuple)):
detached = type(x)([_detach(e) for e in x])
elif isinstance(x, dict):
detached = {k: _detach(e) for k, e in x.items()}
else:
detached = x
return detached
def _tensor_shape_equals(x: object, y: object) -> bool:
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
return x.shape == y.shape
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y))
elif isinstance(x, dict) and isinstance(y, dict):
all_equal = True
for k in x:
all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k]))
return all_equal
else:
log.debug("Comparing non Tensors: %s and %s, they must be equal", x, y)
return type(x) == type(y) and x == y
def _loss_fn(
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object
) -> object:
"""The returned loss will have the same structure as `x` and `y`, e.g.
if both are Tensor, we'll return a Tensor
if both are list, we'll return a list of Tensors
if both are dict, we'll return a dict with the same key, and value being the loss between the
two Tensors
"""
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
return loss(x.to(torch.float32), y.to(torch.float32))
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)])
elif isinstance(x, dict) and isinstance(y, dict):
return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()}
else:
return None
class OutputLogger(torch.nn.Module):
"""
Base class for capturing output values for nodes in a GraphModule, it only captures
Tensor output currently, but we can extend it to work for other types of inputs later if needed
"""
# Mark as impure so that calls to it will not be removed during DCE.
_is_impure = True
def __init__(
self,
debug_handle: int,
node_name: Optional[str] = None,
nn_module_stack: Optional[object] = None,
) -> None:
super().__init__()
self.node_name = node_name
self.nn_module_stack = nn_module_stack
self.debug_handle = debug_handle
self.stats: list[object] = []
def forward(self, x: object) -> object:
self.stats.append(_detach(x))
return x
def __extra_repr__(self) -> str:
return (
f"debug_handle={self.debug_handle}, node_name={self.node_name}, "
"nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})"
)
def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:
"""For a given node, adds an OutputLogger that observes the output of that node,
and all its users use the OutputLogger output instead.
The OutputLogger will contain the debug_handle which can be used to compare
graphs after transforms"""
# to avoid circular dep
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
# add a logger after the node
with model.graph.inserting_after(node):
get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger")
logger_name = get_new_attr_name(model)
setattr(
model,
logger_name,
OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")),
)
logger_node = model.graph.call_module(logger_name, (node,), {})
orig_users = list(node.users.keys())
for user_node in orig_users:
if user_node is logger_node:
continue
user_node.replace_input_with(node, logger_node)
return logger_node
def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
"""Add output loggers to node that has numeric_debug_handle
Args:
model (GraphModule): original model
Returns:
a model with output loggers for all nodes that has numeric_debug_handle_id
"""
# don't change the original model
model = copy.deepcopy(model)
for n in model.graph.nodes:
if (
CUSTOM_KEY not in n.meta
or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY]
):
continue
numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
_insert_logger(model, n, numeric_debug_handle)
model.recompile()
return model
@dataclass(frozen=True)
class QuantizationComparisonResult:
actual: torch.Tensor
ref: torch.Tensor
@property
def mse_loss(self) -> object:
return self.loss(F.mse_loss)
@property
def sqnr(self) -> object:
return self.loss(compute_sqnr)
def loss(
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> object:
return _loss_fn(loss_function, self.actual, self.ref)
def __repr__(self) -> str:
# Don't include the tensors themselves as they are quite large to print
# out.
return (
f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})"
)
def __post_init__(self) -> None:
if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)):
raise ValueError(
f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}"
)
if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)):
raise ValueError(
f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}"
)
if not _tensor_shape_equals(self.ref, self.actual):
raise ValueError(
f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}"
)
@dataclass(frozen=True)
class NodeAccuracySummary:
handle: int
actual_node_name: str
actual_module_stack: str
ref_node_name: str
ref_module_stack: str
results: Sequence[QuantizationComparisonResult]
def _module_stack_to_str(module_stack: object) -> str:
"""Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear")
to "mod.foo.0.linear"
"""
if not isinstance(module_stack, dict):
return str(module_stack)
module_values_list = list(module_stack.values())
if len(module_values_list) > 0:
owning_module = module_values_list[-1][0]
return str(owning_module)
else:
return str(module_stack)
def extract_results_from_loggers(
model: GraphModule,
) -> dict[int, tuple[Optional[str], object, list[object]]]:
"""For a given model, extract the tensors stats and related information for each debug handle.
The reason we have a list of object, instead of Tensor is because the output of node may not be
a Tensor, it could be (nested) list, tuple or dict as well.
Returns:
A dict is keyed by the debug_handle id and the values are a list of object recorded
in loggers
"""
# Results maps debug handle to a tensor list for each model being compared.
handles: dict[int, tuple[Optional[str], object, list[object]]] = {}
for _name, module in model.named_children():
if isinstance(module, OutputLogger) and len(module.stats) > 0:
handles[module.debug_handle] = (
module.node_name,
module.nn_module_stack,
module.stats,
)
return handles
def compare_results(
ref_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]],
actual_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]],
) -> dict[int, NodeAccuracySummary]:
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
comparison information like SQNR, MSE etc.
Args:
ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id
actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id
Returns:
Dict[int, NodeAccuracySummary]
"""
comparisons = {}
for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items():
if debug_handle not in actual_results:
log.debug(
"Cannot compare for handle %s because it wasn't found in the transformed model",
debug_handle,
)
continue
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
try:
results = [
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
]
except Exception as e:
# Add extra information for an exception from QuantizationComparisonResult
# if the shapes didn't match, to include the handle and the node names.
raise ValueError(
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
) from e
comparisons[debug_handle] = NodeAccuracySummary(
handle=debug_handle,
actual_node_name=actual_name or "",
actual_module_stack=_module_stack_to_str(actual_stack),
ref_node_name=ref_name or "",
ref_module_stack=_module_stack_to_str(ref_stack),
results=results,
)
return comparisons

View file

@ -0,0 +1,82 @@
# mypy: allow-untyped-defs
import logging
import operator
import torch
from torch.ao.quantization.pt2e.utils import (
_filter_sym_size_users,
_is_valid_annotation,
)
from torch.fx.node import map_arg
from torch.fx.passes.infra.pass_base import PassBase, PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ["DuplicateDQPass"]
_QUANTIZE_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
def _maybe_duplicate_dq(
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
):
annotation = user.meta.get("quantization_annotation", None)
if not _is_valid_annotation(annotation):
return
with gm.graph.inserting_after(dq_node):
new_node = gm.graph.node_copy(dq_node)
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
if n == dq_node:
return new_node
else:
return n
new_args = map_arg(user.args, maybe_replace_node)
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
user.args = new_args # type: ignore[assignment]
user.kwargs = new_kwargs # type: ignore[assignment]
class DuplicateDQPass(PassBase):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
dq_users = _filter_sym_size_users(node)
if len(dq_users) <= 1:
continue
# Do not duplicate dq for dynamic quantization
# Pattern: choose_qparam - getitem - q - dq
q_node = node.args[0]
if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
getitem_node = q_node.args[1]
if (
isinstance(getitem_node, torch.fx.node.Node)
and getitem_node.op == "call_function"
and getitem_node.target == operator.getitem
):
choose_qparam_node = getitem_node.args[0]
if (
isinstance(choose_qparam_node, torch.fx.node.Node)
and choose_qparam_node.op == "call_function"
and choose_qparam_node.target
== torch.ops.quantized_decomposed.choose_qparams.tensor
):
continue
for user in dq_users:
_maybe_duplicate_dq(graph_module, node, user)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)

View file

@ -0,0 +1,240 @@
# mypy: allow-untyped-defs
import types
import torch
import torch.nn.functional as F
from torch.ao.quantization.utils import _assert_and_get_unique_device
__all__ = [
"model_is_exported",
]
_EXPORTED_TRAINING_ATTR = "_exported_training"
class _WrapperModule(torch.nn.Module):
"""Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
are trying to export a callable.
"""
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, *args, **kwargs):
"""Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
return self.fn(*args, **kwargs)
def model_is_exported(m: torch.nn.Module) -> bool:
"""
Return True if the `torch.nn.Module` was exported, False otherwise
(e.g. if the model was FX symbolically traced or not traced at all).
"""
return isinstance(m, torch.fx.GraphModule) and any(
"val" in n.meta for n in m.graph.nodes
)
def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
"""
Switch dropout patterns in the model between train and eval modes.
Dropout has different behavior in train vs eval mode. For exported models,
however, calling `model.train()` or `model.eval()` does not automatically switch
the dropout behavior between the two modes, so here we need to rewrite the aten
dropout patterns manually to achieve the same effect.
See https://github.com/pytorch/pytorch/issues/103681.
"""
# Avoid circular dependencies
from .utils import _get_aten_graph_module_for_pattern
# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
m.recompile()
for inplace in [False, True]:
def dropout_train(x):
return F.dropout(x, p=0.5, training=True, inplace=inplace)
def dropout_eval(x):
return F.dropout(x, p=0.5, training=False, inplace=inplace)
example_inputs = (torch.randn(1),)
if train_to_eval:
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train),
example_inputs,
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval),
example_inputs,
)
else:
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval),
example_inputs,
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train),
example_inputs,
)
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[],
ignore_literals=True,
)
m.recompile()
def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
"""
Switch batchnorm patterns in the model between train and eval modes.
Batchnorm has different behavior in train vs eval mode. For exported models,
however, calling `model.train()` or `model.eval()` does not automatically switch
the batchnorm behavior between the two modes, so here we need to rewrite the aten
batchnorm patterns manually to achieve the same effect.
"""
# TODO(Leslie): This function still fails to support custom momentum and eps value.
# Enable this support in future updates.
# Avoid circular dependencies
from .utils import _get_aten_graph_module_for_pattern
# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
m.recompile()
def bn_train(
x: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
):
return F.batch_norm(
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
)
def bn_eval(
x: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
):
return F.batch_norm(
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
)
example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
device = _assert_and_get_unique_device(m)
is_cuda = device is not None and device.type == "cuda"
bn_train_aten = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train),
example_inputs,
is_cuda,
)
bn_eval_aten = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval),
example_inputs,
is_cuda,
)
if train_to_eval:
match_pattern = bn_train_aten
replacement_pattern = bn_eval_aten
else:
match_pattern = bn_eval_aten
replacement_pattern = bn_train_aten
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[],
ignore_literals=True,
)
m.recompile()
# TODO: expose these under this namespace?
def _move_exported_model_to_eval(model: torch.fx.GraphModule):
"""
Move an exported GraphModule to eval mode.
This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
QAT users should call this before performing inference on the model.
This call is idempotent; if the model is already in eval mode, nothing will happen.
"""
is_training = getattr(model, _EXPORTED_TRAINING_ATTR, True)
if not is_training:
return model
setattr(model, _EXPORTED_TRAINING_ATTR, False)
_replace_dropout(model, train_to_eval=True)
_replace_batchnorm(model, train_to_eval=True)
return model
def _move_exported_model_to_train(model: torch.fx.GraphModule):
"""
Move an exported GraphModule to train mode.
This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
QAT users should call this before performing training on the model.
This call is idempotent; if the model is already in train mode, nothing will happen.
"""
is_training = getattr(model, _EXPORTED_TRAINING_ATTR, False)
if is_training:
return model
setattr(model, _EXPORTED_TRAINING_ATTR, True)
_replace_dropout(model, train_to_eval=False)
_replace_batchnorm(model, train_to_eval=False)
return model
def _allow_exported_model_train_eval(model: torch.fx.GraphModule):
"""
Allow users to call `model.train()` and `model.eval()` on an exported model,
but with the effect of changing behavior between the two modes limited to special
ops only, which are currently dropout and batchnorm.
Note: This does not achieve the same effect as what `model.train()` and `model.eval()`
does in eager models, but only provides an approximation. In particular, user code
branching on `training` flag will not function correctly in general because the branch
is already specialized at export time. Additionally, other ops beyond dropout and batchnorm
that have different train/eval behavior will also not be converted properly.
"""
def _train(self, mode: bool = True):
if mode:
_move_exported_model_to_train(self)
else:
_move_exported_model_to_eval(self)
def _eval(self):
_move_exported_model_to_eval(self)
model.train = types.MethodType(_train, model) # type: ignore[method-assign]
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
return model

View file

@ -0,0 +1,181 @@
# mypy: allow-untyped-defs
import itertools
import operator
from collections import OrderedDict
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import torch
from torch.export import ExportedProgram
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
get_source_partitions,
SourcePartition,
)
__all__ = [
"find_sequential_partitions",
"get_equivalent_types",
"update_equivalent_types_dict",
"bfs_trace_with_node_process",
]
_EQUIVALENT_TYPES: list[set] = [
{torch.nn.Conv1d, torch.nn.functional.conv1d},
{torch.nn.Conv2d, torch.nn.functional.conv2d},
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
{torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
{torch.add, operator.add, operator.iadd, "add", "add_"},
{torch.mul, operator.mul, operator.imul, "mul", "mul_"},
]
def _create_equivalent_types_dict():
_DICT = {}
for values in _EQUIVALENT_TYPES:
for v in values:
_DICT[v] = list(values)
return _DICT
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def get_equivalent_types() -> list[set]:
return _EQUIVALENT_TYPES
def update_equivalent_types_dict(customized_equivalent_types=None):
"""Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
When customized_equivalent_types passes in,
re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
"""
if customized_equivalent_types is None:
raise ValueError("customized_equivalent_types should not be None")
global _EQUIVALENT_TYPES
global _EQUIVALENT_TYPES_DICT
_EQUIVALENT_TYPES = customized_equivalent_types
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def _partitions_sequential(partitions: Sequence[SourcePartition]):
prev_partition = None
for partition in partitions:
if prev_partition is not None and not check_subgraphs_connected(
prev_partition, partition
):
return False
prev_partition = partition
return True
def _get_matching_types(partition_type):
matching_types = [partition_type]
if partition_type in _EQUIVALENT_TYPES_DICT:
matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
return matching_types
def _valid_type_sequence(partition_types: list[Any]):
partition_types_set = set() # type: ignore[var-annotated]
for partition_type in partition_types:
matching_types = _get_matching_types(partition_type)
matching_types_set = set(matching_types)
if len(partition_types_set & matching_types_set) > 0:
return False
partition_types_set |= matching_types_set
return True
def find_sequential_partitions(
gm: torch.fx.GraphModule,
partition_types: list[Any],
include_functional_equivalent=True,
filter_fn: Optional[Callable[[Node], bool]] = None,
):
if not _valid_type_sequence(partition_types):
raise ValueError(
f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
)
typed_partitions: OrderedDict[Any, list[SourcePartition]] = OrderedDict()
for partition_type in partition_types:
types_to_match = _get_matching_types(partition_type)
partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
typed_partitions[partition_type] = list(
itertools.chain.from_iterable(partitions.values())
)
typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list)
fused_partitions = [
candidate
for candidate in fusion_candidates
if _partitions_sequential(candidate)
]
return fused_partitions
def _get_submodule(
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
) -> tuple[str, torch.nn.Module, torch.fx.Node]:
submod_node = node.args[arg_index]
assert isinstance(submod_node, torch.fx.Node)
assert submod_node.op == "get_attr"
assert isinstance(submod_node.target, str)
submodule = graph_module.get_submodule(submod_node.target)
# pyre-ignore
return submod_node.target, submodule, node
def _get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> list[tuple[str, torch.nn.Module, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
control_flow_submodules = []
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target is torch.ops.higher_order.cond:
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
if node.target is torch.ops.higher_order.map_impl:
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
return control_flow_submodules
def bfs_trace_with_node_process(
model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable
) -> None:
"""Traverse the graph module and apply node_op to each node."""
assert isinstance(
model, (ExportedProgram, torch.fx.GraphModule)
), f"Expected GraphModule or ExportedProgram, got {type(model)}"
gm = model.graph_module if isinstance(model, ExportedProgram) else model
queue = [gm]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
node_op(node)
control_flow_submodules = [
submodule
for _, submodule, _ in _get_control_flow_submodules(current_graph_module)
]
queue.extend(control_flow_submodules)

View file

@ -0,0 +1,215 @@
# mypy: allow-untyped-defs
import logging
from typing import Optional
import torch
from torch._export.error import InternalError
from torch.ao.quantization.pt2e.utils import (
_filter_sym_size_users,
_find_q_dq_node_for_user,
_is_valid_annotation,
)
from torch.ao.quantization.quantizer import QuantizationSpecBase
from torch.fx.passes.infra.pass_base import PassBase, PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)
__all__ = ["PortNodeMetaForQDQ"]
_METADATA_TO_PORT = [
"stack_trace",
"quantization_tag",
]
_QUANTIZE_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
_CHOOSE_QPARAMS_OPS = [
torch.ops.quantized_decomposed.choose_qparams.tensor,
torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
]
def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
from_meta = from_node.meta
for meta_name in _METADATA_TO_PORT:
if meta_name in from_meta:
to_node.meta[meta_name] = from_meta[meta_name]
def _has_quant_annotation(node: torch.fx.Node) -> bool:
return "quantization_annotation" in node.meta
def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
# BFS to look for choose qparams
from collections import deque
queue = deque(list(node.users.keys()))
while len(queue):
n = queue.popleft()
if n.op == "output":
continue
if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS:
return n
for k in n.users.keys():
queue.append(k)
return None
def _port_metadata_for_input_quant_nodes(
input_node: torch.fx.Node,
node: torch.fx.Node,
qspec: Optional[QuantizationSpecBase],
):
if qspec is None:
return
is_dynamic_quant = getattr(qspec, "is_dynamic", None)
if is_dynamic_quant is not None and is_dynamic_quant is True:
choose_qparams_node = _find_choose_qparams_node(input_node)
if choose_qparams_node is None:
raise ValueError(f"No chose qparams node found for {node}")
choose_qparam_users = _filter_sym_size_users(choose_qparams_node)
if len(choose_qparam_users) != 2:
raise InternalError(f"Expecting exactly two user for {choose_qparams_node}")
scale_node = choose_qparam_users.pop()
dynamic_q_node = next(iter(scale_node.users.keys()))
dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node)
if len(dynamic_q_node_users) > 1:
raise InternalError(f"Expecting single user for {dynamic_q_node}")
dynamic_dq_node = dynamic_q_node_users.pop()
_add_metadata(choose_qparams_node, node)
_add_metadata(dynamic_q_node, node)
_add_metadata(dynamic_dq_node, node)
else:
q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
if q_node is None or dq_node is None:
return
# add metadata for all the node between q_node and get_attr node
# if the q_node can be traced back to get_attr node
q_to_get_attr_nodes = [q_node]
q_node_input = q_node.args[0]
while (
isinstance(q_node_input, torch.fx.Node)
and q_node_input.op == "call_function"
and q_node_input.target
in [
torch.ops.aten.flatten.using_ints,
torch.ops.aten.permute.default,
torch.ops.aten.permute_copy.default,
torch.ops.aten.slice_copy.Tensor,
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze_copy.dim,
torch.ops.aten.transpose.Dimname,
torch.ops.aten.transpose.int,
torch.ops.aten.transpose_,
torch.ops.aten.view_copy.default,
torch.ops.aten.view.default,
torch.ops.aten._mkldnn_transpose,
]
):
q_to_get_attr_nodes.append(q_node_input)
q_node_input = q_node_input.args[0]
if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr":
for n in q_to_get_attr_nodes:
_add_metadata(n, q_node_input)
_add_metadata(dq_node, node)
def _port_metadata_for_output_quant_nodes(
node: torch.fx.Node, qspec: Optional[QuantizationSpecBase]
):
if qspec is None:
return
node_users = _filter_sym_size_users(node)
if len(node.users) == 0:
return
if len(node_users) != 1:
logger.warning(f"Expecting {node} to have single user") # noqa: G004
q_node = node_users.pop()
if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS:
logger.warning(
f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004
) # noqa: G004
return
_add_metadata(q_node, node)
class PortNodeMetaForQDQ(PassBase):
"""
Port metadata for nodes added by quantization flow.
For static quant these are:
- quantizer_per_tensor.default, dequantize_per_tensor.default
- quantizer_per_channel.default, dequantize_per_channel.default
For dynamic quant these are:
- choose_qparams.tensor
- quantizer_per_tensor.tensor, dequantize_per_tensor.tensor
- quantizer_per_channel.default, dequantize_per_channel.default
Rules of porting metadata:
- Metadata to be ported:
- nn_module_stack
- stack_trace
- quantization_tag
- Metadata to NOT be ported:
- Everything else
- Rules:
- Statically quantized patterns:
- Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node.
- Quantize nodes on the outputs inherit metadata of the producer node.
- Example 1:
- Original: [Conv -> AvgPool -> Linear]
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
- Inner brackets specify which nodes Q/DQ inherit metdata from
- [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ]
- Note first Q and last DQ do not inherit metadata from any nodes
- Example 2:
- Original: [Conv -> AvgPool -> Linear]
- AvgPool is not quantized
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
- Inner brackets specify which nodes Q/DQ inherit metdata from
- [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ]
- Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because
AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation
on the nodes (in this case AvgPool node) to conclude if the node or patter was
supposed to be quantized. And subsequntly decide if the preceding Q, if any, should
inherit metadata from AvgPool.
- Dynamically quantized patterns:
- Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes
- For example, below linear is dynamically quantized while rest statically:
- Original: [Conv -> AvgPool -> Linear]
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
- Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
- Note first Q does not inherit metadata from any nodes
NB:
- The best place for porting metadata is during observer conversion to q/dq. This is because it precisely
knows which quantization spec is converted to q/dq and thus from where the metadata should be ported.
However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit.
Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant
code, this pass should like to be integrated in the refactored variant of "convert" step.
"""
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
annotation = node.meta.get("quantization_annotation", None)
if _is_valid_annotation(annotation):
input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
output_qspec = node.meta["quantization_annotation"].output_qspec
for input_node, qspec in input_qspec_map.items():
_port_metadata_for_input_quant_nodes(input_node, node, qspec)
_port_metadata_for_output_quant_nodes(node, output_qspec)
return PassResult(graph_module, True)

View file

@ -0,0 +1,574 @@
# mypy: allow-untyped-defs
from typing import Any, Optional, Union
import torch
from torch._subclasses import FakeTensor
from torch.ao.quantization import (
CUSTOM_KEY,
NUMERIC_DEBUG_HANDLE_KEY,
ObserverOrFakeQuantize,
QConfigMapping,
)
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from torch.ao.quantization.fx.prepare import (
_create_obs_or_fq_from_qspec,
_insert_obs_or_fq,
_is_activation_post_process_node,
_save_state,
)
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.quantizer import (
EdgeOrNode,
QuantizationSpecBase,
SharedQuantizationSpec,
)
from torch.fx import Graph, GraphModule, Node
from torch.fx.node import Argument
# TODO: make pt2e folder private?
__all__ = [
"prepare",
]
def _find_root_edge_or_node(
edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode]
) -> EdgeOrNode:
"""Find the root node for the sharing tree
Args:
edge_or_node: edge/node that we want to find the root
shared_with_map: each edge/node points to the parent, the root node will points to itself
Returns:
root edge/node
"""
parent = shared_with_map[edge_or_node]
if parent == edge_or_node:
return edge_or_node
root = _find_root_edge_or_node(parent, shared_with_map)
# path compression
shared_with_map[edge_or_node] = root
return root
def _union(
parent: EdgeOrNode,
child: EdgeOrNode,
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
) -> None:
"""Merge the subtree for `child` with `parent`, the order is important here"""
root_parent = _find_root_edge_or_node(parent, shared_with_map)
root_child = _find_root_edge_or_node(child, shared_with_map)
# union the two trees by pointing the root of child to root of parent
shared_with_map[root_child] = root_parent
def _update_shared_with(
child: EdgeOrNode,
qspec: QuantizationSpecBase,
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
):
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
configuration and established the relationship between `edge_or_node` with the edge/node that it
is pointing to, we'll use this information in the end to get the group id
"""
if isinstance(qspec, SharedQuantizationSpec):
parent = qspec.edge_or_node
# we point from edge_or_node to the node that it is sharing_with, e.g.
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
_union(parent, child, shared_with_map)
def _unwrap_shared_qspec(
qspec: QuantizationSpecBase,
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
) -> QuantizationSpecBase:
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
if qspec is SharedQuantizationSpec
(1). tries to find the root edge or node for the node that the qspec points to
(2). recursively find the root qspec based on the qspec for the root node
"""
if isinstance(qspec, SharedQuantizationSpec):
sharing_with = qspec.edge_or_node
root = _find_root_edge_or_node(sharing_with, shared_with_map)
qspec = edge_or_node_to_qspec[root]
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
return qspec
def _has_same_attr(
qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str
):
return (
hasattr(qspec_a, attr_name)
and hasattr(qspec_b, attr_name)
and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name)
) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name))
def _get_edge_or_node_to_qspec(
model: torch.fx.GraphModule,
) -> dict[EdgeOrNode, QuantizationSpecBase]:
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes"""
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {}
for n in model.graph.nodes:
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
qa = n.meta["quantization_annotation"]
for input_to_n, qspec in qa.input_qspec_map.items():
input_edge = (input_to_n, n)
edge_or_node_to_qspec[input_edge] = qspec
if qa.output_qspec is not None:
output_node = n
qspec = qa.output_qspec
edge_or_node_to_qspec[output_node] = qspec
return edge_or_node_to_qspec
def _union_input_edge_with(
input_edge,
input_edge_root_qspec,
edge_or_node,
edge_or_node_to_qspec,
shared_with_map,
):
"""Union input edge with another edge or node, used in implicit sharing to point the current input
edge to other user edges of the producer node, or the output of producer node since these are
referring to the same Tensor
"""
root_qspec = None
if edge_or_node in edge_or_node_to_qspec:
qspec = edge_or_node_to_qspec[edge_or_node]
root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
# TODO: add assertions for types of root qspecs
if root_qspec is not None and all(
_has_same_attr(root_qspec, input_edge_root_qspec, attr)
for attr in [
"dtype",
"is_dynamic",
"quant_min",
"quant_max",
"qscheme",
"ch_axis",
"scale",
"zero_point",
]
):
# the input arg to the node should reuse the existing output observer for arg
# since dtype is the same (we may want to extend this to be a more strict check
# in the future)
# so we point from `input_edge` to `arg` (output of the argument)
_union(edge_or_node, input_edge, shared_with_map)
def _get_edge_or_node_to_group_id(
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase]
) -> dict[EdgeOrNode, int]:
"""Map from edge/node to the group ID, generated from quantization annotations,
edge/node with the same group ID should use the same observer/fake_quant instance
This is applying SharedQuantizationSpec configuration and map each edge/node to a group
There is another implicit sharing that's built in the quantization, when we have the following:
* op1 -> op2
* output of op1: int8_qspec
* (op1 -> op2) input edge: int8_qspec
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
Figuring out the correct group ID for all edge/node is a standard union find problem:
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
Args:
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
Returns:
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
belongs to the same group should have the same id
Example:
op2 -> cat1 -> cat2
op1 / /
op3
edge_or_node_to_qspec: {
op1: int8_qspec,
op2: int8_qspec,
(op1, cat1): int8_qspc,
(op2, cat1): SharedQuantizationSpec((op1, cat1)),
cat1: SharedQuantizationSpec((op1, cat1)),
(op3, cat2): int8_qspec,
(cat1, cat2): SharedQuantizationSpec((op3, cat2)),
cat2: SharedQuantizationSpec((op3, cat2)),
}
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
edge_or_node_to_group_id: {
op1: 1,
op2: 1,
(op1, cat1): 1,
(op2, cat1): 1,
cat1: 1,
(op3, cat2): 1,
(cat1, cat2): 1,
cat2: 1,
}
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
# connects the two sharing group around cat1 and cat2 op due to transitive sharing
"""
# means the observer of key should be shared with observer with value, by default it will
# be shared with itself
shared_with_map: dict[EdgeOrNode, EdgeOrNode] = {
k: k for k in edge_or_node_to_qspec.keys()
}
for edge_or_node, qspec in edge_or_node_to_qspec.items():
if isinstance(edge_or_node, torch.fx.Node):
output_node = edge_or_node
_update_shared_with(output_node, qspec, shared_with_map)
else:
input_edge = edge_or_node
input_edge_root_qspec = _unwrap_shared_qspec(
qspec, edge_or_node_to_qspec, shared_with_map
)
assert isinstance(input_edge, tuple)
arg, n = input_edge
if n.meta["quantization_annotation"].allow_implicit_sharing:
# NOTE: the order is important here, we first share with other users and then share with previous
# output because the reverse order could cause circular dependency
# e.g node1 -> node2
# \ -> node3
# when processing (node1, node2), if we first point (node1, node2) to node1
# Step 1. shared_map = {(node1, node2): node1}
# Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
# which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
# have a circular dependency
# the following order works around this issue, but this does not allow arbitrary configuration
# of sharing so it might break in a different case in the future, when it breaks
# quantizer writer can check the notes here to debug the issue
# sharing with other users of the producer node
# (arg, user)
if not isinstance(arg, Node) or not isinstance(n, Node):
raise Exception( # noqa: TRY002
f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}"
)
for user in arg.users:
if user is n:
continue
arg_to_user_edge = (arg, user)
_union_input_edge_with(
input_edge,
input_edge_root_qspec,
arg_to_user_edge,
edge_or_node_to_qspec,
shared_with_map,
)
# sharing with output of producer node
_union_input_edge_with(
input_edge,
input_edge_root_qspec,
arg,
edge_or_node_to_qspec,
shared_with_map,
)
_update_shared_with(input_edge, qspec, shared_with_map)
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
cur_group_id = 0
edge_or_node_to_group_id: dict[EdgeOrNode, int] = {}
for edge_or_node in shared_with_map.keys():
root = _find_root_edge_or_node(edge_or_node, shared_with_map)
if root not in edge_or_node_to_group_id:
edge_or_node_to_group_id[root] = cur_group_id
cur_group_id += 1
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
return edge_or_node_to_group_id
def _get_obs_or_fq_map(
edge_or_node_to_group_id: dict[EdgeOrNode, int],
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
is_qat: bool,
) -> dict[EdgeOrNode, ObserverOrFakeQuantize]:
"""Generates the EdgeOrNode to observer/fake_quant instances
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
instances
"""
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {}
for edge_or_node, qspec in edge_or_node_to_qspec.items():
group_id = edge_or_node_to_group_id[edge_or_node]
if group_id not in group_id_to_obs_or_fq:
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
# the implementation for _create_obs_or_fq_from_qspec
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(
qspec, obs_or_fq_map, is_qat
)
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
return obs_or_fq_map
def _maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
arg: Argument,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: dict[str, torch.nn.Module],
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
`node` and `arg` if necessary.
"""
# for ops such as torch.cat([x0, x1]),
# traverse through the list
if isinstance(arg, (list, tuple)):
new_arg_to_return = []
for inner_arg in arg:
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node,
inner_arg,
qconfig,
model,
named_modules,
obs_or_fq_map,
is_qat,
)
new_arg_to_return.append(new_inner_arg)
return type(arg)(new_arg_to_return)
if not isinstance(arg, Node):
return arg
assert isinstance(arg, Node)
# default (no observer)
new_arg = arg
# find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
original_arg = arg
while _is_activation_post_process_node(original_arg, named_modules):
original_arg = original_arg.args[0] # type: ignore[assignment]
assert isinstance(
original_arg, Node
), f"expect original argument to be a Node, but got: {type(original_arg)}"
input_edge = (original_arg, node)
if input_edge not in obs_or_fq_map:
return new_arg
# input_edge needs to be observed
input_edge_obs_or_fq = obs_or_fq_map[input_edge]
if input_edge_obs_or_fq is None:
return new_arg
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
# the arg is observed as the output and is using the same instance as the input_edge
# we'll reuse the inserted observer/fake_quant
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
input_edge_obs_or_fq
):
return new_arg
# otherwise, we'll insert a new observer/fake_quant node
# skip inserting new observers if the same observer instance is inserted before for another user
# Example:
# conv1 -> obs1 -> existing_obs -> conv2
# \ -> conv3
#
# instead of inserting new observers we will have:
# conv1 -> obs1 -> existing_obs -> conv2
# \ -> conv3
for maybe_obs_node in arg.users.keys():
if not _is_activation_post_process_node(maybe_obs_node, named_modules):
continue
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
return maybe_obs_node
assert isinstance(model.graph, Graph)
new_arg = _insert_obs_or_fq(
arg, input_edge_obs_or_fq, model, named_modules, model.graph
)
return new_arg
def _maybe_insert_input_observers_for_node(
node: Node,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: dict[str, torch.nn.Module],
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Note: modifies `node` inplace.
For example, if cur_node needs an observer after prev_node, we change from
prev_node -> cur_node
To
prev_node -> obs -> cur_node
"""
# Look through every input arg. If that arg's target dtype does not
# match the current node's target dtype, insert an observer.
new_args = []
for arg in node.args:
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node,
arg,
qconfig,
model,
named_modules,
obs_or_fq_map,
is_qat,
)
new_args.append(new_arg)
# Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
# gelu has a has an approximate kwarg that persist in exported graph.
# This is just a work around for these.
assert (
node.target == torch.ops.aten.clone.default
or node.target == torch.ops.aten.zeros_like.default
or node.target == torch.ops.aten.gelu.default
or len(node.kwargs) == 0
), " expecting kwargs for aten op IR to be empty"
# assign the new args to the node, inplace
node.args = tuple(new_args)
def _maybe_insert_output_observer_for_node(
node: Node,
model: torch.nn.Module,
named_modules: dict[str, torch.nn.Module],
graph: Graph,
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> Optional[Node]:
if node in obs_or_fq_map:
output_act_obs_or_fq = obs_or_fq_map[node]
new_output = _insert_obs_or_fq(
node, output_act_obs_or_fq, model, named_modules, graph
)
# propagate numeric debug handle from original node to observer/fake_quant node
if (
isinstance(node, Node)
and isinstance(new_output, Node)
and CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
):
if CUSTOM_KEY not in new_output.meta:
new_output.meta[CUSTOM_KEY] = {}
new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
CUSTOM_KEY
][NUMERIC_DEBUG_HANDLE_KEY]
return new_output
return None
def _maybe_insert_input_and_output_observers_for_node(
node: Node,
model: torch.fx.GraphModule,
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
):
this_node_quantization_annotation = (
node.meta["quantization_annotation"]
if "quantization_annotation" in node.meta
else None
)
if this_node_quantization_annotation is None:
return
named_modules = dict(model.named_modules(remove_duplicate=False))
_maybe_insert_input_observers_for_node(
node,
None, # qconfig
model,
named_modules,
obs_or_fq_map,
is_qat,
)
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
if not output_is_a_tensor:
return
# this returns the new observer node if it was needed
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
node, model, named_modules, model.graph, obs_or_fq_map, is_qat
)
if maybe_output_obs_node is None:
return
# Update users of original node to use the output observer
# instead. For example, change
#
# next_node
# /
# cur_node -> obs
#
# to
#
# next_node
# /
# cur_node -> obs
#
# We need to save orig users before updating uses because
# the list of users will change as we update uses
orig_users = list(node.users.keys())
for user_node in orig_users:
if user_node is maybe_output_obs_node:
continue
user_node.replace_input_with(node, maybe_output_obs_node)
def prepare(
model: GraphModule,
node_name_to_scope: dict[str, tuple[str, type]],
is_qat: bool,
obs_or_fq_callback=None,
) -> GraphModule:
# Since we are mutating the graph as we go, we iterate over the original
# nodes before observer insertion, instead of model.graph.nodes.
nodes_before_observation = list(model.graph.nodes)
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
# all edge/nodes that belongs to the same group will use the same instance
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
# instance
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
obs_or_fq_map = _get_obs_or_fq_map(
edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
)
if obs_or_fq_callback:
obs_or_fq_callback(model, obs_or_fq_map)
for node in nodes_before_observation:
# TODO: simplify logic for inserting observers
_maybe_insert_input_and_output_observers_for_node(
node, model, obs_or_fq_map, is_qat
)
model = GraphModule(model, model.graph)
_save_state(
model,
{}, # node_name_to_qconfig
node_name_to_scope,
PrepareCustomConfig(),
{}, # equalization_node_name_to_qconfig
QConfigMapping(),
is_qat,
set(), # observed_node_names
)
return model

View file

@ -0,0 +1,991 @@
# mypy: allow-untyped-defs
import copy
import dataclasses
import itertools
import operator
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
QuantizationSpecBase,
SharedQuantizationSpec,
)
from torch.fx import Graph, GraphModule, Node
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
from .utils import (
_get_aten_graph_module_for_pattern,
_is_bn_node,
_is_conv_or_conv_transpose_node,
_is_conv_transpose_fn,
fold_bn_weights_into_conv_node,
)
if TYPE_CHECKING:
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
__all__ = [] # type: ignore[var-annotated]
def _get_quantized_conv_bn_example_inputs_kwargs(
is_per_channel: bool,
has_bias: bool,
bias_is_quantized: bool,
is_cuda: bool,
) -> dict[str, Any]:
"""
Optional example inputs for quantized and folded conv-bn patterns
used in convert, expressed as kwargs.
"""
kwargs = {}
# Per tensor quantization uses literals to represent scale and zero
# point, so there is no need to include them here as kwargs
if is_per_channel:
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias and bias_is_quantized:
kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias:
kwargs["conv_bias"] = torch.randn(1)
if is_cuda:
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
kwargs[k] = v.cuda()
return kwargs
def _get_conv_bn_pattern(conv_fn: Callable) -> Callable:
def _conv_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
) -> torch.Tensor:
x = conv_fn(x, conv_weight, conv_bias)
x = F.batch_norm(
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
)
return x
return _WrapperModule(_conv_bn_pattern)
# TODO: merge this with the `no_conv_bias` case
def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable:
def _qat_conv_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
) -> torch.Tensor:
"""
Approximated method to fuse conv and bn. It requires only one forward pass.
conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std.
This is based on `nniqat.ConvBn2d._forward_approximate`.
"""
# TODO: allow setting eps
bn_eps = 1e-5
running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
weight_shape[weight_in_channel_axis] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
x = conv_fn(x, scaled_weight, zero_bias)
x = x / scale_factor.reshape(bias_shape)
x = x + conv_bias.reshape(bias_shape)
x = F.batch_norm(
x,
bn_running_mean,
bn_running_var,
bn_weight,
bn_bias,
training=True,
eps=bn_eps,
)
return x
return _WrapperModule(_qat_conv_bn_pattern)
def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable:
def _qat_conv_bn_pattern_no_conv_bias(
x: torch.Tensor,
conv_weight: torch.Tensor,
# Not used, only for matching convenience
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
) -> torch.Tensor:
"""
Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias.
"""
# TODO: allow setting eps
bn_eps = 1e-5
running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
weight_shape[weight_in_channel_axis] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
x = conv_fn(x, scaled_weight, None)
x = x / scale_factor.reshape(bias_shape)
x = F.batch_norm(
x,
bn_running_mean,
bn_running_var,
bn_weight,
bn_bias,
training=True,
eps=bn_eps,
)
return x
return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias)
def _append_qdq(x, is_per_channel, is_bias, kwargs):
"""
Helper function to append q-dq ops after `x`, using dummy values for the qparams
and qmin/qmax. We use dummy values here because we match with `ignore_literals=True`
and will manually replace these values after subgraph rewriting.
Return the dq node.
"""
# Dummy args to be passed into q-dq ops
per_channel_axis = 0
scale_key = "bias_scale" if is_bias else "weight_scale"
zp_key = "bias_zero_point" if is_bias else "weight_zero_point"
scale = kwargs[scale_key] if is_per_channel else 1.0
zp = kwargs[zp_key] if is_per_channel else 0
qmin = -127
qmax = 127
dtype = torch.int8
qd = torch.ops.quantized_decomposed
if is_per_channel:
x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
else:
x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
return x
def _get_quantized_qat_conv_bn_pattern(
is_per_channel: bool,
has_bias: bool,
bias_is_quantized: bool,
conv_fn: Callable,
bn_is_training: bool,
) -> Callable:
"""
Return the quantized version of QAT conv + BN pattern.
This is based on `nniqat.ConvBn2d._forward_approximate`,
used in QAT convert. We first match this pattern and replace
it with the normal [conv - bn] pattern, then fold the BN
weights into conv.
"""
# TODO: allow setting eps
bn_eps = 1e-5
def _quantized_qat_conv_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
**kwargs,
) -> torch.Tensor:
running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
scaled_weight = _append_qdq(
scaled_weight,
is_per_channel,
is_bias=False,
kwargs=kwargs,
)
if has_bias:
zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
if bias_is_quantized:
zero_bias = _append_qdq(
zero_bias,
is_per_channel,
is_bias=True,
kwargs=kwargs,
)
x = conv_fn(x, scaled_weight, zero_bias)
else:
x = conv_fn(x, scaled_weight, None)
x = x / scale_factor.reshape(bias_shape)
if has_bias:
x = x + kwargs["conv_bias"].reshape(bias_shape)
x = F.batch_norm(
x,
bn_running_mean,
bn_running_var,
bn_weight,
bn_bias,
training=bn_is_training,
eps=bn_eps,
)
return x
return _WrapperModule(_quantized_qat_conv_bn_pattern)
def _get_folded_quantized_qat_conv_bn_pattern(
is_per_channel: bool,
has_bias: bool,
bias_is_quantized: bool,
conv_fn: Callable,
bn_is_training: bool,
) -> Callable:
"""
Quantized QAT conv - bn pattern with bn weights being folded into conv.
"""
# TODO: allow setting eps
bn_eps = 1e-5
def _folded_quantized_qat_conv_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
**kwargs,
) -> torch.Tensor:
conv_weight = _append_qdq(
conv_weight,
is_per_channel,
is_bias=False,
kwargs=kwargs,
)
if has_bias:
bias = kwargs["conv_bias"]
if bias_is_quantized:
bias = _append_qdq(
bias,
is_per_channel,
is_bias=True,
kwargs=kwargs,
)
else:
bias = None
x = conv_fn(x, conv_weight, bias)
x = F.batch_norm(
x,
bn_running_mean,
bn_running_var,
bn_weight,
bn_bias,
training=bn_is_training,
eps=bn_eps,
)
return x
return _WrapperModule(_folded_quantized_qat_conv_bn_pattern)
def _has_conv_bias_filter(
match: "InternalMatch",
original_graph: Graph,
pattern_graph: Graph,
) -> bool:
"""
Match filter for the subgraph rewriter that returns True if the conv node in
the original graph has bias.
"""
for n in match.nodes_map.values():
if _is_conv_or_conv_transpose_node(n):
return len(n.args) > 2 and n.args[2] is not None
raise ValueError("Could not find conv node in matched conv + bn pattern")
def _no_conv_bias_filter(
match: "InternalMatch",
original_graph: Graph,
pattern_graph: Graph,
) -> bool:
"""
Match filter for the subgraph rewriter that returns True if the conv node in
the original graph does NOT have bias.
"""
return not _has_conv_bias_filter(match, original_graph, pattern_graph)
def _is_quantize(n: Node) -> bool:
return n.target in [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
def _is_dequantize(n: Node) -> bool:
return n.target in [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> dict[str, tuple[Node, Node]]:
"""
Helper function to extract the nodes in the conv-bn fusion pattern after
subgraph rewriting, in the form of a map:
{name: (original_node, replacement_node)}
The following names must exist in the map:
"conv", "conv_weight", "conv_input", "bn", "getitem"
The following names may exist in the map:
"conv_weight_q", "conv_weight_dq", "conv_bias",
"conv_bias_q", "conv_bias_dq"
"""
def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Optional[Node]]:
"""
Return a 3-tuple of (conv_node, bn_node, getitem_node).
This asserts that the match contains exactly one of each node.
"""
conv_node, bn_node, getitem_node = None, None, None
for n in nodes:
if n.op != "call_function":
continue
if _is_conv_or_conv_transpose_node(n):
assert conv_node is None
conv_node = n
if _is_bn_node(n):
assert bn_node is None
bn_node = n
if n.target == operator.getitem:
assert getitem_node is None
getitem_node = n
assert conv_node is not None
assert bn_node is not None
return (conv_node, bn_node, getitem_node)
def _get_q_dq_nodes(n: Node) -> tuple[Node, Node, Node]:
"""
Return a 3-tuple of (orig_node, q_node, dq_node).
"""
assert _is_dequantize(n)
q_node = n.args[0]
assert isinstance(q_node, Node)
assert _is_quantize(q_node)
orig_node = q_node.args[0]
assert isinstance(orig_node, Node)
return (orig_node, q_node, n)
original_nodes = list(_filter_nodes_map(r.nodes_map).values())
o_conv, o_bn, o_getitem = _get_nodes(original_nodes)
r_conv, r_bn, r_getitem = _get_nodes(r.replacements)
# Create the mapping from original node to replacement node
assert o_getitem is None
assert r_getitem is None
mapping = {
"conv": (o_conv, r_conv),
"bn": (o_bn, r_bn),
}
# Extract conv input and weight
# Note: here we extract the original nodes indirectly through the pattern nodes
# because the args of the original nodes are no longer available after replacement
(p_conv, _, _) = _get_nodes(list(r.nodes_map.keys()))
(p_conv_input, p_conv_weight, *_) = p_conv.args
(r_conv_input, r_conv_weight, *_) = r_conv.args
assert isinstance(p_conv_input, Node)
assert isinstance(p_conv_weight, Node)
assert isinstance(r_conv_input, Node)
assert isinstance(r_conv_weight, Node)
o_conv_input = r.nodes_map[p_conv_input]
o_conv_weight = r.nodes_map[p_conv_weight]
# If conv weight is quantized, extract the q - dq nodes
if _is_dequantize(p_conv_weight):
p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes(
p_conv_weight
)
r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes(
r_conv_weight
)
o_conv_weight = r.nodes_map[p_conv_weight]
o_conv_weight_q = r.nodes_map[p_conv_weight_q]
o_conv_weight_dq = r.nodes_map[p_conv_weight_dq]
mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q)
mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq)
mapping["conv_input"] = (o_conv_input, r_conv_input)
mapping["conv_weight"] = (o_conv_weight, r_conv_weight)
# Extract conv bias
if len(p_conv.args) > 2 and len(r_conv.args) > 2:
p_conv_bias = p_conv.args[2]
r_conv_bias = r_conv.args[2]
assert isinstance(p_conv_bias, Node)
assert isinstance(r_conv_bias, Node)
o_conv_bias = r.nodes_map[p_conv_bias]
# If conv bias is quantized, extract the q - dq nodes
if _is_dequantize(p_conv_bias):
p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias)
r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias)
o_conv_bias = r.nodes_map[p_conv_bias]
o_conv_bias_q = r.nodes_map[p_conv_bias_q]
o_conv_bias_dq = r.nodes_map[p_conv_bias_dq]
mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q)
mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq)
mapping["conv_bias"] = (o_conv_bias, r_conv_bias)
return mapping
def _filter_nodes_map(nodes_map: dict[Node, Node]) -> dict[Node, Node]:
"""
Return a filtered `nodes_map` returned from the subgraph rewriter.
The filtered `nodes_map` will contain only nodes that are actually
matched in the pattern, excluding None or placeholder nodes.
"""
new_nodes_map: dict[Node, Node] = {}
for pattern_node, graph_node in nodes_map.items():
# bias can be None
if graph_node is None:
continue
# skip pattern placeholder nodes
if pattern_node.op == "placeholder":
continue
new_nodes_map[pattern_node] = graph_node
return new_nodes_map
# TODO: this is error prone, use the replace_literals_with_placeholders hack instead
def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
"""
Copy over literal args in conv, such as stride and padding, from the matched node
in the original graph to its replacement in the new graph.
This is needed due to the following limitation in the subgraph rewriter when used
with dynamo export: literal (non-tensor) args are not supported in the match and
replacement patterns. This is because dynamo export automatically inlines these
literal args, making them dead placeholder nodes. In the future, we should check
if dynamo export can optionally disable this inlining, or if subgraph rewriter
can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419.
Note: Unlike other tensor args like conv weights and biases, literal args are
preserved in the original nodes after replacement, so we can access them here.
"""
assert _is_conv_or_conv_transpose_node(original_node)
assert _is_conv_or_conv_transpose_node(new_node)
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
new_args = list(new_node.args)
if len(new_args) < 3:
# bias is optional, when it is not present, it means it is None
new_args.append(None)
new_node.args = tuple(new_args[:3]) + original_node.args[3:]
def _update_conv_input_qspec_map_after_replacement(
original_node: Node, replacement_node: Node
):
"""
Update the `input_qspec_map` in the annotation after subgraph rewriting.
The original annotation referred to the nodes in the original graph,
so the keys in the `input_qspec_map` will need to be updated to reflect
the corresponding nodes in the replacement graph.
"""
assert _is_conv_or_conv_transpose_node(original_node)
assert _is_conv_or_conv_transpose_node(replacement_node)
if "quantization_annotation" not in original_node.meta:
return
original_input_qspec_map = original_node.meta[
"quantization_annotation"
].input_qspec_map
input_qspec_map = {}
# get the list of configs, it should be ordered as input, weight, bias
# note: this is really hacky, we need a better solution, hopefully
# in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820
all_configs = list(original_input_qspec_map.items())
# input activation
input_qspec_map[replacement_node.args[0]] = all_configs[0][1]
# weight
input_qspec_map[replacement_node.args[1]] = all_configs[1][1]
# bias
if len(replacement_node.args) > 2 and len(all_configs) > 2:
input_qspec_map[replacement_node.args[2]] = all_configs[2][1]
replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
def _update_special_qspecs_after_replacement(
node: Node,
original_to_replacement_node: dict[Node, Node],
):
"""
Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s
used in `node`'s quantization annotation after subgraph rewriting.
The original annotation referred to the nodes in the original graph,
so the nodes used in these special quantization specs will need to
be updated to the corresponding nodes in the replacement graph.
"""
def _get_new_edge_or_node(edge_or_node: EdgeOrNode):
if isinstance(edge_or_node, Node):
_node = edge_or_node
return original_to_replacement_node.get(_node, _node)
elif (
isinstance(edge_or_node, tuple)
and len(edge_or_node) == 2
and all(isinstance(x, Node) for x in edge_or_node)
):
src, dest = edge_or_node
return (
original_to_replacement_node.get(src, src),
original_to_replacement_node.get(dest, dest),
)
else:
raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node))
def _get_new_qspec(qspec: QuantizationSpecBase):
if isinstance(qspec, SharedQuantizationSpec):
new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node)
return SharedQuantizationSpec(new_edge_or_node)
elif isinstance(qspec, DerivedQuantizationSpec):
new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from]
return dataclasses.replace(qspec, derived_from=new_derived_from)
else:
return qspec
if "quantization_annotation" not in node.meta:
return
annotation = node.meta["quantization_annotation"]
for input_node, qspec in annotation.input_qspec_map.items():
annotation.input_qspec_map[input_node] = _get_new_qspec(qspec)
annotation.output_qspec = _get_new_qspec(annotation.output_qspec)
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn:
return m
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
for is_cuda in is_cuda_options:
m = _fuse_conv_bn_qat_helper(
m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
)
m = _fuse_conv_bn_qat_helper(
m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
)
m = _fuse_conv_bn_qat_helper(
m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
)
m = _fuse_conv_bn_qat_helper(
m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
)
return m
def _fuse_conv_bn_qat_helper(
m: GraphModule,
conv_fn: Callable,
example_inputs: tuple[Any, ...],
is_cuda: bool,
) -> GraphModule:
"""
Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
the fused QAT subgraph equivalent. The input graph should already be annotated.
The annotations in the original nodes will be preserved in the corresponding
nodes in the new subgraph.
Note: This also handles the (conv + bn + relu) pattern.
"""
m.graph.eliminate_dead_code()
m.recompile()
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
match_pattern = _get_aten_graph_module_for_pattern(
conv_bn_pattern,
example_inputs,
is_cuda,
)
# Step (1): Replace patterns with conv bias
#
# Here we do replacement separately for cases with and without conv bias, since
# the replacement patterns for these two cases are substantially different.
# TODO: use the public replace_pattern API once it also returns replacement nodes
qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern(
qat_conv_bn_pattern,
example_inputs,
is_cuda,
)
replacements_with_conv_bias = replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern_with_conv_bias,
match_filters=[_has_conv_bias_filter],
ignore_literals=True,
)
m.recompile()
# Step (2): Replace patterns without conv bias
qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern(
qat_conv_bn_pattern_no_conv_bias,
example_inputs,
is_cuda,
)
replacements_no_conv_bias = replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern_no_conv_bias,
match_filters=[_no_conv_bias_filter],
ignore_literals=True,
)
m.recompile()
# Step (3): Post processing
#
# Due to limited functionality in the subgraph rewriter, here we manually
# update the replacement graph as follows:
#
# (a) Copy over metadata from original subgraph. This ensures the stack traces
# and annotations are preserved in the new subgraph
#
# (b) Copy over literal args for conv from the original subgraph
# TODO: do this for literal args for batchnorm as well
#
# (c) Update all references of the old nodes in the original subgraph to refer
# to the corresponding nodes in the new subgraph in the annotations
#
# In the future, we should try to push as much of this functionality into the
# subgraph rewriter as possible, so we don't have to manually copy anything over.
# For more detail, see https://github.com/pytorch/pytorch/issues/100419.
all_original_to_replacement_nodes = {}
for r in replacements_with_conv_bias + replacements_no_conv_bias:
replacement_dict = _get_conv_bn_pattern_nodes(r)
# The original conv node's "nn_module_stack"
conv_nn_module = replacement_dict["conv"][0].meta.get("nn_module_stack", None)
for k, node_tuple in replacement_dict.items():
original_node, replacement_node = node_tuple
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
replacement_node.meta = original_node.meta
# If original_node is a get_attr node, it doesn't have nn_module_stack.
# In this case, we copy nn_module_stack from the original conv node.
if (
k in ["conv_input", "conv_weight"]
and conv_nn_module
and "nn_module_stack" not in replacement_node.meta
):
replacement_node.meta["nn_module_stack"] = copy.deepcopy(conv_nn_module)
if _is_conv_or_conv_transpose_node(original_node):
# Step (3b): Copy over conv literal args
_copy_over_literal_conv_args(original_node, replacement_node)
# Step (3c): Update old references in the conv node's input_qspec_map
_update_conv_input_qspec_map_after_replacement(
original_node, replacement_node
)
all_original_to_replacement_nodes[original_node] = replacement_node
# Step (3c): Update old references in the special qspecs for all nodes in the graph
for n in m.graph.nodes:
_update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes)
return m
def _duplicate_dequantize_node(m: GraphModule):
"""
Helper function to duplicate all dequantize nodes in the graph if the
node has more than one user. For example:
Before:
quantize -> dequantize -> a
\\--> b
\\--> c
After:
quantize -> dequantize_1 -> a
\\--> dequantize_2 -> b
\\--> dequantize_3 -> c
This is useful for subgraph rewriting. E.g. if we wish to match the
pattern [dequantize - a] above, subgraph matching would fail because
the dequantize node has users outside the matched portion of the graph.
Instead, we match [dequantize_1 - a], which is safe.
"""
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
for n in m.graph.nodes:
if n.op != "call_function" or n.target != dq_op or len(n.users) == 1:
continue
for user in list(n.users):
with m.graph.inserting_before(n):
new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs)
user.replace_input_with(n, new_node)
m.graph.erase_node(n)
m.recompile()
def _remove_extra_dequantize(m: GraphModule):
"""
Removes duplicate dequant nodes in the graph, for an operator that has
multiple dequant nodes as a user, replace them with a single dequant node
that can be shared across all the uses. This should be seen as the "reverse"
of `_duplicate_dequantize_node`.
"""
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
for n in m.graph.nodes:
dq_users = [
user
for user in n.users
if user.op == "call_function" and user.target == dq_op
]
if len(dq_users) > 1:
with m.graph.inserting_after(dq_users[0]):
new_node = m.graph.create_node(
"call_function", dq_op, dq_users[0].args, {}
)
for dq_user in dq_users:
dq_user.replace_all_uses_with(new_node)
m.graph.erase_node(dq_user)
m.recompile()
def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
"""
Given a pair of quantize or dequantize nodes, copy over all literal args
from the original node to the replacement node.
"""
# For quantize_per_tensor, scale and zp are literals and need to be copied
# For quantize_per_channel, scale and zp are get_attr nodes and should be skipped
assert original_node.target == replacement_node.target
if original_node.target in (
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
):
# Args: input, [scale, zp, qmin, qmax, dtype]
start_copy_arg_index = 1
elif original_node.target in (
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
):
# Args: input, scale, zp, [axis, qmin, qmax, dtype]
start_copy_arg_index = 3
else:
raise ValueError(
f"Expected quantize/dequantize nodes, got '{original_node.target}'"
)
replacement_node.args = (
replacement_node.args[:start_copy_arg_index]
+ original_node.args[start_copy_arg_index:]
)
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
# Example inputs for quantized and folded conv-bn1d patterns used in convert
_quantized_conv1d_bn_example_inputs = (
torch.randn(1, 1, 3), # x
torch.randn(1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
# Example inputs for quantized and folded conv-bn2d patterns used in convert
_quantized_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn:
return m
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
for is_cuda in is_cuda_options:
m = _fold_conv_bn_qat_helper(
m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
)
m = _fold_conv_bn_qat_helper(
m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
)
m = _fold_conv_bn_qat_helper(
m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
)
m = _fold_conv_bn_qat_helper(
m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
)
# remove in place add from batchnorm tracking traning stats
for node in m.graph.nodes:
if (
node.target == torch.ops.aten.add_.Tensor
and node.args[0].op == "get_attr"
and node.args[1] == 1
and torch.nn.modules.batchnorm.BatchNorm2d
in [val[1] for val in node.meta["source_fn_stack"]]
):
m.graph.erase_node(node)
m.graph.eliminate_dead_code()
m.recompile()
return m
def _fold_conv_bn_qat_helper(
m: GraphModule,
conv_fn: Callable,
example_inputs: tuple[Any, ...],
is_cuda: bool,
) -> GraphModule:
"""
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
"""
m.graph.eliminate_dead_code()
m.recompile()
_duplicate_dequantize_node(m)
# Step (1): Replace QAT pattern with simple [conv - bn] pattern
replacements = []
replacement_options = itertools.product(
[True, False], # is_per_channel
[True, False], # has_bias
[True, False], # bias_is_quantized
[True, False], # bn_is_training
)
for (
is_per_channel,
has_bias,
bias_is_quantized,
bn_is_training,
) in replacement_options:
# For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
# filter out one of the values for this flag to avoid having duplicate patterns
if not has_bias and bias_is_quantized:
continue
kwargs = _get_quantized_conv_bn_example_inputs_kwargs(
is_per_channel, has_bias, bias_is_quantized, is_cuda
)
match_pattern = _get_quantized_qat_conv_bn_pattern(
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
)
match_pattern = _get_aten_graph_module_for_pattern(
match_pattern,
example_inputs,
is_cuda,
**kwargs,
)
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
)
replacement_pattern = _get_aten_graph_module_for_pattern(
replacement_pattern,
example_inputs,
is_cuda,
**kwargs,
)
replacements.extend(
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
ignore_literals=True,
)
)
m.recompile()
_remove_extra_dequantize(m)
for r in replacements:
node_map = _get_conv_bn_pattern_nodes(r)
# Step (2): Copy over metadata from original subgraph
for original_node, replacement_node in node_map.values():
replacement_node.meta = original_node.meta
# Step (3): Copy over args for weight (and optionally bias) q - dq nodes
_copy_over_q_dq_args(*node_map["conv_weight_q"])
_copy_over_q_dq_args(*node_map["conv_weight_dq"])
if "conv_bias_q" in node_map:
assert "conv_bias_dq" in node_map
_copy_over_q_dq_args(*node_map["conv_bias_q"])
_copy_over_q_dq_args(*node_map["conv_bias_dq"])
# Step (4): Fold BN weights into conv
conv_bias = None
(_, conv_node) = node_map["conv"]
(_, bn_node) = node_map["bn"]
(_, conv_weight) = node_map["conv_weight"]
if "conv_bias" in node_map:
(_, conv_bias) = node_map["conv_bias"]
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
# Copy over literal args for conv
for original_node in _filter_nodes_map(r.nodes_map).values():
if _is_conv_or_conv_transpose_node(original_node):
_copy_over_literal_conv_args(original_node, conv_node)
m.graph.eliminate_dead_code()
m.recompile()
return m

View file

@ -0,0 +1,6 @@
from .rewrite import reference_representation_rewrite
__all__ = [
"reference_representation_rewrite",
]

View file

@ -0,0 +1,819 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional
import torch
from torch._higher_order_ops.out_dtype import out_dtype
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
from torch.ao.quantization.pt2e.utils import (
_get_aten_graph_module_for_pattern,
_replace_literals_with_existing_placeholders,
_replace_literals_with_new_placeholders,
remove_tensor_overload_for_qdq_ops,
)
from torch.fx import GraphModule
from torch.fx.subgraph_rewriter import replace_pattern
__all__ = [
"reference_representation_rewrite",
]
def _qdq_quantized_linear(
x_i8,
x_scale,
x_zero_point,
x_quant_min,
x_quant_max,
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
bias_fp32,
out_scale,
out_zero_point,
out_quant_min,
out_quant_max,
):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
)
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
torch.int8,
)
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
)
return out_i8
def _reference_quantized_linear(
x_i8,
x_scale,
x_zero_point,
x_quant_min,
x_quant_max,
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
bias_fp32,
out_scale,
out_zero_point,
out_quant_min,
out_quant_max,
):
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
# This results in failure to match the pattern.
# Therefore, we call a torch.ops.aten.clamp here
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
x_i16 = x_i8.to(torch.int16)
weight_i16 = weight_i8.to(torch.int16)
# always set bias to None so that the same representation can work for the case
# no matter if bias_scale == x_scale * weight_scale or not
acc_i32 = out_dtype(
torch.ops.aten.linear.default,
torch.int32,
x_i16 - x_zero_point,
weight_i16 - weight_zero_point,
None,
)
# TODO: change to mul.Scalar
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
bias_scale = x_scale * weight_scale
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
acc_i32 = acc_i32 + bias_i32
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
acc_i32 = (
out_dtype(
torch.ops.aten.mul.Tensor,
torch.int32,
acc_i32,
x_scale * weight_scale / out_scale,
)
+ out_zero_point
)
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
return out_i8
def _qdq_dynamic_quantized_linear(
x_fp32,
x_quant_min,
x_quant_max,
x_eps,
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
bias_fp32,
):
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
)
x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
)
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
)
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
torch.int8,
)
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
return out_fp32
def _reference_dynamic_quantized_linear(
x_fp32,
x_quant_min,
x_quant_max,
x_eps,
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
bias_fp32,
):
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
)
# decomposed representation for quantize_per_tensor
# TODO: use out_dtype(mul, ...) here when the op is ready
x_fp32 = x_fp32 / x_scale # fp32
# round modes might be different here
# pytorch is rounding to even, which is also common for most of the backends
x_fp32 = torch.round(x_fp32) # fp32
x_i32 = x_fp32.to(dtype=torch.int32) # int32
x_i32 = x_i32 + x_zero_point # int32
# clamp works for fp32, int32 and int8 dtypes
x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32
x_i8 = x_i32.to(dtype=torch.int8)
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
x_i16 = x_i8.to(torch.int16)
weight_i16 = weight_i8.to(torch.int16)
# always set bias to None so that the same representation can work for the case
# no matter if bias_scale == x_scale * weight_scale or not
acc_i32 = out_dtype(
torch.ops.aten.linear.default,
torch.int32,
x_i16 - x_zero_point,
weight_i16 - weight_zero_point,
None,
)
bias_scale = x_scale * weight_scale
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
acc_i32 = acc_i32 + bias_i32
out_fp32 = acc_i32 * (x_scale * weight_scale)
return out_fp32
def _qdq_quantized_conv2d(
x_i8,
x_scale,
x_zero_point,
x_quant_min,
x_quant_max,
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
bias_fp32,
out_scale,
out_zero_point,
out_quant_min,
out_quant_max,
):
stride = [1, 1]
padding = [0, 0]
dilation = [1, 1]
transposed = False
output_padding = [0, 0]
groups = 1
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
)
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
torch.int8,
)
out_fp32 = torch.ops.aten.convolution.default(
x_fp32,
weight_fp32,
bias_fp32,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
)
return out_i8
def _reference_quantized_conv2d(
x_i8,
x_scale,
x_zero_point,
x_quant_min,
x_quant_max,
weight_i8,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
bias_fp32,
out_scale,
out_zero_point,
out_quant_min,
out_quant_max,
):
stride = [1, 1]
padding = [0, 0]
dilation = [1, 1]
transposed = False
output_padding = [0, 0]
groups = 1
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
# This results in failure to match the pattern.
# Therefore, we call a torch.ops.aten.clamp here
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
x_i16 = x_i8.to(torch.int16)
weight_i16 = weight_i8.to(torch.int16)
# always set bias to None so that the same representation can work for the case
# no matter if bias_scale == x_scale * weight_scale or not
acc_i32 = out_dtype(
torch.ops.aten.convolution.default,
torch.int32,
x_i16 - x_zero_point,
weight_i16 - weight_zero_point,
None,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
bias_scale = x_scale * weight_scale
# bias quantization to int32 uses bias_scale = x_scale * weight_scale due to:
# Take linear calculation for example
# Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32
# Represent X, W fp32 as their dequant transforms
# A_fp32 = (A_q - A_zero_point)/A_scale
# Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32
# Factor out X_scale and W_scale
# Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32
# In order to addition of bias_(i)_fp32 inside, we must do
# Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950
# Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale
# Thus bias quantization to int32 must be with X_scale * W_scale
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
# Unsqueeze to match broadcast dims
# Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare
# in graph pattern replacement
bias_i32 = bias_i32.unsqueeze(-1)
bias_i32 = bias_i32.unsqueeze(-1)
acc_i32 = acc_i32 + bias_i32
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
acc_i32 = (
out_dtype(
torch.ops.aten.mul.Tensor,
torch.int32,
acc_i32,
x_scale * weight_scale / out_scale,
)
+ out_zero_point
)
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
return out_i8
def _qdq_quantized_add_relu(
x_i8,
x_scale,
x_zero_point,
y_i8,
y_scale,
y_zero_point,
out_scale,
out_zero_point,
quant_min,
quant_max,
):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
)
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
)
out_fp32 = x_fp32 + y_fp32
out_fp32 = torch.ops.aten.relu(out_fp32)
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
)
return out_i8
def _reference_quantized_add_relu(
x_i8,
x_scale,
x_zero_point,
y_i8,
y_scale,
y_zero_point,
out_scale,
out_zero_point,
quant_min,
quant_max,
):
"""
See comments for `_reference_quantized_add` for more information on
how to derive the formula for out_i8 based on x_i8 and y_i8
"""
x_i32 = x_i8.to(torch.int32)
y_i32 = y_i8.to(torch.int32)
# TODO: change this to mul.Scalar?
x_i32 = out_dtype(
torch.ops.aten.mul.Tensor,
torch.int32,
(x_i32 - x_zero_point),
(x_scale / out_scale),
)
y_i32 = out_dtype(
torch.ops.aten.mul.Tensor,
torch.int32,
(y_i32 - y_zero_point),
(y_scale / out_scale),
)
out_i32 = x_i32 + y_i32 + out_zero_point
# out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point)
out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8)
return out_i8
def _qdq_quantized_add(
x_i8,
x_scale,
x_zero_point,
y_i8,
y_scale,
y_zero_point,
out_scale,
out_zero_point,
quant_min,
quant_max,
):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
)
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
)
out_fp32 = x_fp32 + y_fp32
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
)
return out_i8
def _reference_quantized_add(
x_i8,
x_scale,
x_zero_point,
y_i8,
y_scale,
y_zero_point,
out_scale,
out_zero_point,
quant_min,
quant_max,
):
"""
# How to Derive the formula for out_i8 based on x_i8 and y_i8
# (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
# out_i8 is quantized output, we can write down the formula for it first:
out_i8 = out_f32 / out_scale + out_zero_point (1)
# then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
out_f32 = x_f32 + y_f32 (2)
x_fp32 = (x_i8 - x_zero_point) * x_scale (3)
y_fp32 = (y_i8 - y_zero_point) * y_scale (4)
# applying the above fomula to the out_i8 equation we can get the following:
out_i8 = out_fp32 / out_scale + out_zero_point # (1)
= (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
= ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4)
"""
x_i32 = x_i8.to(torch.int32)
y_i32 = y_i8.to(torch.int32)
# TODO: use out_dtype op
x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
out_i32 = x_i32 + y_i32 + out_zero_point
quant_min = -128
quant_max = 127
out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
return out_i8
def _qdq_quantized_max_pool2d(
x_i8,
x_scale,
x_zero_point,
x_quant_min,
x_quant_max,
out_scale,
out_zero_point,
out_quant_min,
out_quant_max,
):
kernel_size = 1
stride = 1
padding = 0
dilation = 1
ceil_mode = False
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
)
out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(
x_fp32, kernel_size, stride, padding, dilation, ceil_mode
)
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
)
return out_i8
def _reference_quantized_max_pool2d(
x_i8,
x_scale,
x_zero_point,
x_quant_min,
x_quant_max,
out_scale,
out_zero_point,
out_quant_min,
out_quant_max,
):
kernel_size = 1
stride = 1
padding = 0
dilation = 1
ceil_mode = False
# to preserve x_quant_min, x_quant_max in the graph for pattern matching
x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
x_i32 = x_i8.to(torch.int32)
out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default(
x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode
)
out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
out_i8 = out_fp32.to(torch.int8)
return out_i8
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
)
return x
def _reference_quantize_per_tensor_int8(
x_fp32, scale, zero_point, quant_min, quant_max
):
# TODO: use out_dtype(mul, ...) here when the op is ready
x = x_fp32 / scale # fp32
# round modes might be different here
# pytorch is rounding to even, which is also common for most of the backends
x = torch.round(x) # fp32
x = x.to(dtype=torch.int32) # int32
x = x + zero_point # int32
# clamp works for fp32, int32 and int8 dtypes
x = torch.clamp(x, quant_min, quant_max) # int32
x = x.to(dtype=torch.int8)
return x
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, scale, zero_point, quant_min, quant_max, torch.int8
)
return x_fp32
def _reference_dequantize_per_tensor_int8(
x_i8, scale, zero_point, quant_min, quant_max
):
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
# This results in failure to match the pattern.
# Therefore, we call a torch.ops.aten.clamp here
x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
# TODO: use out_dtype op
# note: x_i8.to(torch.int32) does not work here
# TODO: debug the implementation later when torchdynamo time out issue is resolved
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
def _quantize_per_channel_int8(
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
):
out_i8 = torch.ops.quantized_decomposed.quantize_per_channel(
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
)
return out_i8
def _reference_quantize_per_channel_int8(
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
):
x_fp32 = torch.transpose(x_fp32, ch_axis, -1)
out_i32 = torch.ops.aten.clamp(
torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max
)
out_i32 = torch.transpose(out_i32, ch_axis, -1)
return out_i32.to(torch.int8)
def _dequantize_per_channel_int8(
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
):
# the following will be replaced as placeholders
out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel(
x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
)
return out_fp32
def _reference_dequantize_per_channel_int8(
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
):
# the following will be replaced as placeholders
# in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops)
# we call a torch.ops.aten.clamp here
x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
x_i8 = torch.transpose(x_i8, ch_axis, -1)
x_i32 = x_i8.to(torch.int32)
out_fp32 = (x_i32 - zero_points).to(torch.float) * scales
out_fp32 = torch.transpose(out_fp32, ch_axis, -1)
return out_fp32
def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule):
return _replace_literals_with_existing_placeholders(
gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5}
)
@dataclass
class _RewriteInfo:
"""Data needed for rewrite, this includes example inputs, pattern and replacement functions
and post transformation functions for the exported pattern and replacement GraphModule
"""
# example inputs used for exporting the pattern into GraphModule
example_inputs: tuple[Any, ...]
pattern: Callable
replacement: Callable
# post transformation on the exported pattern and replacement GraphModule
pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randn((2, 5), dtype=torch.float),
-128,
127,
torch.finfo(torch.float32).eps,
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
)
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randn(1, 3, 3, 3, dtype=torch.float),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(3, dtype=torch.float),
torch.zeros(3, dtype=torch.int),
1,
-128,
127,
)
_REWRITE_INFO_LIST = [
_RewriteInfo(
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_WrapperModule(_qdq_dynamic_quantized_linear),
_WrapperModule(_reference_dynamic_quantized_linear),
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
),
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
),
),
_RewriteInfo(
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_linear),
_WrapperModule(_reference_quantized_linear),
_replace_literals_with_new_placeholders,
_replace_literals_with_new_placeholders,
),
_RewriteInfo(
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_conv2d),
_WrapperModule(_reference_quantized_conv2d),
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
),
_RewriteInfo(
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_add_relu),
_WrapperModule(_reference_quantized_add_relu),
),
_RewriteInfo(
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_add),
_WrapperModule(_reference_quantized_add),
),
_RewriteInfo(
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
_WrapperModule(_qdq_quantized_max_pool2d),
_WrapperModule(_reference_quantized_max_pool2d),
_replace_literals_with_new_placeholders,
_replace_literals_with_new_placeholders,
),
_RewriteInfo(
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
_WrapperModule(_quantize_per_tensor_int8),
_WrapperModule(_reference_quantize_per_tensor_int8),
),
_RewriteInfo(
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
_WrapperModule(_dequantize_per_tensor_int8),
_WrapperModule(_reference_dequantize_per_tensor_int8),
),
_RewriteInfo(
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
_WrapperModule(_quantize_per_channel_int8),
_WrapperModule(_reference_quantize_per_channel_int8),
_replace_ph_qdq_per_channel_replacement,
_replace_ph_qdq_per_channel_replacement,
),
_RewriteInfo(
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
_WrapperModule(_dequantize_per_channel_int8),
_WrapperModule(_reference_dequantize_per_channel_int8),
_replace_ph_qdq_per_channel_replacement,
_replace_ph_qdq_per_channel_replacement,
),
]
remove_tensor_overload_for_qdq_ops(model)
for rewrite_info in _REWRITE_INFO_LIST:
example_inputs = rewrite_info.example_inputs
pattern = rewrite_info.pattern
replacement = rewrite_info.replacement
pattern_post_trans = rewrite_info.pattern_post_trans
replacement_post_trans = rewrite_info.replacement_post_trans
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
if pattern_post_trans:
pattern = pattern_post_trans(pattern)
if replacement_post_trans:
replacement = replacement_post_trans(replacement)
pattern.recompile() # type: ignore[attr-defined]
replacement.recompile() # type: ignore[attr-defined]
replace_pattern(model, pattern, replacement)
return model

View file

@ -0,0 +1,610 @@
# mypy: allow-untyped-defs
import operator
import types
from typing import Any, Callable, Optional, Union
import torch
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
import torch.nn.functional as F
# Makes sure that quantized_decomposed ops are registered
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx import GraphModule, Node
from torch.nn.utils.fusion import fuse_conv_bn_weights
from torch.utils._pytree import LeafSpec
__all__ = [
"fold_bn_weights_into_conv_node",
"remove_tensor_overload_for_qdq_ops",
]
_QUANTIZE_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
"""
Assuming dest is one of the ops inserted by quant workflow, this function
finds if source and dest are connected. Assumption is that only quant workflow
inserted ops exist between source and dest
"""
quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
while dest.target in quant_workflow_ops:
if not isinstance(dest.args[0], torch.fx.Node):
raise ValueError(
f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}"
)
dest = dest.args[0]
return dest == source
def _find_q_dq_node_for_user(
produer: torch.fx.Node, user: torch.fx.Node
) -> tuple[Any, Any]:
"""
Find q, dq pair corresponding to [producer -> q -> dq -> user]
Utils works by finding dq arg of user and ensuring it is connected to
producer
"""
dq_node = None
for n in user.args:
if (
isinstance(n, torch.fx.Node)
and n.op == "call_function"
and n.target in _DEQUANTIZE_OPS
):
if _is_connected(produer, n):
dq_node = n
break
if dq_node is None:
for n in user.kwargs:
if (
isinstance(n, torch.fx.Node)
and n.op == "call_function"
and n.target in _DEQUANTIZE_OPS
):
if _is_connected(produer, n):
dq_node = n
break
if dq_node is None:
return (None, None)
q_node = None
if (
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
):
q_node = dq_node.args[0]
return (q_node, dq_node)
def _is_sym_size_node(node: Node):
return (
node.op == "call_function"
and node.target == torch.ops.aten.sym_size.default
or node.target == torch.ops.aten.sym_numel.default
or node.target == torch.ops.aten.sym_numel
or node.target == torch.ops.aten.sym_size
)
def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]:
node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
return node_users
def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
if annotation is None:
return False
input_qspec_map = annotation.input_qspec_map
output_qspec = annotation.output_qspec
if len(input_qspec_map) == 0 and output_qspec is None:
return False
return True
def _get_tensor_constant_from_node(node, m):
if node is None:
return None
assert node.op == "get_attr"
target_atoms = node.target.split(".")
attr_itr = m
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
all_args = []
for i, schema in enumerate(args_schema):
if schema.name in orig_kwargs:
all_args.append(orig_kwargs[schema.name])
elif not schema.kwarg_only and i < len(orig_args):
all_args.append(orig_args[i])
else:
all_args.append(schema.default_value)
return all_args
def _is_supported_batch_norm_for_training(node: Node):
"""
Return True if the given node refers to an aten batch norm op QAT supports.
"""
supported_ops = [
torch.ops.aten.batch_norm.default,
torch.ops.aten._native_batch_norm_legit.default,
# Note: we won't need this op anymore after batch norm consolidation
# For now, we need to continue to support it because it gives better
# training numerics than `_native_batch_norm_legit`
torch.ops.aten.cudnn_batch_norm.default,
torch.ops.aten.miopen_batch_norm.default,
]
return node.target in supported_ops
# TODO: move this to torch/ao/quantization/utils.py
def _is_conv_node(n: Node):
"""
Return whether the node refers to an aten conv op.
"""
return n.op == "call_function" and n.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
def _is_conv_transpose_node(n: Node):
"""
Return whether the node refers to an aten conv_transpose op.
"""
return n.op == "call_function" and n.target in [
torch.ops.aten.conv_transpose1d,
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d,
torch.ops.aten.conv_transpose2d.input,
]
def _is_conv_or_conv_transpose_node(n: Node):
"""
Return whether the node refers to an aten conv or conv transpose op.
"""
return _is_conv_node(n) or _is_conv_transpose_node(n)
def _is_conv_transpose_fn(conv_fn: Callable):
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
def _is_bn_node(n: Node):
return (
_is_supported_batch_norm_for_training(n)
or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
)
def fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
conv_bias_node: Optional[Node],
bn_node: Node,
m: GraphModule,
) -> None:
# conv args: input, weight, bias, stride, padding, dilation, ...
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
transpose = _is_conv_transpose_node(conv_node)
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
eps_arg_index = 6
elif _is_supported_batch_norm_for_training(bn_node):
eps_arg_index = 7
else:
raise ValueError("BN node target is unexpected ", bn_node.target)
bn_eps = bn_args[eps_arg_index]
fused_weight, fused_bias = fuse_conv_bn_weights(
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose
)
# update the weight and bias for conv
conv_args = list(conv_node.args)
# filling in the default bias argument
if len(conv_args) == 2:
conv_args.append(None)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_weight_node.target
assert isinstance(weight_attr_name, str)
_assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
if conv_bias_node is not None:
bias_attr_name = conv_bias_node.target
_assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
else:
bias_attr_name = weight_attr_name + "_bias"
_assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
with m.graph.inserting_before(conv_node):
get_bias_node = m.graph.get_attr(bias_attr_name)
# NOTE: here we assume the bias of conv is not quantized!
conv_args[2] = get_bias_node
conv_node.args = tuple(conv_args)
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
if bn_node.target == torch.ops.aten.batch_norm.default:
# With the new training ir, instead of batch_norm + getitem,
# we only have the batch_norm node.
#
# Before:
# conv -> bn -> users
# After:
# conv -> users
# bn has no users now
bn_node.replace_all_uses_with(conv_node)
else:
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_node.users:
if (
user.op != "call_function"
or user.target != operator.getitem
or user.args[1] != 0
):
continue
user.replace_all_uses_with(conv_node)
# If the BN node does not have users, erase it from the graph
# Note: we need to do this manually because the model can still be in train
# mode at this point, in which case DCE won't erase the BN node automatically
# since the node refers to a mutating op. Here we still need to call DCE first
# to get rid of the unused getitem nodes that consume the BN node.
m.graph.eliminate_dead_code()
if len(bn_node.users) == 0:
m.graph.erase_node(bn_node)
# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn:
return
for n in m.graph.nodes:
if n.op != "call_function" or n.target not in (
torch.ops.aten._native_batch_norm_legit_no_training.default,
torch.ops.aten.batch_norm.default,
):
continue
bn_node = n
n = bn_node.args[0]
if not _is_conv_or_conv_transpose_node(n):
continue
conv_node = n
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
fold_bn_weights_into_conv_node(
conv_node, conv_weight_node, conv_bias_node, bn_node, m
)
m.graph.eliminate_dead_code()
m.recompile()
def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]:
# TODO: move this information to fx node itself
node_name_to_scope: dict[str, tuple[str, type]] = {}
for n in model.graph.nodes:
nn_module_stack = n.meta.get("nn_module_stack", None)
current_scope = ("", type(None))
if nn_module_stack:
bt = list(nn_module_stack.values())[-1]
current_scope = (bt[0].split(".")[-1], bt[1])
node_name_to_scope[n.name] = current_scope
return node_name_to_scope
def _get_aten_graph_module_for_pattern(
pattern: Callable,
example_inputs: tuple[Any, ...],
is_cuda: bool = False,
**kwargs,
) -> GraphModule:
"""
Convert the pattern to an FX graph with decomposed aten ops.
"""
if is_cuda:
example_inputs = tuple(
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
)
aten_pattern = torch.export.export_for_training(
pattern, # type: ignore[arg-type]
example_inputs,
kwargs,
).module()
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
aten_pattern.recompile() # type: ignore[operator]
# ep.module() adds copy_ nodes for the mutated inputs.
# For patterns, it doesn't matter
for node in aten_pattern.graph.nodes: # type: ignore[union-attr]
if (
node.op == "call_function"
and node.target == torch.ops.aten.copy_.default
and len(node.users) == 0
):
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
aten_pattern.recompile() # type: ignore[operator]
return aten_pattern # type: ignore[return-value]
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
"""Remove .tensor overload for quantize/dequantize ops so that we can
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
"""
_MAP = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
}
for n in match_pattern.graph.nodes:
if n.op != "call_function":
continue
if n.target in _MAP:
n.target = _MAP[n.target]
def _is_literal(arg):
if isinstance(arg, (int, float)):
return True
if isinstance(arg, (tuple, list)):
return all(map(_is_literal, arg))
return False
def _replace_literals_with_new_placeholders(
gm: torch.fx.GraphModule,
merge_dup: bool = False,
exclude_literals: Optional[list[Any]] = None,
):
"""Replace the literals in the graph with placeholder nodes that's created on the fly while we
traverse the graph, so that the literal arguments in the graph can be matched and replaced
To use this, the pattern and replacement graph should have the exact same number of literal args
and they should be used in the exact same order in the pattern and replacement graph.
If the literal arguments are not used in the same order in pattern and replacement graph, please
use `_replace_literals_with_existing_placeholders` instead
Args:
`gm`: input GraphModule that we'll transform
`merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
the graph, whether they should correspond to the same placeholder or not
`exclude_literals`: a list of literals that will not be replaced with placeholders
Example:
# 1. Original Graph
def pattern(self, x):
return x + 3
def replacement(self, x):
return x - 3
example_inputs = (torch.randn(1, 3, 3, 3),)
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x):
return x + 3
def replacement(self, x):
return x - 3
pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
# 3. After replacing literals with new placeholder nodes
def pattern(self, x, new_ph):
return x + new_ph
def pattern(self, x, new_ph):
return x - new_ph
"""
last_ph = None
cnt = 0
literal_to_ph: dict[Union[float, bool, int, torch.dtype], Node] = {}
if exclude_literals is None:
exclude_literals = []
in_spec = gm._in_spec
args_spec = in_spec.children_specs[0]
for node in gm.graph.nodes:
if node.op == "placeholder":
last_ph = node
cnt += 1
continue
with gm.graph.inserting_after(last_ph):
new_args = []
for arg in node.args:
if _is_literal(arg) and arg not in exclude_literals:
if merge_dup and arg in literal_to_ph:
new_args.append(literal_to_ph[arg])
else:
ph_node = gm.graph.placeholder("arg" + str(cnt))
new_args.append(ph_node)
args_spec.children_specs.append(LeafSpec())
cnt += 1
if merge_dup:
literal_to_ph[arg] = ph_node
else:
new_args.append(arg)
new_args = tuple(new_args)
node.args = new_args
# Update `num_nodes`, `num_leaves`, `num_children`.
args_spec.__post_init__()
in_spec.__post_init__()
return gm
def _replace_literals_with_existing_placeholders(
gm: torch.fx.GraphModule,
exclude_literals: Optional[list[Any]] = None,
literal_to_ph_idx: Optional[dict[Union[float, int, bool, torch.dtype], int]] = None,
):
"""Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
in the graph can be matched and replaced
To use this, all literal args in the graph should be unique and each of them should correspond
to exactly one placeholder node
# 1. Original Graph
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
example_inputs = (
torch.randn(1, 3, 3, 3),
1.0,
0,
-128,
127,
)
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
x_i8 = torch.clamp(x_i8, -128, 127)
return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
# Note that literal args appear in different order in pattern and replacement graph, so
# we can't use _replace_literals_with_new_placeholders
literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
# 3. After replacing literals with existing placeholder nodes
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
"""
if exclude_literals is None:
exclude_literals = []
if literal_to_ph_idx is None:
literal_to_ph_idx = {}
phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
for node in gm.graph.nodes:
if node.op != "call_function":
continue
new_args = []
for arg in node.args:
if (
_is_literal(arg)
and arg not in exclude_literals
and arg in literal_to_ph_idx
):
ph_idx = literal_to_ph_idx[arg]
ph_node = phs[ph_idx]
new_args.append(ph_node)
else:
new_args.append(arg)
new_args = tuple(new_args)
node.args = new_args
return gm
# TODO: Handle this in export itself and don't wrap the model in another GraphModule
# in prepare and convert
def _disallow_eval_train(model: GraphModule):
"""
Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
This is useful for exported models, where these methods don't actually behave as expected.
"""
error_message = """
Calling train() or eval() is not supported for exported models.
Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.
If you cannot replace the calls to `model.train()` and `model.eval()`, you may override
the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,
which does the above automatically for you. Note that this has limited effect on switching
behavior between train and eval modes, and should be used only for special ops such as dropout
and batchnorm.
"""
def _train(self, mode: bool = True):
raise NotImplementedError(error_message)
def _eval(self, mode: bool = True):
raise NotImplementedError(error_message)
model.train = types.MethodType(_train, model) # type: ignore[method-assign]
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
return model