Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,775 @@
|
|||
# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
|
||||
# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
|
||||
# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC
|
||||
import logging
|
||||
from abc import ABCMeta
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.observer import (
|
||||
AffineQuantizedObserverBase,
|
||||
get_block_size,
|
||||
MappingType,
|
||||
TorchAODType,
|
||||
ZeroPointDomain,
|
||||
)
|
||||
from torch.fx import Node
|
||||
|
||||
|
||||
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FP8_TYPES = {
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
}
|
||||
_SUB_BYTE_UINT_BOUNDS = {
|
||||
torch.uint1: (0, 2**1 - 1),
|
||||
torch.uint2: (0, 2**2 - 1),
|
||||
torch.uint3: (0, 2**3 - 1),
|
||||
torch.uint4: (0, 2**4 - 1),
|
||||
torch.uint5: (0, 2**5 - 1),
|
||||
torch.uint6: (0, 2**6 - 1),
|
||||
torch.uint7: (0, 2**7 - 1),
|
||||
}
|
||||
|
||||
"""
|
||||
Map from dtype to the bound value of integers
|
||||
TODO: maybe can replace this with call to torch.iinfo
|
||||
"""
|
||||
_DTYPE_TO_QVALUE_BOUNDS: dict[Union[torch.dtype, TorchAODType], tuple[int, int]] = {
|
||||
torch.uint8: (0, 255),
|
||||
torch.int8: (-128, 127),
|
||||
torch.int16: (-(2**15), 2**15 - 1),
|
||||
torch.int32: (-(2**31), 2**31 - 1),
|
||||
}
|
||||
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
|
||||
|
||||
|
||||
def _is_float8_type(dtype: torch.dtype) -> bool:
|
||||
fp8_types = {
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
}
|
||||
return dtype in fp8_types
|
||||
|
||||
|
||||
# TODO: decide on if we want to allow custom quant_min/quant_max here
|
||||
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
|
||||
"""Get quant_min and quant_max args based on dtype and also
|
||||
verify that they are within the range of possible quant_min/quant_max
|
||||
for dtype
|
||||
"""
|
||||
if dtype in FP8_TYPES:
|
||||
quant_min_lower_bound, quant_max_upper_bound = (
|
||||
torch.finfo(dtype).min,
|
||||
torch.finfo(dtype).max,
|
||||
)
|
||||
elif dtype not in _DTYPE_TO_QVALUE_BOUNDS:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
else:
|
||||
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
|
||||
if quant_min is None:
|
||||
quant_min = quant_min_lower_bound
|
||||
if quant_max is None:
|
||||
quant_max = quant_max_upper_bound
|
||||
|
||||
assert quant_min >= quant_min_lower_bound, (
|
||||
"quant_min out of bound for dtype, "
|
||||
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
|
||||
)
|
||||
|
||||
assert quant_max <= quant_max_upper_bound, (
|
||||
"quant_max out of bound for dtype, "
|
||||
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
|
||||
)
|
||||
return quant_min, quant_max
|
||||
|
||||
|
||||
def _get_reduction_params(block_size, input_size):
|
||||
"""Given block_size and input size find the parameters for reduction:
|
||||
|
||||
Output:
|
||||
shape_for_reduction: the shape we use to `view` input to prepare it for reduction
|
||||
reduction_dims: the dims we'll do reduction over
|
||||
|
||||
Example::
|
||||
Input:
|
||||
block_size: (3, 3, 2, 10)
|
||||
input_size: (3, 3, 10, 10)
|
||||
|
||||
Output:
|
||||
shape_for_reduction: (3, 3, 5, 2, 10)
|
||||
reduction_dim: [0, 1, 3, 4]
|
||||
"""
|
||||
assert len(block_size) == len(input_size)
|
||||
shape_for_reduction = []
|
||||
reduction_dims = []
|
||||
cur_dim = 0
|
||||
for i in range(len(block_size)):
|
||||
if block_size[i] != input_size[i] and block_size[i] > 1:
|
||||
assert input_size[i] % block_size[i] == 0, (
|
||||
f"Expecting input size at {i} dimension: "
|
||||
f"{input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}"
|
||||
)
|
||||
shape_for_reduction.append(input_size[i] // block_size[i])
|
||||
shape_for_reduction.append(block_size[i])
|
||||
# reduce over the block_size[i] dim
|
||||
reduction_dims.append(cur_dim + 1)
|
||||
cur_dim += 2
|
||||
else:
|
||||
# block_size[i] == input_size[i] or block_size[i] == 1
|
||||
shape_for_reduction.append(input_size[i])
|
||||
# we only need to reduce over the dimension if block_size is greater than 1
|
||||
# otherwise it's already the same as reduced dimension
|
||||
if block_size[i] != 1:
|
||||
reduction_dims.append(cur_dim)
|
||||
cur_dim += 1
|
||||
return shape_for_reduction, reduction_dims
|
||||
|
||||
|
||||
def _register_custom_op(lib):
|
||||
"""This decorator is used to preserve some high level operators for torch.export.export
|
||||
while still allow them to be decomposed for inductor path
|
||||
|
||||
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
|
||||
|
||||
NOTE: This should be applied at the top, after all other decorators have been applied
|
||||
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
|
||||
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
|
||||
sense for downstream system (like executorch) to accept as well
|
||||
|
||||
Example:
|
||||
lib = torch.library.Library("my_namespace', "FRAGMENT")
|
||||
|
||||
register_custom_op = _register_custom_op(lib)
|
||||
|
||||
@register_custom_op
|
||||
def _the_op_that_needs_to_be_preserved(...)
|
||||
...
|
||||
|
||||
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
|
||||
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
|
||||
# torch.export.export / torch._export.export_for_training
|
||||
|
||||
"""
|
||||
from torch._inductor.decomposition import register_decomposition
|
||||
|
||||
def decorator(fn):
|
||||
from torch._library.infer_schema import infer_schema
|
||||
|
||||
# expecting fn.__name__ starts with `_` and we want to take the rest
|
||||
# to be the name of the custom op
|
||||
assert (
|
||||
fn.__name__[0] == "_"
|
||||
), f"Expecting function name starts with `_`, got {fn.__name__}"
|
||||
assert not any(
|
||||
c in fn.__name__ for c in ".<>"
|
||||
), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
|
||||
op_name = fn.__name__[1:]
|
||||
schema = op_name + infer_schema(fn, mutates_args={})
|
||||
lib.define(schema)
|
||||
lib.impl(op_name, fn, "CompositeImplicitAutograd")
|
||||
|
||||
lib_namespace = lib.ns
|
||||
op = getattr(getattr(torch.ops, lib_namespace), op_name)
|
||||
register_decomposition([op])(fn)
|
||||
return op
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
quant_lib = torch.library.Library("pt2e_quant", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
register_custom_op = _register_custom_op(quant_lib)
|
||||
|
||||
|
||||
def choose_qparams_affine_with_min_max(
|
||||
min_val: torch.Tensor,
|
||||
max_val: torch.Tensor,
|
||||
mapping_type: MappingType,
|
||||
block_size: tuple[int, ...],
|
||||
target_dtype: torch.dtype,
|
||||
quant_min: Optional[int] = None,
|
||||
quant_max: Optional[int] = None,
|
||||
eps: Optional[float] = None,
|
||||
scale_dtype: Optional[torch.dtype] = None,
|
||||
zero_point_dtype: Optional[torch.dtype] = None,
|
||||
preserve_zero: bool = True,
|
||||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
|
||||
operator that pass in min_val and max_val directly instead of deriving these from a single input.
|
||||
This is used for observers in static quantization where min_val and max_val may be obtained through
|
||||
tracking all the data in calibration data set.
|
||||
|
||||
Args:
|
||||
Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
|
||||
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
|
||||
and then scale/zero_point, we pass in min_val/max_val directly
|
||||
"""
|
||||
return _choose_qparams_affine(
|
||||
None,
|
||||
mapping_type.name,
|
||||
block_size,
|
||||
target_dtype,
|
||||
quant_min,
|
||||
quant_max,
|
||||
eps,
|
||||
scale_dtype,
|
||||
zero_point_dtype,
|
||||
preserve_zero,
|
||||
zero_point_domain.name if zero_point_domain is not None else None,
|
||||
min_val,
|
||||
max_val,
|
||||
)
|
||||
|
||||
|
||||
@register_custom_op
|
||||
def _choose_qparams_affine(
|
||||
input: Optional[torch.Tensor],
|
||||
mapping_type: str,
|
||||
block_size: list[int],
|
||||
target_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float, bool]] = None,
|
||||
quant_max: Optional[Union[int, float, bool]] = None,
|
||||
eps: Optional[float] = None,
|
||||
scale_dtype: Optional[torch.dtype] = None,
|
||||
zero_point_dtype: Optional[torch.dtype] = None,
|
||||
preserve_zero: bool = True,
|
||||
zero_point_domain: Optional[str] = "INT",
|
||||
min_val: Optional[torch.Tensor] = None,
|
||||
max_val: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""op definition that has compatible signatures with custom op library
|
||||
|
||||
The op does the following:
|
||||
1. figure out the dimension for reduction based on block_size
|
||||
2. find min_val/max_val based on the dimension for reduction
|
||||
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
|
||||
and `zero_point_domain`
|
||||
"""
|
||||
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
|
||||
assert mapping_type in [
|
||||
MappingType.SYMMETRIC.name,
|
||||
MappingType.SYMMETRIC_NO_CLIPPING_ERR.name,
|
||||
MappingType.ASYMMETRIC.name,
|
||||
], f"Unsupported mapping type: {mapping_type}"
|
||||
if target_dtype in FP8_TYPES:
|
||||
assert (
|
||||
mapping_type == MappingType.SYMMETRIC.name
|
||||
), f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
|
||||
|
||||
if input is not None:
|
||||
if scale_dtype is None:
|
||||
scale_dtype = input.dtype
|
||||
if zero_point_dtype is None:
|
||||
zero_point_dtype = input.dtype
|
||||
if eps is None:
|
||||
eps = torch.finfo(input.dtype).eps
|
||||
|
||||
assert (
|
||||
len(block_size) == input.dim()
|
||||
), f"Got input dim:{input.dim()}, block_size: {block_size}"
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
block_size, input.size()
|
||||
)
|
||||
input = input.view(shape_for_reduction)
|
||||
|
||||
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
|
||||
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
|
||||
else:
|
||||
assert (
|
||||
min_val is not None and max_val is not None
|
||||
), "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
|
||||
assert (
|
||||
min_val.dtype == max_val.dtype
|
||||
), "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"
|
||||
|
||||
if scale_dtype is None:
|
||||
scale_dtype = min_val.dtype
|
||||
if zero_point_dtype is None:
|
||||
zero_point_dtype = min_val.dtype
|
||||
if eps is None:
|
||||
eps = torch.finfo(min_val.dtype).eps
|
||||
|
||||
if preserve_zero:
|
||||
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
||||
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
||||
else:
|
||||
min_val_neg = min_val
|
||||
max_val_pos = max_val
|
||||
|
||||
if (
|
||||
mapping_type == MappingType.SYMMETRIC.name
|
||||
or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
|
||||
):
|
||||
# scales
|
||||
if mapping_type == MappingType.SYMMETRIC.name:
|
||||
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
||||
scale = max_val_pos / (float(quant_max - quant_min) / 2)
|
||||
else:
|
||||
assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
|
||||
# calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and
|
||||
# quant_max = 7.
|
||||
# - If smin is bigger: There would be coverage on negative values down to -8, and less rounding
|
||||
# error than the existing SYMMETRIC case.
|
||||
# - If smax is bigger: it covers the positive values up to 7. The round
|
||||
# error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after
|
||||
# quantization.
|
||||
smin = min_val_neg / float(quant_min)
|
||||
smax = max_val_pos / float(quant_max)
|
||||
mask = smin > smax
|
||||
scale = torch.where(mask, smin, smax)
|
||||
# zeros
|
||||
if not preserve_zero:
|
||||
raise ValueError(
|
||||
"preserve_zero == False is not supported for symmetric quantization"
|
||||
)
|
||||
if (
|
||||
zero_point_domain is not None
|
||||
and zero_point_domain != ZeroPointDomain.INT.name
|
||||
):
|
||||
raise ValueError(
|
||||
"zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization"
|
||||
)
|
||||
scale = torch.clamp(scale, min=eps)
|
||||
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
|
||||
else:
|
||||
assert mapping_type == MappingType.ASYMMETRIC.name
|
||||
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
|
||||
scale = torch.clamp(scale, min=eps)
|
||||
if zero_point_domain == ZeroPointDomain.NONE.name:
|
||||
zero_point = None
|
||||
else:
|
||||
if preserve_zero:
|
||||
zero_point = quant_min - torch.round(min_val_neg / scale)
|
||||
zero_point = torch.clamp(zero_point, quant_min, quant_max)
|
||||
else:
|
||||
assert (
|
||||
zero_point_domain == ZeroPointDomain.FLOAT.name
|
||||
), "if not preserve_zero, zero_point must be in FLOAT domain"
|
||||
mid_point = (quant_max + quant_min + 1) / 2
|
||||
zero_point = min_val_neg + scale * mid_point
|
||||
|
||||
if zero_point is not None:
|
||||
zero_point = zero_point.to(dtype=zero_point_dtype)
|
||||
return scale.to(dtype=scale_dtype), zero_point
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def quantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: tuple[int, ...],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
output_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float]] = None,
|
||||
quant_max: Optional[Union[int, float]] = None,
|
||||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
|
||||
block_size: (Tuple[int, ...]): granularity of quantization,
|
||||
this means the size of the tensor elements that's sharing the same qparam
|
||||
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
|
||||
scale (float): quantization parameter for affine quantization
|
||||
zero_point (int): quantization parameter for affine quantization
|
||||
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
|
||||
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
|
||||
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
|
||||
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
|
||||
if zero_point is in integer domain, zero point is added to the quantized integer value during
|
||||
quantization
|
||||
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
|
||||
value during quantization
|
||||
default is ZeroPointDomain.INT
|
||||
|
||||
Note:
|
||||
How can block_size represent different granularities?
|
||||
let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different
|
||||
granularities:
|
||||
|
||||
granularity type | block_size
|
||||
per_tensor | (3, 3, 10, 10)
|
||||
per_axis (axis=0) | (1, 3, 10, 10)
|
||||
per_axis (axis=1) | (3, 1, 10, 10)
|
||||
per_group (groupsize=2) | (3, 3, 10, 2)
|
||||
per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10)
|
||||
|
||||
|
||||
Output:
|
||||
quantized tensor with requested dtype
|
||||
"""
|
||||
return _quantize_affine(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
output_dtype,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain.name if zero_point_domain is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@register_custom_op
|
||||
def _quantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: list[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
output_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float, bool]] = None,
|
||||
quant_max: Optional[Union[int, float, bool]] = None,
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
) -> torch.Tensor:
|
||||
"""op definition that has compatible signatures with custom op library
|
||||
|
||||
Note:
|
||||
zero_point_domain is optional specifies how we quantize the floating point to quantized data:
|
||||
INT: quantized_val = (float_val / scale) (integer) + zero_point (integer)
|
||||
FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
|
||||
None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization
|
||||
Where we do not want to round values to nearest integer and instead scale and cast.
|
||||
"""
|
||||
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
|
||||
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
|
||||
# torch.uintx dtypes yet
|
||||
if output_dtype in _SUB_BYTE_UINT_BOUNDS:
|
||||
output_dtype = torch.uint8
|
||||
return _quantize_affine_no_dtype_cast(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain,
|
||||
).to(output_dtype)
|
||||
|
||||
|
||||
def _quantize_affine_no_dtype_cast(
|
||||
input: torch.Tensor,
|
||||
block_size: list[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
quant_min: Union[int, float],
|
||||
quant_max: Union[int, float],
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The op does the following:
|
||||
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
|
||||
the shape after reduction
|
||||
2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
|
||||
3. reshape the quantized result to origianl shape
|
||||
"""
|
||||
# TODO: validations
|
||||
# TODO: validate scale/zero_point dimensions are compatible with block_size
|
||||
assert input.dtype in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], f"Unsupported input dtype: {input.dtype}"
|
||||
assert (
|
||||
len(block_size) == input.dim()
|
||||
), f"Got input dim:{input.dim()}, block_size: {block_size}"
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
block_size, input.size()
|
||||
)
|
||||
original_shape = input.shape
|
||||
input = input.view(shape_for_reduction)
|
||||
shape_after_reduction = shape_for_reduction
|
||||
for i in reduction_dims:
|
||||
shape_after_reduction[i] = 1
|
||||
scale = scale.view(shape_after_reduction)
|
||||
if zero_point is not None:
|
||||
zero_point = zero_point.view(shape_after_reduction)
|
||||
|
||||
if zero_point_domain == ZeroPointDomain.INT.name:
|
||||
quant = torch.clamp(
|
||||
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
|
||||
)
|
||||
elif zero_point_domain == ZeroPointDomain.NONE.name:
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is NONE"
|
||||
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
|
||||
elif zero_point_domain is None:
|
||||
# This case handles quantization for float8 we expect no zero point and no zero point domain
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is None"
|
||||
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
|
||||
else:
|
||||
assert zero_point_domain == ZeroPointDomain.FLOAT.name
|
||||
mid_point = (quant_max + quant_min + 1) / 2
|
||||
min_val = zero_point - scale * mid_point
|
||||
quant = torch.clamp(
|
||||
torch.round((input - min_val) / scale), quant_min, quant_max
|
||||
)
|
||||
quant = quant.view(original_shape)
|
||||
|
||||
return quant
|
||||
|
||||
|
||||
def dequantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: tuple[int, ...],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
input_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float]] = None,
|
||||
quant_max: Optional[Union[int, float]] = None,
|
||||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
|
||||
*,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
|
||||
block_size: (List[int]): granularity of quantization,
|
||||
this means the size of the tensor elements that's sharing the same qparam
|
||||
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
|
||||
scale (Tensor): quantization parameter for affine quantization
|
||||
zero_point (Tensor): quantization parameter for affine quantization
|
||||
input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
|
||||
quant_min (Optional[int]): minimum quantized value for input Tensor
|
||||
quant_max (Optional[int]): maximum quantized value for input Tensor
|
||||
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
|
||||
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
|
||||
if zero_point is in integer domain, zero point is added to the quantized integer value during
|
||||
quantization
|
||||
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
|
||||
value during quantization
|
||||
default is ZeroPointDomain.INT
|
||||
|
||||
Output:
|
||||
dequantized Tensor, with requested dtype or fp32
|
||||
"""
|
||||
return _dequantize_affine(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
input_dtype,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain.name if zero_point_domain is not None else None,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
|
||||
@register_custom_op
|
||||
def _dequantize_affine(
|
||||
input: torch.Tensor,
|
||||
block_size: list[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
input_dtype: torch.dtype,
|
||||
quant_min: Optional[Union[int, float, bool]] = None,
|
||||
quant_max: Optional[Union[int, float, bool]] = None,
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""op definition that has compatible signatures with custom op library"""
|
||||
# TODO: validate scale/zero_point dimensions are compatible with block_size
|
||||
if input_dtype not in _SUB_BYTE_UINT_BOUNDS:
|
||||
assert (
|
||||
input.dtype == input_dtype
|
||||
), f"Expected: {input_dtype}, got: {input.dtype}"
|
||||
assert output_dtype in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], f"Unsupported output dtype: {output_dtype}"
|
||||
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
|
||||
return _dequantize_affine_no_dtype_check(
|
||||
input,
|
||||
block_size,
|
||||
scale,
|
||||
zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
zero_point_domain,
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_affine_no_dtype_check(
|
||||
input: torch.Tensor,
|
||||
block_size: list[int],
|
||||
scale: torch.Tensor,
|
||||
zero_point: Optional[torch.Tensor],
|
||||
quant_min: Union[int, float],
|
||||
quant_max: Union[int, float],
|
||||
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""This function converts AQT tensors to their high precision floating point representation
|
||||
|
||||
The op does the following:
|
||||
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
|
||||
the shape after reduction
|
||||
2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
|
||||
3. reshape the quantized result to origianl shape and change dtype to the output_dtype
|
||||
"""
|
||||
assert (
|
||||
len(block_size) == input.dim()
|
||||
), f"Got input dim:{input.dim()}, block_size: {block_size}"
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
block_size, input.size()
|
||||
)
|
||||
original_shape = input.shape
|
||||
input = input.view(shape_for_reduction)
|
||||
shape_after_reduction = shape_for_reduction
|
||||
for i in reduction_dims:
|
||||
shape_after_reduction[i] = 1
|
||||
scale = scale.view(shape_after_reduction)
|
||||
|
||||
if zero_point is not None:
|
||||
zero_point = zero_point.view(shape_after_reduction)
|
||||
|
||||
if zero_point_domain == ZeroPointDomain.INT.name:
|
||||
# Force a copy to avoid input modification due
|
||||
# to upcoming in-place operations.
|
||||
dequant = input.to(torch.int32, copy=True)
|
||||
if zero_point is not None:
|
||||
dequant = dequant - zero_point.to(torch.int32)
|
||||
dequant = dequant.to(output_dtype)
|
||||
dequant = dequant * scale
|
||||
elif zero_point_domain == ZeroPointDomain.NONE.name:
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is NONE"
|
||||
dequant = input.to(output_dtype)
|
||||
dequant = dequant * scale
|
||||
elif zero_point_domain is None:
|
||||
# This case handles dequantization for float8 we expect no zero point and no zero point domain
|
||||
assert (
|
||||
zero_point is None
|
||||
), "zero_point should be None when zero_point_domain is None"
|
||||
assert _is_float8_type(
|
||||
input.dtype
|
||||
), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}"
|
||||
dequant = input.to(output_dtype)
|
||||
dequant = dequant * scale
|
||||
else:
|
||||
assert (
|
||||
zero_point_domain == ZeroPointDomain.FLOAT.name
|
||||
), f"Unexpected zero point domain: {zero_point_domain}"
|
||||
# TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this)
|
||||
mid_point = (quant_max + quant_min + 1) / 2
|
||||
# This should allocate new memory and avoid input modification
|
||||
dequant = input - mid_point
|
||||
dequant = dequant.to(output_dtype)
|
||||
dequant *= scale
|
||||
if zero_point is not None:
|
||||
dequant += zero_point
|
||||
|
||||
return dequant.view(original_shape).to(output_dtype)
|
||||
|
||||
|
||||
class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
|
||||
def forward(self, input: torch.Tensor):
|
||||
if input.numel() == 0:
|
||||
return input
|
||||
|
||||
input_detached = input.detach()
|
||||
self.original_dtype = input_detached.dtype
|
||||
assert self.granularity is not None, "granularity is None"
|
||||
self.block_size = get_block_size(input_detached.shape, self.granularity)
|
||||
|
||||
shape_for_reduction, reduction_dims = _get_reduction_params(
|
||||
self.block_size, input_detached.size()
|
||||
)
|
||||
input_detached = input_detached.view(shape_for_reduction)
|
||||
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
|
||||
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
|
||||
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
else:
|
||||
assert (
|
||||
self.min_val.shape == min_val.shape
|
||||
), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
|
||||
assert (
|
||||
self.max_val.shape == max_val.shape
|
||||
), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
|
||||
min_val = torch.min(self.min_val, min_val)
|
||||
max_val = torch.max(self.max_val, max_val)
|
||||
self.min_val.copy_(min_val)
|
||||
self.max_val.copy_(max_val)
|
||||
# returning original input
|
||||
return input
|
||||
|
||||
def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hasattr(self, "min_val") and hasattr(
|
||||
self, "max_val"
|
||||
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
|
||||
return choose_qparams_affine_with_min_max(
|
||||
self.min_val,
|
||||
self.max_val,
|
||||
self.mapping_type,
|
||||
[], # BlockSize is not needed because the min/max are already reduced
|
||||
self.target_dtype,
|
||||
self.quant_min,
|
||||
self.quant_max,
|
||||
self.eps,
|
||||
self.scale_dtype,
|
||||
self.zero_point_dtype,
|
||||
self.preserve_zero,
|
||||
self.zero_point_domain,
|
||||
)
|
||||
|
||||
def convert(self, model: torch.fx.GraphModule, observer_node: Node):
|
||||
print("calling convert")
|
||||
from torch.ao.quantization.fx.utils import create_getattr_from_value
|
||||
|
||||
scale, zero_point = self.calculate_qparams()
|
||||
with model.graph.inserting_before(observer_node):
|
||||
assert self.block_size is not None, "Expecting block_size to be populated"
|
||||
assert (
|
||||
self.original_dtype is not None
|
||||
), "Expecting original_dtype to be populated"
|
||||
scale_node = create_getattr_from_value(model, model.graph, "_scale", scale)
|
||||
zero_point_node = create_getattr_from_value(
|
||||
model, model.graph, "_zero_point", zero_point
|
||||
)
|
||||
q_node = model.graph.call_function(
|
||||
torch.ops.pt2e_quant.quantize_affine,
|
||||
(
|
||||
observer_node.args[0],
|
||||
self.block_size,
|
||||
scale_node,
|
||||
zero_point_node,
|
||||
self.target_dtype,
|
||||
self.quant_min,
|
||||
self.quant_max,
|
||||
self.zero_point_domain.name,
|
||||
),
|
||||
{},
|
||||
)
|
||||
dq_node = model.graph.call_function(
|
||||
torch.ops.pt2e_quant.dequantize_affine,
|
||||
(
|
||||
q_node,
|
||||
self.block_size,
|
||||
scale_node,
|
||||
zero_point_node,
|
||||
self.target_dtype,
|
||||
self.quant_min,
|
||||
self.quant_max,
|
||||
self.zero_point_domain.name,
|
||||
),
|
||||
{"output_dtype": self.original_dtype},
|
||||
)
|
||||
observer_node.replace_all_uses_with(dq_node)
|
||||
model.graph.erase_node(observer_node)
|
|
@ -0,0 +1,342 @@
|
|||
import copy
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.ao.ns.fx.utils import compute_sqnr
|
||||
from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
|
||||
from torch.export import ExportedProgram
|
||||
from torch.fx import GraphModule, Node
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
|
||||
CUSTOM_KEY = "custom"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
|
||||
"""
|
||||
Attach numeric_debug_handle_id for all nodes in the graph module of the given
|
||||
ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder.
|
||||
Notice that nodes like getattr are out of scope since they are not in the graph.
|
||||
|
||||
The graph nodes of input exported program are modified inplace.
|
||||
|
||||
Here's an example of using debug handle quantize flow::
|
||||
|
||||
ep = export_for_training(eager_model, example_inputs)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
m = ep.module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m = convert_pt2e(m)
|
||||
"""
|
||||
|
||||
# Sanity check the input data type
|
||||
if not isinstance(ep, ExportedProgram):
|
||||
raise ValueError(
|
||||
f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}"
|
||||
)
|
||||
|
||||
unique_id = 0
|
||||
|
||||
def _find_max_id(node: torch.fx.Node) -> None:
|
||||
nonlocal unique_id
|
||||
unique_id = max(
|
||||
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
|
||||
)
|
||||
|
||||
def _assign_debug_handle(node: torch.fx.Node) -> None:
|
||||
nonlocal unique_id
|
||||
if CUSTOM_KEY not in node.meta:
|
||||
node.meta[CUSTOM_KEY] = {}
|
||||
|
||||
if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]:
|
||||
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
|
||||
unique_id += 1
|
||||
|
||||
# Find the max ID that exists in the graph first, in case part of the graph
|
||||
# has already been annotated. This way we guarantee there are no duplicate
|
||||
# handle IDs.
|
||||
bfs_trace_with_node_process(ep, _find_max_id)
|
||||
|
||||
unique_id += 1
|
||||
|
||||
# Assign debug handles to all nodes in the graph that don't have one based on the
|
||||
# max ID found in the previous step.
|
||||
bfs_trace_with_node_process(ep, _assign_debug_handle)
|
||||
|
||||
|
||||
def _detach(x: object) -> object:
|
||||
detached: object = None
|
||||
if isinstance(x, torch.Tensor):
|
||||
detached = x.detach()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
detached = type(x)([_detach(e) for e in x])
|
||||
elif isinstance(x, dict):
|
||||
detached = {k: _detach(e) for k, e in x.items()}
|
||||
else:
|
||||
detached = x
|
||||
return detached
|
||||
|
||||
|
||||
def _tensor_shape_equals(x: object, y: object) -> bool:
|
||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
return x.shape == y.shape
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y))
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
all_equal = True
|
||||
for k in x:
|
||||
all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k]))
|
||||
return all_equal
|
||||
else:
|
||||
log.debug("Comparing non Tensors: %s and %s, they must be equal", x, y)
|
||||
return type(x) == type(y) and x == y
|
||||
|
||||
|
||||
def _loss_fn(
|
||||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object
|
||||
) -> object:
|
||||
"""The returned loss will have the same structure as `x` and `y`, e.g.
|
||||
if both are Tensor, we'll return a Tensor
|
||||
if both are list, we'll return a list of Tensors
|
||||
if both are dict, we'll return a dict with the same key, and value being the loss between the
|
||||
two Tensors
|
||||
"""
|
||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
return loss(x.to(torch.float32), y.to(torch.float32))
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)])
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class OutputLogger(torch.nn.Module):
|
||||
"""
|
||||
Base class for capturing output values for nodes in a GraphModule, it only captures
|
||||
Tensor output currently, but we can extend it to work for other types of inputs later if needed
|
||||
"""
|
||||
|
||||
# Mark as impure so that calls to it will not be removed during DCE.
|
||||
_is_impure = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
debug_handle: int,
|
||||
node_name: Optional[str] = None,
|
||||
nn_module_stack: Optional[object] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.node_name = node_name
|
||||
self.nn_module_stack = nn_module_stack
|
||||
self.debug_handle = debug_handle
|
||||
self.stats: list[object] = []
|
||||
|
||||
def forward(self, x: object) -> object:
|
||||
self.stats.append(_detach(x))
|
||||
return x
|
||||
|
||||
def __extra_repr__(self) -> str:
|
||||
return (
|
||||
f"debug_handle={self.debug_handle}, node_name={self.node_name}, "
|
||||
"nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})"
|
||||
)
|
||||
|
||||
|
||||
def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:
|
||||
"""For a given node, adds an OutputLogger that observes the output of that node,
|
||||
and all its users use the OutputLogger output instead.
|
||||
The OutputLogger will contain the debug_handle which can be used to compare
|
||||
graphs after transforms"""
|
||||
|
||||
# to avoid circular dep
|
||||
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
||||
|
||||
# add a logger after the node
|
||||
with model.graph.inserting_after(node):
|
||||
get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger")
|
||||
logger_name = get_new_attr_name(model)
|
||||
setattr(
|
||||
model,
|
||||
logger_name,
|
||||
OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")),
|
||||
)
|
||||
logger_node = model.graph.call_module(logger_name, (node,), {})
|
||||
|
||||
orig_users = list(node.users.keys())
|
||||
for user_node in orig_users:
|
||||
if user_node is logger_node:
|
||||
continue
|
||||
user_node.replace_input_with(node, logger_node)
|
||||
|
||||
return logger_node
|
||||
|
||||
|
||||
def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
|
||||
"""Add output loggers to node that has numeric_debug_handle
|
||||
|
||||
Args:
|
||||
model (GraphModule): original model
|
||||
Returns:
|
||||
a model with output loggers for all nodes that has numeric_debug_handle_id
|
||||
"""
|
||||
# don't change the original model
|
||||
model = copy.deepcopy(model)
|
||||
for n in model.graph.nodes:
|
||||
if (
|
||||
CUSTOM_KEY not in n.meta
|
||||
or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY]
|
||||
):
|
||||
continue
|
||||
numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
|
||||
_insert_logger(model, n, numeric_debug_handle)
|
||||
|
||||
model.recompile()
|
||||
return model
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuantizationComparisonResult:
|
||||
actual: torch.Tensor
|
||||
ref: torch.Tensor
|
||||
|
||||
@property
|
||||
def mse_loss(self) -> object:
|
||||
return self.loss(F.mse_loss)
|
||||
|
||||
@property
|
||||
def sqnr(self) -> object:
|
||||
return self.loss(compute_sqnr)
|
||||
|
||||
def loss(
|
||||
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
) -> object:
|
||||
return _loss_fn(loss_function, self.actual, self.ref)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Don't include the tensors themselves as they are quite large to print
|
||||
# out.
|
||||
return (
|
||||
f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})"
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)):
|
||||
raise ValueError(
|
||||
f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}"
|
||||
)
|
||||
|
||||
if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)):
|
||||
raise ValueError(
|
||||
f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}"
|
||||
)
|
||||
|
||||
if not _tensor_shape_equals(self.ref, self.actual):
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NodeAccuracySummary:
|
||||
handle: int
|
||||
actual_node_name: str
|
||||
actual_module_stack: str
|
||||
ref_node_name: str
|
||||
ref_module_stack: str
|
||||
results: Sequence[QuantizationComparisonResult]
|
||||
|
||||
|
||||
def _module_stack_to_str(module_stack: object) -> str:
|
||||
"""Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear")
|
||||
to "mod.foo.0.linear"
|
||||
"""
|
||||
if not isinstance(module_stack, dict):
|
||||
return str(module_stack)
|
||||
module_values_list = list(module_stack.values())
|
||||
if len(module_values_list) > 0:
|
||||
owning_module = module_values_list[-1][0]
|
||||
return str(owning_module)
|
||||
else:
|
||||
return str(module_stack)
|
||||
|
||||
|
||||
def extract_results_from_loggers(
|
||||
model: GraphModule,
|
||||
) -> dict[int, tuple[Optional[str], object, list[object]]]:
|
||||
"""For a given model, extract the tensors stats and related information for each debug handle.
|
||||
The reason we have a list of object, instead of Tensor is because the output of node may not be
|
||||
a Tensor, it could be (nested) list, tuple or dict as well.
|
||||
|
||||
Returns:
|
||||
A dict is keyed by the debug_handle id and the values are a list of object recorded
|
||||
in loggers
|
||||
|
||||
"""
|
||||
# Results maps debug handle to a tensor list for each model being compared.
|
||||
handles: dict[int, tuple[Optional[str], object, list[object]]] = {}
|
||||
for _name, module in model.named_children():
|
||||
if isinstance(module, OutputLogger) and len(module.stats) > 0:
|
||||
handles[module.debug_handle] = (
|
||||
module.node_name,
|
||||
module.nn_module_stack,
|
||||
module.stats,
|
||||
)
|
||||
|
||||
return handles
|
||||
|
||||
|
||||
def compare_results(
|
||||
ref_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]],
|
||||
actual_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]],
|
||||
) -> dict[int, NodeAccuracySummary]:
|
||||
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
|
||||
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
|
||||
comparison information like SQNR, MSE etc.
|
||||
|
||||
Args:
|
||||
ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id
|
||||
actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id
|
||||
|
||||
Returns:
|
||||
Dict[int, NodeAccuracySummary]
|
||||
"""
|
||||
comparisons = {}
|
||||
for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items():
|
||||
if debug_handle not in actual_results:
|
||||
log.debug(
|
||||
"Cannot compare for handle %s because it wasn't found in the transformed model",
|
||||
debug_handle,
|
||||
)
|
||||
continue
|
||||
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
|
||||
try:
|
||||
results = [
|
||||
QuantizationComparisonResult(actual=a, ref=b)
|
||||
for a, b in zip(actual_stats, ref_stats)
|
||||
]
|
||||
except Exception as e:
|
||||
# Add extra information for an exception from QuantizationComparisonResult
|
||||
# if the shapes didn't match, to include the handle and the node names.
|
||||
raise ValueError(
|
||||
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
|
||||
) from e
|
||||
|
||||
comparisons[debug_handle] = NodeAccuracySummary(
|
||||
handle=debug_handle,
|
||||
actual_node_name=actual_name or "",
|
||||
actual_module_stack=_module_stack_to_str(actual_stack),
|
||||
ref_node_name=ref_name or "",
|
||||
ref_module_stack=_module_stack_to_str(ref_stack),
|
||||
results=results,
|
||||
)
|
||||
|
||||
return comparisons
|
|
@ -0,0 +1,82 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_filter_sym_size_users,
|
||||
_is_valid_annotation,
|
||||
)
|
||||
from torch.fx.node import map_arg
|
||||
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
__all__ = ["DuplicateDQPass"]
|
||||
|
||||
_QUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
]
|
||||
|
||||
_DEQUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
def _maybe_duplicate_dq(
|
||||
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
|
||||
):
|
||||
annotation = user.meta.get("quantization_annotation", None)
|
||||
if not _is_valid_annotation(annotation):
|
||||
return
|
||||
with gm.graph.inserting_after(dq_node):
|
||||
new_node = gm.graph.node_copy(dq_node)
|
||||
|
||||
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
|
||||
if n == dq_node:
|
||||
return new_node
|
||||
else:
|
||||
return n
|
||||
|
||||
new_args = map_arg(user.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
|
||||
user.args = new_args # type: ignore[assignment]
|
||||
user.kwargs = new_kwargs # type: ignore[assignment]
|
||||
|
||||
|
||||
class DuplicateDQPass(PassBase):
|
||||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
|
||||
dq_users = _filter_sym_size_users(node)
|
||||
if len(dq_users) <= 1:
|
||||
continue
|
||||
# Do not duplicate dq for dynamic quantization
|
||||
# Pattern: choose_qparam - getitem - q - dq
|
||||
q_node = node.args[0]
|
||||
if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
|
||||
getitem_node = q_node.args[1]
|
||||
if (
|
||||
isinstance(getitem_node, torch.fx.node.Node)
|
||||
and getitem_node.op == "call_function"
|
||||
and getitem_node.target == operator.getitem
|
||||
):
|
||||
choose_qparam_node = getitem_node.args[0]
|
||||
if (
|
||||
isinstance(choose_qparam_node, torch.fx.node.Node)
|
||||
and choose_qparam_node.op == "call_function"
|
||||
and choose_qparam_node.target
|
||||
== torch.ops.quantized_decomposed.choose_qparams.tensor
|
||||
):
|
||||
continue
|
||||
for user in dq_users:
|
||||
_maybe_duplicate_dq(graph_module, node, user)
|
||||
graph_module.graph.eliminate_dead_code()
|
||||
graph_module.recompile()
|
||||
return PassResult(graph_module, True)
|
|
@ -0,0 +1,240 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.utils import _assert_and_get_unique_device
|
||||
|
||||
|
||||
__all__ = [
|
||||
"model_is_exported",
|
||||
]
|
||||
|
||||
_EXPORTED_TRAINING_ATTR = "_exported_training"
|
||||
|
||||
|
||||
class _WrapperModule(torch.nn.Module):
|
||||
"""Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
|
||||
are trying to export a callable.
|
||||
"""
|
||||
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
def model_is_exported(m: torch.nn.Module) -> bool:
|
||||
"""
|
||||
Return True if the `torch.nn.Module` was exported, False otherwise
|
||||
(e.g. if the model was FX symbolically traced or not traced at all).
|
||||
"""
|
||||
return isinstance(m, torch.fx.GraphModule) and any(
|
||||
"val" in n.meta for n in m.graph.nodes
|
||||
)
|
||||
|
||||
|
||||
def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||
"""
|
||||
Switch dropout patterns in the model between train and eval modes.
|
||||
|
||||
Dropout has different behavior in train vs eval mode. For exported models,
|
||||
however, calling `model.train()` or `model.eval()` does not automatically switch
|
||||
the dropout behavior between the two modes, so here we need to rewrite the aten
|
||||
dropout patterns manually to achieve the same effect.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/103681.
|
||||
"""
|
||||
# Avoid circular dependencies
|
||||
from .utils import _get_aten_graph_module_for_pattern
|
||||
|
||||
# Needed to ensure subgraph matches are self-contained
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
for inplace in [False, True]:
|
||||
|
||||
def dropout_train(x):
|
||||
return F.dropout(x, p=0.5, training=True, inplace=inplace)
|
||||
|
||||
def dropout_eval(x):
|
||||
return F.dropout(x, p=0.5, training=False, inplace=inplace)
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
if train_to_eval:
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_train),
|
||||
example_inputs,
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_eval),
|
||||
example_inputs,
|
||||
)
|
||||
else:
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_eval),
|
||||
example_inputs,
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_train),
|
||||
example_inputs,
|
||||
)
|
||||
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
match_filters=[],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
|
||||
|
||||
def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||
"""
|
||||
Switch batchnorm patterns in the model between train and eval modes.
|
||||
|
||||
Batchnorm has different behavior in train vs eval mode. For exported models,
|
||||
however, calling `model.train()` or `model.eval()` does not automatically switch
|
||||
the batchnorm behavior between the two modes, so here we need to rewrite the aten
|
||||
batchnorm patterns manually to achieve the same effect.
|
||||
"""
|
||||
# TODO(Leslie): This function still fails to support custom momentum and eps value.
|
||||
# Enable this support in future updates.
|
||||
|
||||
# Avoid circular dependencies
|
||||
from .utils import _get_aten_graph_module_for_pattern
|
||||
|
||||
# Needed to ensure subgraph matches are self-contained
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
def bn_train(
|
||||
x: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
):
|
||||
return F.batch_norm(
|
||||
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
|
||||
)
|
||||
|
||||
def bn_eval(
|
||||
x: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
):
|
||||
return F.batch_norm(
|
||||
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
|
||||
)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
device = _assert_and_get_unique_device(m)
|
||||
is_cuda = device is not None and device.type == "cuda"
|
||||
bn_train_aten = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(bn_train),
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
)
|
||||
bn_eval_aten = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(bn_eval),
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
)
|
||||
|
||||
if train_to_eval:
|
||||
match_pattern = bn_train_aten
|
||||
replacement_pattern = bn_eval_aten
|
||||
else:
|
||||
match_pattern = bn_eval_aten
|
||||
replacement_pattern = bn_train_aten
|
||||
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
match_filters=[],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
|
||||
|
||||
# TODO: expose these under this namespace?
|
||||
def _move_exported_model_to_eval(model: torch.fx.GraphModule):
|
||||
"""
|
||||
Move an exported GraphModule to eval mode.
|
||||
|
||||
This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
|
||||
QAT users should call this before performing inference on the model.
|
||||
|
||||
This call is idempotent; if the model is already in eval mode, nothing will happen.
|
||||
"""
|
||||
is_training = getattr(model, _EXPORTED_TRAINING_ATTR, True)
|
||||
if not is_training:
|
||||
return model
|
||||
setattr(model, _EXPORTED_TRAINING_ATTR, False)
|
||||
_replace_dropout(model, train_to_eval=True)
|
||||
_replace_batchnorm(model, train_to_eval=True)
|
||||
return model
|
||||
|
||||
|
||||
def _move_exported_model_to_train(model: torch.fx.GraphModule):
|
||||
"""
|
||||
Move an exported GraphModule to train mode.
|
||||
|
||||
This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
|
||||
QAT users should call this before performing training on the model.
|
||||
|
||||
This call is idempotent; if the model is already in train mode, nothing will happen.
|
||||
"""
|
||||
is_training = getattr(model, _EXPORTED_TRAINING_ATTR, False)
|
||||
if is_training:
|
||||
return model
|
||||
setattr(model, _EXPORTED_TRAINING_ATTR, True)
|
||||
_replace_dropout(model, train_to_eval=False)
|
||||
_replace_batchnorm(model, train_to_eval=False)
|
||||
return model
|
||||
|
||||
|
||||
def _allow_exported_model_train_eval(model: torch.fx.GraphModule):
|
||||
"""
|
||||
Allow users to call `model.train()` and `model.eval()` on an exported model,
|
||||
but with the effect of changing behavior between the two modes limited to special
|
||||
ops only, which are currently dropout and batchnorm.
|
||||
|
||||
Note: This does not achieve the same effect as what `model.train()` and `model.eval()`
|
||||
does in eager models, but only provides an approximation. In particular, user code
|
||||
branching on `training` flag will not function correctly in general because the branch
|
||||
is already specialized at export time. Additionally, other ops beyond dropout and batchnorm
|
||||
that have different train/eval behavior will also not be converted properly.
|
||||
"""
|
||||
|
||||
def _train(self, mode: bool = True):
|
||||
if mode:
|
||||
_move_exported_model_to_train(self)
|
||||
else:
|
||||
_move_exported_model_to_eval(self)
|
||||
|
||||
def _eval(self):
|
||||
_move_exported_model_to_eval(self)
|
||||
|
||||
model.train = types.MethodType(_train, model) # type: ignore[method-assign]
|
||||
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
|
||||
return model
|
181
venv/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py
Normal file
181
venv/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
import operator
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.export import ExportedProgram
|
||||
from torch.fx import Node
|
||||
from torch.fx.passes.utils.source_matcher_utils import (
|
||||
check_subgraphs_connected,
|
||||
get_source_partitions,
|
||||
SourcePartition,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
"update_equivalent_types_dict",
|
||||
"bfs_trace_with_node_process",
|
||||
]
|
||||
|
||||
_EQUIVALENT_TYPES: list[set] = [
|
||||
{torch.nn.Conv1d, torch.nn.functional.conv1d},
|
||||
{torch.nn.Conv2d, torch.nn.functional.conv2d},
|
||||
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
|
||||
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
|
||||
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
|
||||
{torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
|
||||
{torch.add, operator.add, operator.iadd, "add", "add_"},
|
||||
{torch.mul, operator.mul, operator.imul, "mul", "mul_"},
|
||||
]
|
||||
|
||||
|
||||
def _create_equivalent_types_dict():
|
||||
_DICT = {}
|
||||
for values in _EQUIVALENT_TYPES:
|
||||
for v in values:
|
||||
_DICT[v] = list(values)
|
||||
return _DICT
|
||||
|
||||
|
||||
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
|
||||
|
||||
|
||||
def get_equivalent_types() -> list[set]:
|
||||
return _EQUIVALENT_TYPES
|
||||
|
||||
|
||||
def update_equivalent_types_dict(customized_equivalent_types=None):
|
||||
"""Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
|
||||
When customized_equivalent_types passes in,
|
||||
re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
|
||||
"""
|
||||
if customized_equivalent_types is None:
|
||||
raise ValueError("customized_equivalent_types should not be None")
|
||||
global _EQUIVALENT_TYPES
|
||||
global _EQUIVALENT_TYPES_DICT
|
||||
_EQUIVALENT_TYPES = customized_equivalent_types
|
||||
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
|
||||
|
||||
|
||||
def _partitions_sequential(partitions: Sequence[SourcePartition]):
|
||||
prev_partition = None
|
||||
for partition in partitions:
|
||||
if prev_partition is not None and not check_subgraphs_connected(
|
||||
prev_partition, partition
|
||||
):
|
||||
return False
|
||||
prev_partition = partition
|
||||
return True
|
||||
|
||||
|
||||
def _get_matching_types(partition_type):
|
||||
matching_types = [partition_type]
|
||||
if partition_type in _EQUIVALENT_TYPES_DICT:
|
||||
matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
|
||||
return matching_types
|
||||
|
||||
|
||||
def _valid_type_sequence(partition_types: list[Any]):
|
||||
partition_types_set = set() # type: ignore[var-annotated]
|
||||
for partition_type in partition_types:
|
||||
matching_types = _get_matching_types(partition_type)
|
||||
matching_types_set = set(matching_types)
|
||||
if len(partition_types_set & matching_types_set) > 0:
|
||||
return False
|
||||
partition_types_set |= matching_types_set
|
||||
return True
|
||||
|
||||
|
||||
def find_sequential_partitions(
|
||||
gm: torch.fx.GraphModule,
|
||||
partition_types: list[Any],
|
||||
include_functional_equivalent=True,
|
||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||
):
|
||||
if not _valid_type_sequence(partition_types):
|
||||
raise ValueError(
|
||||
f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
|
||||
)
|
||||
|
||||
typed_partitions: OrderedDict[Any, list[SourcePartition]] = OrderedDict()
|
||||
for partition_type in partition_types:
|
||||
types_to_match = _get_matching_types(partition_type)
|
||||
partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
|
||||
typed_partitions[partition_type] = list(
|
||||
itertools.chain.from_iterable(partitions.values())
|
||||
)
|
||||
|
||||
typed_partitions_list = list(typed_partitions.values())
|
||||
fusion_candidates = itertools.product(*typed_partitions_list)
|
||||
fused_partitions = [
|
||||
candidate
|
||||
for candidate in fusion_candidates
|
||||
if _partitions_sequential(candidate)
|
||||
]
|
||||
return fused_partitions
|
||||
|
||||
|
||||
def _get_submodule(
|
||||
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
|
||||
) -> tuple[str, torch.nn.Module, torch.fx.Node]:
|
||||
submod_node = node.args[arg_index]
|
||||
assert isinstance(submod_node, torch.fx.Node)
|
||||
assert submod_node.op == "get_attr"
|
||||
assert isinstance(submod_node.target, str)
|
||||
submodule = graph_module.get_submodule(submod_node.target)
|
||||
# pyre-ignore
|
||||
return submod_node.target, submodule, node
|
||||
|
||||
|
||||
def _get_control_flow_submodules(
|
||||
graph_module: torch.fx.GraphModule,
|
||||
) -> list[tuple[str, torch.nn.Module, torch.fx.Node]]:
|
||||
"""
|
||||
Returns a list of submodules used for control flow operations
|
||||
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
|
||||
into submodules). Specifically, the returned value is a list containing a
|
||||
tuple of (name of the submodule that's stored in the graph module, the
|
||||
submodule itself, and the fx node that uses this submodule).
|
||||
"""
|
||||
control_flow_submodules = []
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
|
||||
if node.target is torch.ops.higher_order.cond:
|
||||
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
|
||||
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
|
||||
if node.target is torch.ops.higher_order.map_impl:
|
||||
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
|
||||
|
||||
return control_flow_submodules
|
||||
|
||||
|
||||
def bfs_trace_with_node_process(
|
||||
model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable
|
||||
) -> None:
|
||||
"""Traverse the graph module and apply node_op to each node."""
|
||||
|
||||
assert isinstance(
|
||||
model, (ExportedProgram, torch.fx.GraphModule)
|
||||
), f"Expected GraphModule or ExportedProgram, got {type(model)}"
|
||||
gm = model.graph_module if isinstance(model, ExportedProgram) else model
|
||||
queue = [gm]
|
||||
while queue:
|
||||
current_graph_module = queue.pop(0)
|
||||
for node in current_graph_module.graph.nodes:
|
||||
if node.op in ["output", "placeholder"]:
|
||||
continue
|
||||
|
||||
node_op(node)
|
||||
|
||||
control_flow_submodules = [
|
||||
submodule
|
||||
for _, submodule, _ in _get_control_flow_submodules(current_graph_module)
|
||||
]
|
||||
queue.extend(control_flow_submodules)
|
|
@ -0,0 +1,215 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._export.error import InternalError
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_filter_sym_size_users,
|
||||
_find_q_dq_node_for_user,
|
||||
_is_valid_annotation,
|
||||
)
|
||||
from torch.ao.quantization.quantizer import QuantizationSpecBase
|
||||
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
__all__ = ["PortNodeMetaForQDQ"]
|
||||
|
||||
_METADATA_TO_PORT = [
|
||||
"stack_trace",
|
||||
"quantization_tag",
|
||||
]
|
||||
|
||||
_QUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
]
|
||||
|
||||
_DEQUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
_CHOOSE_QPARAMS_OPS = [
|
||||
torch.ops.quantized_decomposed.choose_qparams.tensor,
|
||||
torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
|
||||
]
|
||||
|
||||
|
||||
def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
|
||||
from_meta = from_node.meta
|
||||
for meta_name in _METADATA_TO_PORT:
|
||||
if meta_name in from_meta:
|
||||
to_node.meta[meta_name] = from_meta[meta_name]
|
||||
|
||||
|
||||
def _has_quant_annotation(node: torch.fx.Node) -> bool:
|
||||
return "quantization_annotation" in node.meta
|
||||
|
||||
|
||||
def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
|
||||
# BFS to look for choose qparams
|
||||
from collections import deque
|
||||
|
||||
queue = deque(list(node.users.keys()))
|
||||
while len(queue):
|
||||
n = queue.popleft()
|
||||
if n.op == "output":
|
||||
continue
|
||||
if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS:
|
||||
return n
|
||||
for k in n.users.keys():
|
||||
queue.append(k)
|
||||
return None
|
||||
|
||||
|
||||
def _port_metadata_for_input_quant_nodes(
|
||||
input_node: torch.fx.Node,
|
||||
node: torch.fx.Node,
|
||||
qspec: Optional[QuantizationSpecBase],
|
||||
):
|
||||
if qspec is None:
|
||||
return
|
||||
|
||||
is_dynamic_quant = getattr(qspec, "is_dynamic", None)
|
||||
if is_dynamic_quant is not None and is_dynamic_quant is True:
|
||||
choose_qparams_node = _find_choose_qparams_node(input_node)
|
||||
if choose_qparams_node is None:
|
||||
raise ValueError(f"No chose qparams node found for {node}")
|
||||
choose_qparam_users = _filter_sym_size_users(choose_qparams_node)
|
||||
if len(choose_qparam_users) != 2:
|
||||
raise InternalError(f"Expecting exactly two user for {choose_qparams_node}")
|
||||
scale_node = choose_qparam_users.pop()
|
||||
dynamic_q_node = next(iter(scale_node.users.keys()))
|
||||
dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node)
|
||||
if len(dynamic_q_node_users) > 1:
|
||||
raise InternalError(f"Expecting single user for {dynamic_q_node}")
|
||||
dynamic_dq_node = dynamic_q_node_users.pop()
|
||||
_add_metadata(choose_qparams_node, node)
|
||||
_add_metadata(dynamic_q_node, node)
|
||||
_add_metadata(dynamic_dq_node, node)
|
||||
else:
|
||||
q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
|
||||
if q_node is None or dq_node is None:
|
||||
return
|
||||
# add metadata for all the node between q_node and get_attr node
|
||||
# if the q_node can be traced back to get_attr node
|
||||
q_to_get_attr_nodes = [q_node]
|
||||
q_node_input = q_node.args[0]
|
||||
while (
|
||||
isinstance(q_node_input, torch.fx.Node)
|
||||
and q_node_input.op == "call_function"
|
||||
and q_node_input.target
|
||||
in [
|
||||
torch.ops.aten.flatten.using_ints,
|
||||
torch.ops.aten.permute.default,
|
||||
torch.ops.aten.permute_copy.default,
|
||||
torch.ops.aten.slice_copy.Tensor,
|
||||
torch.ops.aten.squeeze.dim,
|
||||
torch.ops.aten.squeeze_copy.dim,
|
||||
torch.ops.aten.transpose.Dimname,
|
||||
torch.ops.aten.transpose.int,
|
||||
torch.ops.aten.transpose_,
|
||||
torch.ops.aten.view_copy.default,
|
||||
torch.ops.aten.view.default,
|
||||
torch.ops.aten._mkldnn_transpose,
|
||||
]
|
||||
):
|
||||
q_to_get_attr_nodes.append(q_node_input)
|
||||
q_node_input = q_node_input.args[0]
|
||||
if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr":
|
||||
for n in q_to_get_attr_nodes:
|
||||
_add_metadata(n, q_node_input)
|
||||
_add_metadata(dq_node, node)
|
||||
|
||||
|
||||
def _port_metadata_for_output_quant_nodes(
|
||||
node: torch.fx.Node, qspec: Optional[QuantizationSpecBase]
|
||||
):
|
||||
if qspec is None:
|
||||
return
|
||||
|
||||
node_users = _filter_sym_size_users(node)
|
||||
if len(node.users) == 0:
|
||||
return
|
||||
if len(node_users) != 1:
|
||||
logger.warning(f"Expecting {node} to have single user") # noqa: G004
|
||||
q_node = node_users.pop()
|
||||
if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS:
|
||||
logger.warning(
|
||||
f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004
|
||||
) # noqa: G004
|
||||
return
|
||||
|
||||
_add_metadata(q_node, node)
|
||||
|
||||
|
||||
class PortNodeMetaForQDQ(PassBase):
|
||||
"""
|
||||
Port metadata for nodes added by quantization flow.
|
||||
For static quant these are:
|
||||
- quantizer_per_tensor.default, dequantize_per_tensor.default
|
||||
- quantizer_per_channel.default, dequantize_per_channel.default
|
||||
For dynamic quant these are:
|
||||
- choose_qparams.tensor
|
||||
- quantizer_per_tensor.tensor, dequantize_per_tensor.tensor
|
||||
- quantizer_per_channel.default, dequantize_per_channel.default
|
||||
|
||||
Rules of porting metadata:
|
||||
- Metadata to be ported:
|
||||
- nn_module_stack
|
||||
- stack_trace
|
||||
- quantization_tag
|
||||
- Metadata to NOT be ported:
|
||||
- Everything else
|
||||
- Rules:
|
||||
- Statically quantized patterns:
|
||||
- Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node.
|
||||
- Quantize nodes on the outputs inherit metadata of the producer node.
|
||||
- Example 1:
|
||||
- Original: [Conv -> AvgPool -> Linear]
|
||||
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
|
||||
- Inner brackets specify which nodes Q/DQ inherit metdata from
|
||||
- [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ]
|
||||
- Note first Q and last DQ do not inherit metadata from any nodes
|
||||
- Example 2:
|
||||
- Original: [Conv -> AvgPool -> Linear]
|
||||
- AvgPool is not quantized
|
||||
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
|
||||
- Inner brackets specify which nodes Q/DQ inherit metdata from
|
||||
- [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ]
|
||||
- Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because
|
||||
AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation
|
||||
on the nodes (in this case AvgPool node) to conclude if the node or patter was
|
||||
supposed to be quantized. And subsequntly decide if the preceding Q, if any, should
|
||||
inherit metadata from AvgPool.
|
||||
- Dynamically quantized patterns:
|
||||
- Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes
|
||||
- For example, below linear is dynamically quantized while rest statically:
|
||||
- Original: [Conv -> AvgPool -> Linear]
|
||||
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
|
||||
- Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
|
||||
- Note first Q does not inherit metadata from any nodes
|
||||
NB:
|
||||
- The best place for porting metadata is during observer conversion to q/dq. This is because it precisely
|
||||
knows which quantization spec is converted to q/dq and thus from where the metadata should be ported.
|
||||
However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit.
|
||||
Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant
|
||||
code, this pass should like to be integrated in the refactored variant of "convert" step.
|
||||
"""
|
||||
|
||||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
||||
for node in graph_module.graph.nodes:
|
||||
annotation = node.meta.get("quantization_annotation", None)
|
||||
if _is_valid_annotation(annotation):
|
||||
input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
|
||||
output_qspec = node.meta["quantization_annotation"].output_qspec
|
||||
for input_node, qspec in input_qspec_map.items():
|
||||
_port_metadata_for_input_quant_nodes(input_node, node, qspec)
|
||||
_port_metadata_for_output_quant_nodes(node, output_qspec)
|
||||
return PassResult(graph_module, True)
|
574
venv/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py
Normal file
574
venv/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py
Normal file
|
@ -0,0 +1,574 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.ao.quantization import (
|
||||
CUSTOM_KEY,
|
||||
NUMERIC_DEBUG_HANDLE_KEY,
|
||||
ObserverOrFakeQuantize,
|
||||
QConfigMapping,
|
||||
)
|
||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||
from torch.ao.quantization.fx.prepare import (
|
||||
_create_obs_or_fq_from_qspec,
|
||||
_insert_obs_or_fq,
|
||||
_is_activation_post_process_node,
|
||||
_save_state,
|
||||
)
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
from torch.ao.quantization.quantizer import (
|
||||
EdgeOrNode,
|
||||
QuantizationSpecBase,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.fx.node import Argument
|
||||
|
||||
|
||||
# TODO: make pt2e folder private?
|
||||
__all__ = [
|
||||
"prepare",
|
||||
]
|
||||
|
||||
|
||||
def _find_root_edge_or_node(
|
||||
edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode]
|
||||
) -> EdgeOrNode:
|
||||
"""Find the root node for the sharing tree
|
||||
Args:
|
||||
edge_or_node: edge/node that we want to find the root
|
||||
shared_with_map: each edge/node points to the parent, the root node will points to itself
|
||||
|
||||
Returns:
|
||||
root edge/node
|
||||
"""
|
||||
parent = shared_with_map[edge_or_node]
|
||||
if parent == edge_or_node:
|
||||
return edge_or_node
|
||||
root = _find_root_edge_or_node(parent, shared_with_map)
|
||||
# path compression
|
||||
shared_with_map[edge_or_node] = root
|
||||
return root
|
||||
|
||||
|
||||
def _union(
|
||||
parent: EdgeOrNode,
|
||||
child: EdgeOrNode,
|
||||
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
|
||||
) -> None:
|
||||
"""Merge the subtree for `child` with `parent`, the order is important here"""
|
||||
root_parent = _find_root_edge_or_node(parent, shared_with_map)
|
||||
root_child = _find_root_edge_or_node(child, shared_with_map)
|
||||
# union the two trees by pointing the root of child to root of parent
|
||||
shared_with_map[root_child] = root_parent
|
||||
|
||||
|
||||
def _update_shared_with(
|
||||
child: EdgeOrNode,
|
||||
qspec: QuantizationSpecBase,
|
||||
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
|
||||
):
|
||||
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
|
||||
configuration and established the relationship between `edge_or_node` with the edge/node that it
|
||||
is pointing to, we'll use this information in the end to get the group id
|
||||
"""
|
||||
if isinstance(qspec, SharedQuantizationSpec):
|
||||
parent = qspec.edge_or_node
|
||||
# we point from edge_or_node to the node that it is sharing_with, e.g.
|
||||
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
|
||||
_union(parent, child, shared_with_map)
|
||||
|
||||
|
||||
def _unwrap_shared_qspec(
|
||||
qspec: QuantizationSpecBase,
|
||||
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
|
||||
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
|
||||
) -> QuantizationSpecBase:
|
||||
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
|
||||
if qspec is SharedQuantizationSpec
|
||||
(1). tries to find the root edge or node for the node that the qspec points to
|
||||
(2). recursively find the root qspec based on the qspec for the root node
|
||||
"""
|
||||
if isinstance(qspec, SharedQuantizationSpec):
|
||||
sharing_with = qspec.edge_or_node
|
||||
root = _find_root_edge_or_node(sharing_with, shared_with_map)
|
||||
qspec = edge_or_node_to_qspec[root]
|
||||
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
||||
return qspec
|
||||
|
||||
|
||||
def _has_same_attr(
|
||||
qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str
|
||||
):
|
||||
return (
|
||||
hasattr(qspec_a, attr_name)
|
||||
and hasattr(qspec_b, attr_name)
|
||||
and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name)
|
||||
) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name))
|
||||
|
||||
|
||||
def _get_edge_or_node_to_qspec(
|
||||
model: torch.fx.GraphModule,
|
||||
) -> dict[EdgeOrNode, QuantizationSpecBase]:
|
||||
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes"""
|
||||
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {}
|
||||
for n in model.graph.nodes:
|
||||
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
|
||||
qa = n.meta["quantization_annotation"]
|
||||
for input_to_n, qspec in qa.input_qspec_map.items():
|
||||
input_edge = (input_to_n, n)
|
||||
edge_or_node_to_qspec[input_edge] = qspec
|
||||
if qa.output_qspec is not None:
|
||||
output_node = n
|
||||
qspec = qa.output_qspec
|
||||
edge_or_node_to_qspec[output_node] = qspec
|
||||
return edge_or_node_to_qspec
|
||||
|
||||
|
||||
def _union_input_edge_with(
|
||||
input_edge,
|
||||
input_edge_root_qspec,
|
||||
edge_or_node,
|
||||
edge_or_node_to_qspec,
|
||||
shared_with_map,
|
||||
):
|
||||
"""Union input edge with another edge or node, used in implicit sharing to point the current input
|
||||
edge to other user edges of the producer node, or the output of producer node since these are
|
||||
referring to the same Tensor
|
||||
"""
|
||||
root_qspec = None
|
||||
if edge_or_node in edge_or_node_to_qspec:
|
||||
qspec = edge_or_node_to_qspec[edge_or_node]
|
||||
root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
||||
# TODO: add assertions for types of root qspecs
|
||||
if root_qspec is not None and all(
|
||||
_has_same_attr(root_qspec, input_edge_root_qspec, attr)
|
||||
for attr in [
|
||||
"dtype",
|
||||
"is_dynamic",
|
||||
"quant_min",
|
||||
"quant_max",
|
||||
"qscheme",
|
||||
"ch_axis",
|
||||
"scale",
|
||||
"zero_point",
|
||||
]
|
||||
):
|
||||
# the input arg to the node should reuse the existing output observer for arg
|
||||
# since dtype is the same (we may want to extend this to be a more strict check
|
||||
# in the future)
|
||||
# so we point from `input_edge` to `arg` (output of the argument)
|
||||
_union(edge_or_node, input_edge, shared_with_map)
|
||||
|
||||
|
||||
def _get_edge_or_node_to_group_id(
|
||||
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase]
|
||||
) -> dict[EdgeOrNode, int]:
|
||||
"""Map from edge/node to the group ID, generated from quantization annotations,
|
||||
edge/node with the same group ID should use the same observer/fake_quant instance
|
||||
|
||||
This is applying SharedQuantizationSpec configuration and map each edge/node to a group
|
||||
There is another implicit sharing that's built in the quantization, when we have the following:
|
||||
* op1 -> op2
|
||||
* output of op1: int8_qspec
|
||||
* (op1 -> op2) input edge: int8_qspec
|
||||
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
|
||||
|
||||
Figuring out the correct group ID for all edge/node is a standard union find problem:
|
||||
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
|
||||
|
||||
Args:
|
||||
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
|
||||
Returns:
|
||||
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
|
||||
belongs to the same group should have the same id
|
||||
|
||||
Example:
|
||||
op2 -> cat1 -> cat2
|
||||
op1 / /
|
||||
op3
|
||||
edge_or_node_to_qspec: {
|
||||
op1: int8_qspec,
|
||||
op2: int8_qspec,
|
||||
(op1, cat1): int8_qspc,
|
||||
(op2, cat1): SharedQuantizationSpec((op1, cat1)),
|
||||
cat1: SharedQuantizationSpec((op1, cat1)),
|
||||
(op3, cat2): int8_qspec,
|
||||
(cat1, cat2): SharedQuantizationSpec((op3, cat2)),
|
||||
cat2: SharedQuantizationSpec((op3, cat2)),
|
||||
}
|
||||
|
||||
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
||||
edge_or_node_to_group_id: {
|
||||
op1: 1,
|
||||
op2: 1,
|
||||
(op1, cat1): 1,
|
||||
(op2, cat1): 1,
|
||||
cat1: 1,
|
||||
(op3, cat2): 1,
|
||||
(cat1, cat2): 1,
|
||||
cat2: 1,
|
||||
}
|
||||
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
|
||||
# connects the two sharing group around cat1 and cat2 op due to transitive sharing
|
||||
"""
|
||||
# means the observer of key should be shared with observer with value, by default it will
|
||||
# be shared with itself
|
||||
shared_with_map: dict[EdgeOrNode, EdgeOrNode] = {
|
||||
k: k for k in edge_or_node_to_qspec.keys()
|
||||
}
|
||||
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
||||
if isinstance(edge_or_node, torch.fx.Node):
|
||||
output_node = edge_or_node
|
||||
_update_shared_with(output_node, qspec, shared_with_map)
|
||||
else:
|
||||
input_edge = edge_or_node
|
||||
input_edge_root_qspec = _unwrap_shared_qspec(
|
||||
qspec, edge_or_node_to_qspec, shared_with_map
|
||||
)
|
||||
|
||||
assert isinstance(input_edge, tuple)
|
||||
arg, n = input_edge
|
||||
if n.meta["quantization_annotation"].allow_implicit_sharing:
|
||||
# NOTE: the order is important here, we first share with other users and then share with previous
|
||||
# output because the reverse order could cause circular dependency
|
||||
# e.g node1 -> node2
|
||||
# \ -> node3
|
||||
# when processing (node1, node2), if we first point (node1, node2) to node1
|
||||
# Step 1. shared_map = {(node1, node2): node1}
|
||||
# Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
|
||||
# which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
|
||||
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
|
||||
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
|
||||
# have a circular dependency
|
||||
# the following order works around this issue, but this does not allow arbitrary configuration
|
||||
# of sharing so it might break in a different case in the future, when it breaks
|
||||
# quantizer writer can check the notes here to debug the issue
|
||||
|
||||
# sharing with other users of the producer node
|
||||
# (arg, user)
|
||||
if not isinstance(arg, Node) or not isinstance(n, Node):
|
||||
raise Exception( # noqa: TRY002
|
||||
f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}"
|
||||
)
|
||||
for user in arg.users:
|
||||
if user is n:
|
||||
continue
|
||||
arg_to_user_edge = (arg, user)
|
||||
_union_input_edge_with(
|
||||
input_edge,
|
||||
input_edge_root_qspec,
|
||||
arg_to_user_edge,
|
||||
edge_or_node_to_qspec,
|
||||
shared_with_map,
|
||||
)
|
||||
|
||||
# sharing with output of producer node
|
||||
_union_input_edge_with(
|
||||
input_edge,
|
||||
input_edge_root_qspec,
|
||||
arg,
|
||||
edge_or_node_to_qspec,
|
||||
shared_with_map,
|
||||
)
|
||||
|
||||
_update_shared_with(input_edge, qspec, shared_with_map)
|
||||
|
||||
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
|
||||
cur_group_id = 0
|
||||
edge_or_node_to_group_id: dict[EdgeOrNode, int] = {}
|
||||
for edge_or_node in shared_with_map.keys():
|
||||
root = _find_root_edge_or_node(edge_or_node, shared_with_map)
|
||||
if root not in edge_or_node_to_group_id:
|
||||
edge_or_node_to_group_id[root] = cur_group_id
|
||||
cur_group_id += 1
|
||||
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
|
||||
|
||||
return edge_or_node_to_group_id
|
||||
|
||||
|
||||
def _get_obs_or_fq_map(
|
||||
edge_or_node_to_group_id: dict[EdgeOrNode, int],
|
||||
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
|
||||
is_qat: bool,
|
||||
) -> dict[EdgeOrNode, ObserverOrFakeQuantize]:
|
||||
"""Generates the EdgeOrNode to observer/fake_quant instances
|
||||
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
|
||||
instances
|
||||
"""
|
||||
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
|
||||
group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {}
|
||||
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
||||
group_id = edge_or_node_to_group_id[edge_or_node]
|
||||
if group_id not in group_id_to_obs_or_fq:
|
||||
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
|
||||
# the implementation for _create_obs_or_fq_from_qspec
|
||||
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(
|
||||
qspec, obs_or_fq_map, is_qat
|
||||
)
|
||||
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
|
||||
return obs_or_fq_map
|
||||
|
||||
|
||||
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node: Union[Node, Any],
|
||||
arg: Argument,
|
||||
qconfig: QConfigAny,
|
||||
model: torch.nn.Module,
|
||||
named_modules: dict[str, torch.nn.Module],
|
||||
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
is_qat: bool,
|
||||
) -> Argument:
|
||||
"""
|
||||
Given a `node` and an `arg`, inserts an input observer between
|
||||
`node` and `arg` if necessary.
|
||||
"""
|
||||
# for ops such as torch.cat([x0, x1]),
|
||||
# traverse through the list
|
||||
if isinstance(arg, (list, tuple)):
|
||||
new_arg_to_return = []
|
||||
for inner_arg in arg:
|
||||
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node,
|
||||
inner_arg,
|
||||
qconfig,
|
||||
model,
|
||||
named_modules,
|
||||
obs_or_fq_map,
|
||||
is_qat,
|
||||
)
|
||||
new_arg_to_return.append(new_inner_arg)
|
||||
return type(arg)(new_arg_to_return)
|
||||
|
||||
if not isinstance(arg, Node):
|
||||
return arg
|
||||
assert isinstance(arg, Node)
|
||||
# default (no observer)
|
||||
new_arg = arg
|
||||
|
||||
# find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
|
||||
original_arg = arg
|
||||
while _is_activation_post_process_node(original_arg, named_modules):
|
||||
original_arg = original_arg.args[0] # type: ignore[assignment]
|
||||
assert isinstance(
|
||||
original_arg, Node
|
||||
), f"expect original argument to be a Node, but got: {type(original_arg)}"
|
||||
|
||||
input_edge = (original_arg, node)
|
||||
if input_edge not in obs_or_fq_map:
|
||||
return new_arg
|
||||
# input_edge needs to be observed
|
||||
input_edge_obs_or_fq = obs_or_fq_map[input_edge]
|
||||
if input_edge_obs_or_fq is None:
|
||||
return new_arg
|
||||
|
||||
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
|
||||
# the arg is observed as the output and is using the same instance as the input_edge
|
||||
# we'll reuse the inserted observer/fake_quant
|
||||
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
|
||||
input_edge_obs_or_fq
|
||||
):
|
||||
return new_arg
|
||||
|
||||
# otherwise, we'll insert a new observer/fake_quant node
|
||||
|
||||
# skip inserting new observers if the same observer instance is inserted before for another user
|
||||
# Example:
|
||||
# conv1 -> obs1 -> existing_obs -> conv2
|
||||
# \ -> conv3
|
||||
#
|
||||
# instead of inserting new observers we will have:
|
||||
# conv1 -> obs1 -> existing_obs -> conv2
|
||||
# \ -> conv3
|
||||
for maybe_obs_node in arg.users.keys():
|
||||
if not _is_activation_post_process_node(maybe_obs_node, named_modules):
|
||||
continue
|
||||
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
||||
if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
|
||||
return maybe_obs_node
|
||||
|
||||
assert isinstance(model.graph, Graph)
|
||||
new_arg = _insert_obs_or_fq(
|
||||
arg, input_edge_obs_or_fq, model, named_modules, model.graph
|
||||
)
|
||||
return new_arg
|
||||
|
||||
|
||||
def _maybe_insert_input_observers_for_node(
|
||||
node: Node,
|
||||
qconfig: QConfigAny,
|
||||
model: torch.nn.Module,
|
||||
named_modules: dict[str, torch.nn.Module],
|
||||
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
is_qat: bool,
|
||||
) -> None:
|
||||
"""
|
||||
If needed, inserts observers to the input args and kwargs of `node`.
|
||||
Note: modifies `node` inplace.
|
||||
|
||||
For example, if cur_node needs an observer after prev_node, we change from
|
||||
|
||||
prev_node -> cur_node
|
||||
|
||||
To
|
||||
|
||||
prev_node -> obs -> cur_node
|
||||
|
||||
"""
|
||||
# Look through every input arg. If that arg's target dtype does not
|
||||
# match the current node's target dtype, insert an observer.
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node,
|
||||
arg,
|
||||
qconfig,
|
||||
model,
|
||||
named_modules,
|
||||
obs_or_fq_map,
|
||||
is_qat,
|
||||
)
|
||||
new_args.append(new_arg)
|
||||
|
||||
# Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
|
||||
# gelu has a has an approximate kwarg that persist in exported graph.
|
||||
# This is just a work around for these.
|
||||
assert (
|
||||
node.target == torch.ops.aten.clone.default
|
||||
or node.target == torch.ops.aten.zeros_like.default
|
||||
or node.target == torch.ops.aten.gelu.default
|
||||
or len(node.kwargs) == 0
|
||||
), " expecting kwargs for aten op IR to be empty"
|
||||
|
||||
# assign the new args to the node, inplace
|
||||
node.args = tuple(new_args)
|
||||
|
||||
|
||||
def _maybe_insert_output_observer_for_node(
|
||||
node: Node,
|
||||
model: torch.nn.Module,
|
||||
named_modules: dict[str, torch.nn.Module],
|
||||
graph: Graph,
|
||||
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
is_qat: bool,
|
||||
) -> Optional[Node]:
|
||||
if node in obs_or_fq_map:
|
||||
output_act_obs_or_fq = obs_or_fq_map[node]
|
||||
new_output = _insert_obs_or_fq(
|
||||
node, output_act_obs_or_fq, model, named_modules, graph
|
||||
)
|
||||
# propagate numeric debug handle from original node to observer/fake_quant node
|
||||
if (
|
||||
isinstance(node, Node)
|
||||
and isinstance(new_output, Node)
|
||||
and CUSTOM_KEY in node.meta
|
||||
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
|
||||
):
|
||||
if CUSTOM_KEY not in new_output.meta:
|
||||
new_output.meta[CUSTOM_KEY] = {}
|
||||
new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
|
||||
CUSTOM_KEY
|
||||
][NUMERIC_DEBUG_HANDLE_KEY]
|
||||
return new_output
|
||||
return None
|
||||
|
||||
|
||||
def _maybe_insert_input_and_output_observers_for_node(
|
||||
node: Node,
|
||||
model: torch.fx.GraphModule,
|
||||
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
is_qat: bool,
|
||||
):
|
||||
this_node_quantization_annotation = (
|
||||
node.meta["quantization_annotation"]
|
||||
if "quantization_annotation" in node.meta
|
||||
else None
|
||||
)
|
||||
if this_node_quantization_annotation is None:
|
||||
return
|
||||
|
||||
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
_maybe_insert_input_observers_for_node(
|
||||
node,
|
||||
None, # qconfig
|
||||
model,
|
||||
named_modules,
|
||||
obs_or_fq_map,
|
||||
is_qat,
|
||||
)
|
||||
|
||||
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
|
||||
if not output_is_a_tensor:
|
||||
return
|
||||
|
||||
# this returns the new observer node if it was needed
|
||||
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
|
||||
node, model, named_modules, model.graph, obs_or_fq_map, is_qat
|
||||
)
|
||||
|
||||
if maybe_output_obs_node is None:
|
||||
return
|
||||
# Update users of original node to use the output observer
|
||||
# instead. For example, change
|
||||
#
|
||||
# next_node
|
||||
# /
|
||||
# cur_node -> obs
|
||||
#
|
||||
# to
|
||||
#
|
||||
# next_node
|
||||
# /
|
||||
# cur_node -> obs
|
||||
#
|
||||
# We need to save orig users before updating uses because
|
||||
# the list of users will change as we update uses
|
||||
orig_users = list(node.users.keys())
|
||||
for user_node in orig_users:
|
||||
if user_node is maybe_output_obs_node:
|
||||
continue
|
||||
user_node.replace_input_with(node, maybe_output_obs_node)
|
||||
|
||||
|
||||
def prepare(
|
||||
model: GraphModule,
|
||||
node_name_to_scope: dict[str, tuple[str, type]],
|
||||
is_qat: bool,
|
||||
obs_or_fq_callback=None,
|
||||
) -> GraphModule:
|
||||
# Since we are mutating the graph as we go, we iterate over the original
|
||||
# nodes before observer insertion, instead of model.graph.nodes.
|
||||
nodes_before_observation = list(model.graph.nodes)
|
||||
|
||||
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
|
||||
# all edge/nodes that belongs to the same group will use the same instance
|
||||
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
|
||||
# instance
|
||||
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
|
||||
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
||||
obs_or_fq_map = _get_obs_or_fq_map(
|
||||
edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
|
||||
)
|
||||
if obs_or_fq_callback:
|
||||
obs_or_fq_callback(model, obs_or_fq_map)
|
||||
|
||||
for node in nodes_before_observation:
|
||||
# TODO: simplify logic for inserting observers
|
||||
_maybe_insert_input_and_output_observers_for_node(
|
||||
node, model, obs_or_fq_map, is_qat
|
||||
)
|
||||
|
||||
model = GraphModule(model, model.graph)
|
||||
|
||||
_save_state(
|
||||
model,
|
||||
{}, # node_name_to_qconfig
|
||||
node_name_to_scope,
|
||||
PrepareCustomConfig(),
|
||||
{}, # equalization_node_name_to_qconfig
|
||||
QConfigMapping(),
|
||||
is_qat,
|
||||
set(), # observed_node_names
|
||||
)
|
||||
return model
|
991
venv/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py
Normal file
991
venv/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py
Normal file
|
@ -0,0 +1,991 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
|
||||
from torch.ao.quantization.quantizer import (
|
||||
DerivedQuantizationSpec,
|
||||
EdgeOrNode,
|
||||
QuantizationSpecBase,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
|
||||
|
||||
from .utils import (
|
||||
_get_aten_graph_module_for_pattern,
|
||||
_is_bn_node,
|
||||
_is_conv_or_conv_transpose_node,
|
||||
_is_conv_transpose_fn,
|
||||
fold_bn_weights_into_conv_node,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
def _get_quantized_conv_bn_example_inputs_kwargs(
|
||||
is_per_channel: bool,
|
||||
has_bias: bool,
|
||||
bias_is_quantized: bool,
|
||||
is_cuda: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Optional example inputs for quantized and folded conv-bn patterns
|
||||
used in convert, expressed as kwargs.
|
||||
"""
|
||||
kwargs = {}
|
||||
# Per tensor quantization uses literals to represent scale and zero
|
||||
# point, so there is no need to include them here as kwargs
|
||||
if is_per_channel:
|
||||
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
|
||||
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
|
||||
if has_bias and bias_is_quantized:
|
||||
kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float)
|
||||
kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int)
|
||||
if has_bias:
|
||||
kwargs["conv_bias"] = torch.randn(1)
|
||||
if is_cuda:
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
kwargs[k] = v.cuda()
|
||||
return kwargs
|
||||
|
||||
|
||||
def _get_conv_bn_pattern(conv_fn: Callable) -> Callable:
|
||||
def _conv_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
conv_bias: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = conv_fn(x, conv_weight, conv_bias)
|
||||
x = F.batch_norm(
|
||||
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
|
||||
)
|
||||
return x
|
||||
|
||||
return _WrapperModule(_conv_bn_pattern)
|
||||
|
||||
|
||||
# TODO: merge this with the `no_conv_bias` case
|
||||
def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable:
|
||||
def _qat_conv_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
conv_bias: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Approximated method to fuse conv and bn. It requires only one forward pass.
|
||||
conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std.
|
||||
This is based on `nniqat.ConvBn2d._forward_approximate`.
|
||||
"""
|
||||
# TODO: allow setting eps
|
||||
bn_eps = 1e-5
|
||||
running_std = torch.sqrt(bn_running_var + bn_eps)
|
||||
scale_factor = bn_weight / running_std
|
||||
weight_shape = [1] * len(conv_weight.shape)
|
||||
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
|
||||
weight_shape[weight_in_channel_axis] = -1
|
||||
bias_shape = [1] * len(conv_weight.shape)
|
||||
bias_shape[1] = -1
|
||||
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
|
||||
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
|
||||
x = conv_fn(x, scaled_weight, zero_bias)
|
||||
x = x / scale_factor.reshape(bias_shape)
|
||||
x = x + conv_bias.reshape(bias_shape)
|
||||
x = F.batch_norm(
|
||||
x,
|
||||
bn_running_mean,
|
||||
bn_running_var,
|
||||
bn_weight,
|
||||
bn_bias,
|
||||
training=True,
|
||||
eps=bn_eps,
|
||||
)
|
||||
return x
|
||||
|
||||
return _WrapperModule(_qat_conv_bn_pattern)
|
||||
|
||||
|
||||
def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable:
|
||||
def _qat_conv_bn_pattern_no_conv_bias(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
# Not used, only for matching convenience
|
||||
conv_bias: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias.
|
||||
"""
|
||||
# TODO: allow setting eps
|
||||
bn_eps = 1e-5
|
||||
running_std = torch.sqrt(bn_running_var + bn_eps)
|
||||
scale_factor = bn_weight / running_std
|
||||
weight_shape = [1] * len(conv_weight.shape)
|
||||
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
|
||||
weight_shape[weight_in_channel_axis] = -1
|
||||
bias_shape = [1] * len(conv_weight.shape)
|
||||
bias_shape[1] = -1
|
||||
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
|
||||
x = conv_fn(x, scaled_weight, None)
|
||||
x = x / scale_factor.reshape(bias_shape)
|
||||
x = F.batch_norm(
|
||||
x,
|
||||
bn_running_mean,
|
||||
bn_running_var,
|
||||
bn_weight,
|
||||
bn_bias,
|
||||
training=True,
|
||||
eps=bn_eps,
|
||||
)
|
||||
return x
|
||||
|
||||
return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias)
|
||||
|
||||
|
||||
def _append_qdq(x, is_per_channel, is_bias, kwargs):
|
||||
"""
|
||||
Helper function to append q-dq ops after `x`, using dummy values for the qparams
|
||||
and qmin/qmax. We use dummy values here because we match with `ignore_literals=True`
|
||||
and will manually replace these values after subgraph rewriting.
|
||||
|
||||
Return the dq node.
|
||||
"""
|
||||
# Dummy args to be passed into q-dq ops
|
||||
per_channel_axis = 0
|
||||
scale_key = "bias_scale" if is_bias else "weight_scale"
|
||||
zp_key = "bias_zero_point" if is_bias else "weight_zero_point"
|
||||
scale = kwargs[scale_key] if is_per_channel else 1.0
|
||||
zp = kwargs[zp_key] if is_per_channel else 0
|
||||
qmin = -127
|
||||
qmax = 127
|
||||
dtype = torch.int8
|
||||
|
||||
qd = torch.ops.quantized_decomposed
|
||||
if is_per_channel:
|
||||
x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
|
||||
x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
|
||||
else:
|
||||
x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
|
||||
x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
|
||||
return x
|
||||
|
||||
|
||||
def _get_quantized_qat_conv_bn_pattern(
|
||||
is_per_channel: bool,
|
||||
has_bias: bool,
|
||||
bias_is_quantized: bool,
|
||||
conv_fn: Callable,
|
||||
bn_is_training: bool,
|
||||
) -> Callable:
|
||||
"""
|
||||
Return the quantized version of QAT conv + BN pattern.
|
||||
This is based on `nniqat.ConvBn2d._forward_approximate`,
|
||||
used in QAT convert. We first match this pattern and replace
|
||||
it with the normal [conv - bn] pattern, then fold the BN
|
||||
weights into conv.
|
||||
"""
|
||||
# TODO: allow setting eps
|
||||
bn_eps = 1e-5
|
||||
|
||||
def _quantized_qat_conv_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
running_std = torch.sqrt(bn_running_var + bn_eps)
|
||||
scale_factor = bn_weight / running_std
|
||||
weight_shape = [1] * len(conv_weight.shape)
|
||||
weight_shape[0] = -1
|
||||
bias_shape = [1] * len(conv_weight.shape)
|
||||
bias_shape[1] = -1
|
||||
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
|
||||
scaled_weight = _append_qdq(
|
||||
scaled_weight,
|
||||
is_per_channel,
|
||||
is_bias=False,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
if has_bias:
|
||||
zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
|
||||
if bias_is_quantized:
|
||||
zero_bias = _append_qdq(
|
||||
zero_bias,
|
||||
is_per_channel,
|
||||
is_bias=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
x = conv_fn(x, scaled_weight, zero_bias)
|
||||
else:
|
||||
x = conv_fn(x, scaled_weight, None)
|
||||
x = x / scale_factor.reshape(bias_shape)
|
||||
if has_bias:
|
||||
x = x + kwargs["conv_bias"].reshape(bias_shape)
|
||||
x = F.batch_norm(
|
||||
x,
|
||||
bn_running_mean,
|
||||
bn_running_var,
|
||||
bn_weight,
|
||||
bn_bias,
|
||||
training=bn_is_training,
|
||||
eps=bn_eps,
|
||||
)
|
||||
return x
|
||||
|
||||
return _WrapperModule(_quantized_qat_conv_bn_pattern)
|
||||
|
||||
|
||||
def _get_folded_quantized_qat_conv_bn_pattern(
|
||||
is_per_channel: bool,
|
||||
has_bias: bool,
|
||||
bias_is_quantized: bool,
|
||||
conv_fn: Callable,
|
||||
bn_is_training: bool,
|
||||
) -> Callable:
|
||||
"""
|
||||
Quantized QAT conv - bn pattern with bn weights being folded into conv.
|
||||
"""
|
||||
# TODO: allow setting eps
|
||||
bn_eps = 1e-5
|
||||
|
||||
def _folded_quantized_qat_conv_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
conv_weight = _append_qdq(
|
||||
conv_weight,
|
||||
is_per_channel,
|
||||
is_bias=False,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
if has_bias:
|
||||
bias = kwargs["conv_bias"]
|
||||
if bias_is_quantized:
|
||||
bias = _append_qdq(
|
||||
bias,
|
||||
is_per_channel,
|
||||
is_bias=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
else:
|
||||
bias = None
|
||||
x = conv_fn(x, conv_weight, bias)
|
||||
x = F.batch_norm(
|
||||
x,
|
||||
bn_running_mean,
|
||||
bn_running_var,
|
||||
bn_weight,
|
||||
bn_bias,
|
||||
training=bn_is_training,
|
||||
eps=bn_eps,
|
||||
)
|
||||
return x
|
||||
|
||||
return _WrapperModule(_folded_quantized_qat_conv_bn_pattern)
|
||||
|
||||
|
||||
def _has_conv_bias_filter(
|
||||
match: "InternalMatch",
|
||||
original_graph: Graph,
|
||||
pattern_graph: Graph,
|
||||
) -> bool:
|
||||
"""
|
||||
Match filter for the subgraph rewriter that returns True if the conv node in
|
||||
the original graph has bias.
|
||||
"""
|
||||
for n in match.nodes_map.values():
|
||||
if _is_conv_or_conv_transpose_node(n):
|
||||
return len(n.args) > 2 and n.args[2] is not None
|
||||
raise ValueError("Could not find conv node in matched conv + bn pattern")
|
||||
|
||||
|
||||
def _no_conv_bias_filter(
|
||||
match: "InternalMatch",
|
||||
original_graph: Graph,
|
||||
pattern_graph: Graph,
|
||||
) -> bool:
|
||||
"""
|
||||
Match filter for the subgraph rewriter that returns True if the conv node in
|
||||
the original graph does NOT have bias.
|
||||
"""
|
||||
return not _has_conv_bias_filter(match, original_graph, pattern_graph)
|
||||
|
||||
|
||||
def _is_quantize(n: Node) -> bool:
|
||||
return n.target in [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
def _is_dequantize(n: Node) -> bool:
|
||||
return n.target in [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> dict[str, tuple[Node, Node]]:
|
||||
"""
|
||||
Helper function to extract the nodes in the conv-bn fusion pattern after
|
||||
subgraph rewriting, in the form of a map:
|
||||
|
||||
{name: (original_node, replacement_node)}
|
||||
|
||||
The following names must exist in the map:
|
||||
|
||||
"conv", "conv_weight", "conv_input", "bn", "getitem"
|
||||
|
||||
The following names may exist in the map:
|
||||
|
||||
"conv_weight_q", "conv_weight_dq", "conv_bias",
|
||||
"conv_bias_q", "conv_bias_dq"
|
||||
"""
|
||||
|
||||
def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Optional[Node]]:
|
||||
"""
|
||||
Return a 3-tuple of (conv_node, bn_node, getitem_node).
|
||||
This asserts that the match contains exactly one of each node.
|
||||
"""
|
||||
conv_node, bn_node, getitem_node = None, None, None
|
||||
for n in nodes:
|
||||
if n.op != "call_function":
|
||||
continue
|
||||
if _is_conv_or_conv_transpose_node(n):
|
||||
assert conv_node is None
|
||||
conv_node = n
|
||||
if _is_bn_node(n):
|
||||
assert bn_node is None
|
||||
bn_node = n
|
||||
if n.target == operator.getitem:
|
||||
assert getitem_node is None
|
||||
getitem_node = n
|
||||
assert conv_node is not None
|
||||
assert bn_node is not None
|
||||
return (conv_node, bn_node, getitem_node)
|
||||
|
||||
def _get_q_dq_nodes(n: Node) -> tuple[Node, Node, Node]:
|
||||
"""
|
||||
Return a 3-tuple of (orig_node, q_node, dq_node).
|
||||
"""
|
||||
assert _is_dequantize(n)
|
||||
q_node = n.args[0]
|
||||
assert isinstance(q_node, Node)
|
||||
assert _is_quantize(q_node)
|
||||
orig_node = q_node.args[0]
|
||||
assert isinstance(orig_node, Node)
|
||||
return (orig_node, q_node, n)
|
||||
|
||||
original_nodes = list(_filter_nodes_map(r.nodes_map).values())
|
||||
o_conv, o_bn, o_getitem = _get_nodes(original_nodes)
|
||||
r_conv, r_bn, r_getitem = _get_nodes(r.replacements)
|
||||
|
||||
# Create the mapping from original node to replacement node
|
||||
assert o_getitem is None
|
||||
assert r_getitem is None
|
||||
mapping = {
|
||||
"conv": (o_conv, r_conv),
|
||||
"bn": (o_bn, r_bn),
|
||||
}
|
||||
|
||||
# Extract conv input and weight
|
||||
# Note: here we extract the original nodes indirectly through the pattern nodes
|
||||
# because the args of the original nodes are no longer available after replacement
|
||||
(p_conv, _, _) = _get_nodes(list(r.nodes_map.keys()))
|
||||
(p_conv_input, p_conv_weight, *_) = p_conv.args
|
||||
(r_conv_input, r_conv_weight, *_) = r_conv.args
|
||||
assert isinstance(p_conv_input, Node)
|
||||
assert isinstance(p_conv_weight, Node)
|
||||
assert isinstance(r_conv_input, Node)
|
||||
assert isinstance(r_conv_weight, Node)
|
||||
o_conv_input = r.nodes_map[p_conv_input]
|
||||
o_conv_weight = r.nodes_map[p_conv_weight]
|
||||
|
||||
# If conv weight is quantized, extract the q - dq nodes
|
||||
if _is_dequantize(p_conv_weight):
|
||||
p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes(
|
||||
p_conv_weight
|
||||
)
|
||||
r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes(
|
||||
r_conv_weight
|
||||
)
|
||||
o_conv_weight = r.nodes_map[p_conv_weight]
|
||||
o_conv_weight_q = r.nodes_map[p_conv_weight_q]
|
||||
o_conv_weight_dq = r.nodes_map[p_conv_weight_dq]
|
||||
mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q)
|
||||
mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq)
|
||||
mapping["conv_input"] = (o_conv_input, r_conv_input)
|
||||
mapping["conv_weight"] = (o_conv_weight, r_conv_weight)
|
||||
|
||||
# Extract conv bias
|
||||
if len(p_conv.args) > 2 and len(r_conv.args) > 2:
|
||||
p_conv_bias = p_conv.args[2]
|
||||
r_conv_bias = r_conv.args[2]
|
||||
assert isinstance(p_conv_bias, Node)
|
||||
assert isinstance(r_conv_bias, Node)
|
||||
o_conv_bias = r.nodes_map[p_conv_bias]
|
||||
|
||||
# If conv bias is quantized, extract the q - dq nodes
|
||||
if _is_dequantize(p_conv_bias):
|
||||
p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias)
|
||||
r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias)
|
||||
o_conv_bias = r.nodes_map[p_conv_bias]
|
||||
o_conv_bias_q = r.nodes_map[p_conv_bias_q]
|
||||
o_conv_bias_dq = r.nodes_map[p_conv_bias_dq]
|
||||
mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q)
|
||||
mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq)
|
||||
mapping["conv_bias"] = (o_conv_bias, r_conv_bias)
|
||||
return mapping
|
||||
|
||||
|
||||
def _filter_nodes_map(nodes_map: dict[Node, Node]) -> dict[Node, Node]:
|
||||
"""
|
||||
Return a filtered `nodes_map` returned from the subgraph rewriter.
|
||||
The filtered `nodes_map` will contain only nodes that are actually
|
||||
matched in the pattern, excluding None or placeholder nodes.
|
||||
"""
|
||||
new_nodes_map: dict[Node, Node] = {}
|
||||
for pattern_node, graph_node in nodes_map.items():
|
||||
# bias can be None
|
||||
if graph_node is None:
|
||||
continue
|
||||
# skip pattern placeholder nodes
|
||||
if pattern_node.op == "placeholder":
|
||||
continue
|
||||
new_nodes_map[pattern_node] = graph_node
|
||||
return new_nodes_map
|
||||
|
||||
|
||||
# TODO: this is error prone, use the replace_literals_with_placeholders hack instead
|
||||
def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
|
||||
"""
|
||||
Copy over literal args in conv, such as stride and padding, from the matched node
|
||||
in the original graph to its replacement in the new graph.
|
||||
|
||||
This is needed due to the following limitation in the subgraph rewriter when used
|
||||
with dynamo export: literal (non-tensor) args are not supported in the match and
|
||||
replacement patterns. This is because dynamo export automatically inlines these
|
||||
literal args, making them dead placeholder nodes. In the future, we should check
|
||||
if dynamo export can optionally disable this inlining, or if subgraph rewriter
|
||||
can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419.
|
||||
|
||||
Note: Unlike other tensor args like conv weights and biases, literal args are
|
||||
preserved in the original nodes after replacement, so we can access them here.
|
||||
"""
|
||||
assert _is_conv_or_conv_transpose_node(original_node)
|
||||
assert _is_conv_or_conv_transpose_node(new_node)
|
||||
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
|
||||
new_args = list(new_node.args)
|
||||
if len(new_args) < 3:
|
||||
# bias is optional, when it is not present, it means it is None
|
||||
new_args.append(None)
|
||||
new_node.args = tuple(new_args[:3]) + original_node.args[3:]
|
||||
|
||||
|
||||
def _update_conv_input_qspec_map_after_replacement(
|
||||
original_node: Node, replacement_node: Node
|
||||
):
|
||||
"""
|
||||
Update the `input_qspec_map` in the annotation after subgraph rewriting.
|
||||
|
||||
The original annotation referred to the nodes in the original graph,
|
||||
so the keys in the `input_qspec_map` will need to be updated to reflect
|
||||
the corresponding nodes in the replacement graph.
|
||||
"""
|
||||
assert _is_conv_or_conv_transpose_node(original_node)
|
||||
assert _is_conv_or_conv_transpose_node(replacement_node)
|
||||
if "quantization_annotation" not in original_node.meta:
|
||||
return
|
||||
original_input_qspec_map = original_node.meta[
|
||||
"quantization_annotation"
|
||||
].input_qspec_map
|
||||
input_qspec_map = {}
|
||||
# get the list of configs, it should be ordered as input, weight, bias
|
||||
# note: this is really hacky, we need a better solution, hopefully
|
||||
# in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820
|
||||
all_configs = list(original_input_qspec_map.items())
|
||||
# input activation
|
||||
input_qspec_map[replacement_node.args[0]] = all_configs[0][1]
|
||||
# weight
|
||||
input_qspec_map[replacement_node.args[1]] = all_configs[1][1]
|
||||
# bias
|
||||
if len(replacement_node.args) > 2 and len(all_configs) > 2:
|
||||
input_qspec_map[replacement_node.args[2]] = all_configs[2][1]
|
||||
replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
|
||||
|
||||
|
||||
def _update_special_qspecs_after_replacement(
|
||||
node: Node,
|
||||
original_to_replacement_node: dict[Node, Node],
|
||||
):
|
||||
"""
|
||||
Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s
|
||||
used in `node`'s quantization annotation after subgraph rewriting.
|
||||
|
||||
The original annotation referred to the nodes in the original graph,
|
||||
so the nodes used in these special quantization specs will need to
|
||||
be updated to the corresponding nodes in the replacement graph.
|
||||
"""
|
||||
|
||||
def _get_new_edge_or_node(edge_or_node: EdgeOrNode):
|
||||
if isinstance(edge_or_node, Node):
|
||||
_node = edge_or_node
|
||||
return original_to_replacement_node.get(_node, _node)
|
||||
elif (
|
||||
isinstance(edge_or_node, tuple)
|
||||
and len(edge_or_node) == 2
|
||||
and all(isinstance(x, Node) for x in edge_or_node)
|
||||
):
|
||||
src, dest = edge_or_node
|
||||
return (
|
||||
original_to_replacement_node.get(src, src),
|
||||
original_to_replacement_node.get(dest, dest),
|
||||
)
|
||||
else:
|
||||
raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node))
|
||||
|
||||
def _get_new_qspec(qspec: QuantizationSpecBase):
|
||||
if isinstance(qspec, SharedQuantizationSpec):
|
||||
new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node)
|
||||
return SharedQuantizationSpec(new_edge_or_node)
|
||||
elif isinstance(qspec, DerivedQuantizationSpec):
|
||||
new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from]
|
||||
return dataclasses.replace(qspec, derived_from=new_derived_from)
|
||||
else:
|
||||
return qspec
|
||||
|
||||
if "quantization_annotation" not in node.meta:
|
||||
return
|
||||
annotation = node.meta["quantization_annotation"]
|
||||
for input_node, qspec in annotation.input_qspec_map.items():
|
||||
annotation.input_qspec_map[input_node] = _get_new_qspec(qspec)
|
||||
annotation.output_qspec = _get_new_qspec(annotation.output_qspec)
|
||||
|
||||
|
||||
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
# Example inputs for conv-bn1d patterns
|
||||
_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for conv-bn2d patterns
|
||||
_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
|
||||
for is_cuda in is_cuda_options:
|
||||
m = _fuse_conv_bn_qat_helper(
|
||||
m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
m = _fuse_conv_bn_qat_helper(
|
||||
m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
m = _fuse_conv_bn_qat_helper(
|
||||
m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
m = _fuse_conv_bn_qat_helper(
|
||||
m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
def _fuse_conv_bn_qat_helper(
|
||||
m: GraphModule,
|
||||
conv_fn: Callable,
|
||||
example_inputs: tuple[Any, ...],
|
||||
is_cuda: bool,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
|
||||
the fused QAT subgraph equivalent. The input graph should already be annotated.
|
||||
The annotations in the original nodes will be preserved in the corresponding
|
||||
nodes in the new subgraph.
|
||||
|
||||
Note: This also handles the (conv + bn + relu) pattern.
|
||||
"""
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
conv_bn_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
)
|
||||
|
||||
# Step (1): Replace patterns with conv bias
|
||||
#
|
||||
# Here we do replacement separately for cases with and without conv bias, since
|
||||
# the replacement patterns for these two cases are substantially different.
|
||||
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
||||
|
||||
qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
|
||||
replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern(
|
||||
qat_conv_bn_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
)
|
||||
replacements_with_conv_bias = replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern_with_conv_bias,
|
||||
match_filters=[_has_conv_bias_filter],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
|
||||
# Step (2): Replace patterns without conv bias
|
||||
|
||||
qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
|
||||
replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern(
|
||||
qat_conv_bn_pattern_no_conv_bias,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
)
|
||||
replacements_no_conv_bias = replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern_no_conv_bias,
|
||||
match_filters=[_no_conv_bias_filter],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
|
||||
# Step (3): Post processing
|
||||
#
|
||||
# Due to limited functionality in the subgraph rewriter, here we manually
|
||||
# update the replacement graph as follows:
|
||||
#
|
||||
# (a) Copy over metadata from original subgraph. This ensures the stack traces
|
||||
# and annotations are preserved in the new subgraph
|
||||
#
|
||||
# (b) Copy over literal args for conv from the original subgraph
|
||||
# TODO: do this for literal args for batchnorm as well
|
||||
#
|
||||
# (c) Update all references of the old nodes in the original subgraph to refer
|
||||
# to the corresponding nodes in the new subgraph in the annotations
|
||||
#
|
||||
# In the future, we should try to push as much of this functionality into the
|
||||
# subgraph rewriter as possible, so we don't have to manually copy anything over.
|
||||
# For more detail, see https://github.com/pytorch/pytorch/issues/100419.
|
||||
|
||||
all_original_to_replacement_nodes = {}
|
||||
for r in replacements_with_conv_bias + replacements_no_conv_bias:
|
||||
replacement_dict = _get_conv_bn_pattern_nodes(r)
|
||||
# The original conv node's "nn_module_stack"
|
||||
conv_nn_module = replacement_dict["conv"][0].meta.get("nn_module_stack", None)
|
||||
for k, node_tuple in replacement_dict.items():
|
||||
original_node, replacement_node = node_tuple
|
||||
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
|
||||
replacement_node.meta = original_node.meta
|
||||
# If original_node is a get_attr node, it doesn't have nn_module_stack.
|
||||
# In this case, we copy nn_module_stack from the original conv node.
|
||||
if (
|
||||
k in ["conv_input", "conv_weight"]
|
||||
and conv_nn_module
|
||||
and "nn_module_stack" not in replacement_node.meta
|
||||
):
|
||||
replacement_node.meta["nn_module_stack"] = copy.deepcopy(conv_nn_module)
|
||||
if _is_conv_or_conv_transpose_node(original_node):
|
||||
# Step (3b): Copy over conv literal args
|
||||
_copy_over_literal_conv_args(original_node, replacement_node)
|
||||
# Step (3c): Update old references in the conv node's input_qspec_map
|
||||
_update_conv_input_qspec_map_after_replacement(
|
||||
original_node, replacement_node
|
||||
)
|
||||
all_original_to_replacement_nodes[original_node] = replacement_node
|
||||
|
||||
# Step (3c): Update old references in the special qspecs for all nodes in the graph
|
||||
for n in m.graph.nodes:
|
||||
_update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def _duplicate_dequantize_node(m: GraphModule):
|
||||
"""
|
||||
Helper function to duplicate all dequantize nodes in the graph if the
|
||||
node has more than one user. For example:
|
||||
|
||||
Before:
|
||||
quantize -> dequantize -> a
|
||||
\\--> b
|
||||
\\--> c
|
||||
|
||||
After:
|
||||
quantize -> dequantize_1 -> a
|
||||
\\--> dequantize_2 -> b
|
||||
\\--> dequantize_3 -> c
|
||||
|
||||
This is useful for subgraph rewriting. E.g. if we wish to match the
|
||||
pattern [dequantize - a] above, subgraph matching would fail because
|
||||
the dequantize node has users outside the matched portion of the graph.
|
||||
Instead, we match [dequantize_1 - a], which is safe.
|
||||
"""
|
||||
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
for n in m.graph.nodes:
|
||||
if n.op != "call_function" or n.target != dq_op or len(n.users) == 1:
|
||||
continue
|
||||
for user in list(n.users):
|
||||
with m.graph.inserting_before(n):
|
||||
new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs)
|
||||
user.replace_input_with(n, new_node)
|
||||
m.graph.erase_node(n)
|
||||
m.recompile()
|
||||
|
||||
|
||||
def _remove_extra_dequantize(m: GraphModule):
|
||||
"""
|
||||
Removes duplicate dequant nodes in the graph, for an operator that has
|
||||
multiple dequant nodes as a user, replace them with a single dequant node
|
||||
that can be shared across all the uses. This should be seen as the "reverse"
|
||||
of `_duplicate_dequantize_node`.
|
||||
"""
|
||||
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
for n in m.graph.nodes:
|
||||
dq_users = [
|
||||
user
|
||||
for user in n.users
|
||||
if user.op == "call_function" and user.target == dq_op
|
||||
]
|
||||
if len(dq_users) > 1:
|
||||
with m.graph.inserting_after(dq_users[0]):
|
||||
new_node = m.graph.create_node(
|
||||
"call_function", dq_op, dq_users[0].args, {}
|
||||
)
|
||||
for dq_user in dq_users:
|
||||
dq_user.replace_all_uses_with(new_node)
|
||||
m.graph.erase_node(dq_user)
|
||||
m.recompile()
|
||||
|
||||
|
||||
def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
|
||||
"""
|
||||
Given a pair of quantize or dequantize nodes, copy over all literal args
|
||||
from the original node to the replacement node.
|
||||
"""
|
||||
# For quantize_per_tensor, scale and zp are literals and need to be copied
|
||||
# For quantize_per_channel, scale and zp are get_attr nodes and should be skipped
|
||||
assert original_node.target == replacement_node.target
|
||||
if original_node.target in (
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
):
|
||||
# Args: input, [scale, zp, qmin, qmax, dtype]
|
||||
start_copy_arg_index = 1
|
||||
elif original_node.target in (
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
):
|
||||
# Args: input, scale, zp, [axis, qmin, qmax, dtype]
|
||||
start_copy_arg_index = 3
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected quantize/dequantize nodes, got '{original_node.target}'"
|
||||
)
|
||||
replacement_node.args = (
|
||||
replacement_node.args[:start_copy_arg_index]
|
||||
+ original_node.args[start_copy_arg_index:]
|
||||
)
|
||||
|
||||
|
||||
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
# Example inputs for quantized and folded conv-bn1d patterns used in convert
|
||||
_quantized_conv1d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3), # x
|
||||
torch.randn(1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for quantized and folded conv-bn2d patterns used in convert
|
||||
_quantized_conv2d_bn_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
|
||||
for is_cuda in is_cuda_options:
|
||||
m = _fold_conv_bn_qat_helper(
|
||||
m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
m = _fold_conv_bn_qat_helper(
|
||||
m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
m = _fold_conv_bn_qat_helper(
|
||||
m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
m = _fold_conv_bn_qat_helper(
|
||||
m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
|
||||
)
|
||||
|
||||
# remove in place add from batchnorm tracking traning stats
|
||||
for node in m.graph.nodes:
|
||||
if (
|
||||
node.target == torch.ops.aten.add_.Tensor
|
||||
and node.args[0].op == "get_attr"
|
||||
and node.args[1] == 1
|
||||
and torch.nn.modules.batchnorm.BatchNorm2d
|
||||
in [val[1] for val in node.meta["source_fn_stack"]]
|
||||
):
|
||||
m.graph.erase_node(node)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def _fold_conv_bn_qat_helper(
|
||||
m: GraphModule,
|
||||
conv_fn: Callable,
|
||||
example_inputs: tuple[Any, ...],
|
||||
is_cuda: bool,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
|
||||
"""
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
_duplicate_dequantize_node(m)
|
||||
|
||||
# Step (1): Replace QAT pattern with simple [conv - bn] pattern
|
||||
replacements = []
|
||||
replacement_options = itertools.product(
|
||||
[True, False], # is_per_channel
|
||||
[True, False], # has_bias
|
||||
[True, False], # bias_is_quantized
|
||||
[True, False], # bn_is_training
|
||||
)
|
||||
for (
|
||||
is_per_channel,
|
||||
has_bias,
|
||||
bias_is_quantized,
|
||||
bn_is_training,
|
||||
) in replacement_options:
|
||||
# For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
|
||||
# filter out one of the values for this flag to avoid having duplicate patterns
|
||||
if not has_bias and bias_is_quantized:
|
||||
continue
|
||||
kwargs = _get_quantized_conv_bn_example_inputs_kwargs(
|
||||
is_per_channel, has_bias, bias_is_quantized, is_cuda
|
||||
)
|
||||
match_pattern = _get_quantized_qat_conv_bn_pattern(
|
||||
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
|
||||
)
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
match_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
**kwargs,
|
||||
)
|
||||
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
|
||||
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
replacement_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
**kwargs,
|
||||
)
|
||||
replacements.extend(
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
ignore_literals=True,
|
||||
)
|
||||
)
|
||||
m.recompile()
|
||||
_remove_extra_dequantize(m)
|
||||
|
||||
for r in replacements:
|
||||
node_map = _get_conv_bn_pattern_nodes(r)
|
||||
|
||||
# Step (2): Copy over metadata from original subgraph
|
||||
for original_node, replacement_node in node_map.values():
|
||||
replacement_node.meta = original_node.meta
|
||||
|
||||
# Step (3): Copy over args for weight (and optionally bias) q - dq nodes
|
||||
_copy_over_q_dq_args(*node_map["conv_weight_q"])
|
||||
_copy_over_q_dq_args(*node_map["conv_weight_dq"])
|
||||
if "conv_bias_q" in node_map:
|
||||
assert "conv_bias_dq" in node_map
|
||||
_copy_over_q_dq_args(*node_map["conv_bias_q"])
|
||||
_copy_over_q_dq_args(*node_map["conv_bias_dq"])
|
||||
|
||||
# Step (4): Fold BN weights into conv
|
||||
conv_bias = None
|
||||
(_, conv_node) = node_map["conv"]
|
||||
(_, bn_node) = node_map["bn"]
|
||||
(_, conv_weight) = node_map["conv_weight"]
|
||||
if "conv_bias" in node_map:
|
||||
(_, conv_bias) = node_map["conv_bias"]
|
||||
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
||||
|
||||
# Copy over literal args for conv
|
||||
for original_node in _filter_nodes_map(r.nodes_map).values():
|
||||
if _is_conv_or_conv_transpose_node(original_node):
|
||||
_copy_over_literal_conv_args(original_node, conv_node)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
return m
|
|
@ -0,0 +1,6 @@
|
|||
from .rewrite import reference_representation_rewrite
|
||||
|
||||
|
||||
__all__ = [
|
||||
"reference_representation_rewrite",
|
||||
]
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,819 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_get_aten_graph_module_for_pattern,
|
||||
_replace_literals_with_existing_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
remove_tensor_overload_for_qdq_ops,
|
||||
)
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.subgraph_rewriter import replace_pattern
|
||||
|
||||
|
||||
__all__ = [
|
||||
"reference_representation_rewrite",
|
||||
]
|
||||
|
||||
|
||||
def _qdq_quantized_linear(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
bias_fp32,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
out_quant_min,
|
||||
out_quant_max,
|
||||
):
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
|
||||
)
|
||||
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
torch.int8,
|
||||
)
|
||||
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
|
||||
)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _reference_quantized_linear(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
bias_fp32,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
out_quant_min,
|
||||
out_quant_max,
|
||||
):
|
||||
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
|
||||
# This results in failure to match the pattern.
|
||||
# Therefore, we call a torch.ops.aten.clamp here
|
||||
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
|
||||
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
||||
|
||||
x_i16 = x_i8.to(torch.int16)
|
||||
weight_i16 = weight_i8.to(torch.int16)
|
||||
# always set bias to None so that the same representation can work for the case
|
||||
# no matter if bias_scale == x_scale * weight_scale or not
|
||||
acc_i32 = out_dtype(
|
||||
torch.ops.aten.linear.default,
|
||||
torch.int32,
|
||||
x_i16 - x_zero_point,
|
||||
weight_i16 - weight_zero_point,
|
||||
None,
|
||||
)
|
||||
# TODO: change to mul.Scalar
|
||||
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
|
||||
bias_scale = x_scale * weight_scale
|
||||
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
||||
acc_i32 = acc_i32 + bias_i32
|
||||
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
|
||||
acc_i32 = (
|
||||
out_dtype(
|
||||
torch.ops.aten.mul.Tensor,
|
||||
torch.int32,
|
||||
acc_i32,
|
||||
x_scale * weight_scale / out_scale,
|
||||
)
|
||||
+ out_zero_point
|
||||
)
|
||||
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _qdq_dynamic_quantized_linear(
|
||||
x_fp32,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
x_eps,
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
bias_fp32,
|
||||
):
|
||||
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
|
||||
x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
|
||||
)
|
||||
x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
|
||||
)
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
|
||||
)
|
||||
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
torch.int8,
|
||||
)
|
||||
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
|
||||
return out_fp32
|
||||
|
||||
|
||||
def _reference_dynamic_quantized_linear(
|
||||
x_fp32,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
x_eps,
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
bias_fp32,
|
||||
):
|
||||
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
|
||||
x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
|
||||
)
|
||||
# decomposed representation for quantize_per_tensor
|
||||
# TODO: use out_dtype(mul, ...) here when the op is ready
|
||||
x_fp32 = x_fp32 / x_scale # fp32
|
||||
# round modes might be different here
|
||||
# pytorch is rounding to even, which is also common for most of the backends
|
||||
x_fp32 = torch.round(x_fp32) # fp32
|
||||
x_i32 = x_fp32.to(dtype=torch.int32) # int32
|
||||
x_i32 = x_i32 + x_zero_point # int32
|
||||
# clamp works for fp32, int32 and int8 dtypes
|
||||
x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32
|
||||
x_i8 = x_i32.to(dtype=torch.int8)
|
||||
|
||||
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
||||
|
||||
x_i16 = x_i8.to(torch.int16)
|
||||
weight_i16 = weight_i8.to(torch.int16)
|
||||
# always set bias to None so that the same representation can work for the case
|
||||
# no matter if bias_scale == x_scale * weight_scale or not
|
||||
acc_i32 = out_dtype(
|
||||
torch.ops.aten.linear.default,
|
||||
torch.int32,
|
||||
x_i16 - x_zero_point,
|
||||
weight_i16 - weight_zero_point,
|
||||
None,
|
||||
)
|
||||
bias_scale = x_scale * weight_scale
|
||||
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
||||
acc_i32 = acc_i32 + bias_i32
|
||||
out_fp32 = acc_i32 * (x_scale * weight_scale)
|
||||
return out_fp32
|
||||
|
||||
|
||||
def _qdq_quantized_conv2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
bias_fp32,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
out_quant_min,
|
||||
out_quant_max,
|
||||
):
|
||||
stride = [1, 1]
|
||||
padding = [0, 0]
|
||||
dilation = [1, 1]
|
||||
transposed = False
|
||||
output_padding = [0, 0]
|
||||
groups = 1
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
|
||||
)
|
||||
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
torch.int8,
|
||||
)
|
||||
out_fp32 = torch.ops.aten.convolution.default(
|
||||
x_fp32,
|
||||
weight_fp32,
|
||||
bias_fp32,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
)
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
|
||||
)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _reference_quantized_conv2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
weight_i8,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
weight_quant_min,
|
||||
weight_quant_max,
|
||||
bias_fp32,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
out_quant_min,
|
||||
out_quant_max,
|
||||
):
|
||||
stride = [1, 1]
|
||||
padding = [0, 0]
|
||||
dilation = [1, 1]
|
||||
transposed = False
|
||||
output_padding = [0, 0]
|
||||
groups = 1
|
||||
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
|
||||
# This results in failure to match the pattern.
|
||||
# Therefore, we call a torch.ops.aten.clamp here
|
||||
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
|
||||
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
||||
|
||||
x_i16 = x_i8.to(torch.int16)
|
||||
weight_i16 = weight_i8.to(torch.int16)
|
||||
# always set bias to None so that the same representation can work for the case
|
||||
# no matter if bias_scale == x_scale * weight_scale or not
|
||||
acc_i32 = out_dtype(
|
||||
torch.ops.aten.convolution.default,
|
||||
torch.int32,
|
||||
x_i16 - x_zero_point,
|
||||
weight_i16 - weight_zero_point,
|
||||
None,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
)
|
||||
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
|
||||
bias_scale = x_scale * weight_scale
|
||||
# bias quantization to int32 uses bias_scale = x_scale * weight_scale due to:
|
||||
# Take linear calculation for example
|
||||
# Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32
|
||||
# Represent X, W fp32 as their dequant transforms
|
||||
# A_fp32 = (A_q - A_zero_point)/A_scale
|
||||
# Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32
|
||||
# Factor out X_scale and W_scale
|
||||
# Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32
|
||||
# In order to addition of bias_(i)_fp32 inside, we must do
|
||||
# Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950
|
||||
# Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale
|
||||
# Thus bias quantization to int32 must be with X_scale * W_scale
|
||||
|
||||
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
||||
# Unsqueeze to match broadcast dims
|
||||
# Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare
|
||||
# in graph pattern replacement
|
||||
bias_i32 = bias_i32.unsqueeze(-1)
|
||||
bias_i32 = bias_i32.unsqueeze(-1)
|
||||
acc_i32 = acc_i32 + bias_i32
|
||||
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
|
||||
acc_i32 = (
|
||||
out_dtype(
|
||||
torch.ops.aten.mul.Tensor,
|
||||
torch.int32,
|
||||
acc_i32,
|
||||
x_scale * weight_scale / out_scale,
|
||||
)
|
||||
+ out_zero_point
|
||||
)
|
||||
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _qdq_quantized_add_relu(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
y_i8,
|
||||
y_scale,
|
||||
y_zero_point,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
):
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
out_fp32 = x_fp32 + y_fp32
|
||||
out_fp32 = torch.ops.aten.relu(out_fp32)
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _reference_quantized_add_relu(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
y_i8,
|
||||
y_scale,
|
||||
y_zero_point,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
):
|
||||
"""
|
||||
See comments for `_reference_quantized_add` for more information on
|
||||
how to derive the formula for out_i8 based on x_i8 and y_i8
|
||||
"""
|
||||
x_i32 = x_i8.to(torch.int32)
|
||||
y_i32 = y_i8.to(torch.int32)
|
||||
# TODO: change this to mul.Scalar?
|
||||
x_i32 = out_dtype(
|
||||
torch.ops.aten.mul.Tensor,
|
||||
torch.int32,
|
||||
(x_i32 - x_zero_point),
|
||||
(x_scale / out_scale),
|
||||
)
|
||||
y_i32 = out_dtype(
|
||||
torch.ops.aten.mul.Tensor,
|
||||
torch.int32,
|
||||
(y_i32 - y_zero_point),
|
||||
(y_scale / out_scale),
|
||||
)
|
||||
out_i32 = x_i32 + y_i32 + out_zero_point
|
||||
# out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point)
|
||||
out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _qdq_quantized_add(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
y_i8,
|
||||
y_scale,
|
||||
y_zero_point,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
):
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
out_fp32 = x_fp32 + y_fp32
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _reference_quantized_add(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
y_i8,
|
||||
y_scale,
|
||||
y_zero_point,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
):
|
||||
"""
|
||||
# How to Derive the formula for out_i8 based on x_i8 and y_i8
|
||||
# (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
|
||||
|
||||
# out_i8 is quantized output, we can write down the formula for it first:
|
||||
out_i8 = out_f32 / out_scale + out_zero_point (1)
|
||||
|
||||
# then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
|
||||
out_f32 = x_f32 + y_f32 (2)
|
||||
x_fp32 = (x_i8 - x_zero_point) * x_scale (3)
|
||||
y_fp32 = (y_i8 - y_zero_point) * y_scale (4)
|
||||
|
||||
# applying the above fomula to the out_i8 equation we can get the following:
|
||||
out_i8 = out_fp32 / out_scale + out_zero_point # (1)
|
||||
= (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
|
||||
= ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4)
|
||||
"""
|
||||
x_i32 = x_i8.to(torch.int32)
|
||||
y_i32 = y_i8.to(torch.int32)
|
||||
# TODO: use out_dtype op
|
||||
x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
|
||||
y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
|
||||
out_i32 = x_i32 + y_i32 + out_zero_point
|
||||
quant_min = -128
|
||||
quant_max = 127
|
||||
out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _qdq_quantized_max_pool2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
out_quant_min,
|
||||
out_quant_max,
|
||||
):
|
||||
kernel_size = 1
|
||||
stride = 1
|
||||
padding = 0
|
||||
dilation = 1
|
||||
ceil_mode = False
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
|
||||
)
|
||||
out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(
|
||||
x_fp32, kernel_size, stride, padding, dilation, ceil_mode
|
||||
)
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
|
||||
)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _reference_quantized_max_pool2d(
|
||||
x_i8,
|
||||
x_scale,
|
||||
x_zero_point,
|
||||
x_quant_min,
|
||||
x_quant_max,
|
||||
out_scale,
|
||||
out_zero_point,
|
||||
out_quant_min,
|
||||
out_quant_max,
|
||||
):
|
||||
kernel_size = 1
|
||||
stride = 1
|
||||
padding = 0
|
||||
dilation = 1
|
||||
ceil_mode = False
|
||||
# to preserve x_quant_min, x_quant_max in the graph for pattern matching
|
||||
x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
|
||||
x_i32 = x_i8.to(torch.int32)
|
||||
out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default(
|
||||
x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode
|
||||
)
|
||||
out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
|
||||
out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
|
||||
out_i8 = out_fp32.to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
|
||||
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _reference_quantize_per_tensor_int8(
|
||||
x_fp32, scale, zero_point, quant_min, quant_max
|
||||
):
|
||||
# TODO: use out_dtype(mul, ...) here when the op is ready
|
||||
x = x_fp32 / scale # fp32
|
||||
# round modes might be different here
|
||||
# pytorch is rounding to even, which is also common for most of the backends
|
||||
x = torch.round(x) # fp32
|
||||
x = x.to(dtype=torch.int32) # int32
|
||||
x = x + zero_point # int32
|
||||
# clamp works for fp32, int32 and int8 dtypes
|
||||
x = torch.clamp(x, quant_min, quant_max) # int32
|
||||
x = x.to(dtype=torch.int8)
|
||||
return x
|
||||
|
||||
|
||||
def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, scale, zero_point, quant_min, quant_max, torch.int8
|
||||
)
|
||||
return x_fp32
|
||||
|
||||
|
||||
def _reference_dequantize_per_tensor_int8(
|
||||
x_i8, scale, zero_point, quant_min, quant_max
|
||||
):
|
||||
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
|
||||
# This results in failure to match the pattern.
|
||||
# Therefore, we call a torch.ops.aten.clamp here
|
||||
x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
|
||||
# TODO: use out_dtype op
|
||||
# note: x_i8.to(torch.int32) does not work here
|
||||
# TODO: debug the implementation later when torchdynamo time out issue is resolved
|
||||
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
||||
|
||||
|
||||
def _quantize_per_channel_int8(
|
||||
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_channel(
|
||||
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
|
||||
)
|
||||
return out_i8
|
||||
|
||||
|
||||
def _reference_quantize_per_channel_int8(
|
||||
x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
x_fp32 = torch.transpose(x_fp32, ch_axis, -1)
|
||||
out_i32 = torch.ops.aten.clamp(
|
||||
torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max
|
||||
)
|
||||
out_i32 = torch.transpose(out_i32, ch_axis, -1)
|
||||
return out_i32.to(torch.int8)
|
||||
|
||||
|
||||
def _dequantize_per_channel_int8(
|
||||
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
# the following will be replaced as placeholders
|
||||
out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel(
|
||||
x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
|
||||
)
|
||||
return out_fp32
|
||||
|
||||
|
||||
def _reference_dequantize_per_channel_int8(
|
||||
x_i8, scales, zero_points, ch_axis, quant_min, quant_max
|
||||
):
|
||||
# the following will be replaced as placeholders
|
||||
# in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops)
|
||||
# we call a torch.ops.aten.clamp here
|
||||
x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
|
||||
x_i8 = torch.transpose(x_i8, ch_axis, -1)
|
||||
x_i32 = x_i8.to(torch.int32)
|
||||
out_fp32 = (x_i32 - zero_points).to(torch.float) * scales
|
||||
out_fp32 = torch.transpose(out_fp32, ch_axis, -1)
|
||||
return out_fp32
|
||||
|
||||
|
||||
def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule):
|
||||
return _replace_literals_with_existing_placeholders(
|
||||
gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RewriteInfo:
|
||||
"""Data needed for rewrite, this includes example inputs, pattern and replacement functions
|
||||
and post transformation functions for the exported pattern and replacement GraphModule
|
||||
"""
|
||||
|
||||
# example inputs used for exporting the pattern into GraphModule
|
||||
example_inputs: tuple[Any, ...]
|
||||
pattern: Callable
|
||||
replacement: Callable
|
||||
# post transformation on the exported pattern and replacement GraphModule
|
||||
pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
||||
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
||||
|
||||
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randn((2, 5), dtype=torch.float),
|
||||
-128,
|
||||
127,
|
||||
torch.finfo(torch.float32).eps,
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
)
|
||||
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randn(1, 3, 3, 3, dtype=torch.float),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(3, dtype=torch.float),
|
||||
torch.zeros(3, dtype=torch.int),
|
||||
1,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
|
||||
_REWRITE_INFO_LIST = [
|
||||
_RewriteInfo(
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_dynamic_quantized_linear),
|
||||
_WrapperModule(_reference_dynamic_quantized_linear),
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||
),
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
|
||||
),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_linear),
|
||||
_WrapperModule(_reference_quantized_linear),
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_conv2d),
|
||||
_WrapperModule(_reference_quantized_conv2d),
|
||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||
partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_add_relu),
|
||||
_WrapperModule(_reference_quantized_add_relu),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_add),
|
||||
_WrapperModule(_reference_quantized_add),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_qdq_quantized_max_pool2d),
|
||||
_WrapperModule(_reference_quantized_max_pool2d),
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_quantize_per_tensor_int8),
|
||||
_WrapperModule(_reference_quantize_per_tensor_int8),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_dequantize_per_tensor_int8),
|
||||
_WrapperModule(_reference_dequantize_per_tensor_int8),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_quantize_per_channel_int8),
|
||||
_WrapperModule(_reference_quantize_per_channel_int8),
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
|
||||
_WrapperModule(_dequantize_per_channel_int8),
|
||||
_WrapperModule(_reference_dequantize_per_channel_int8),
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
_replace_ph_qdq_per_channel_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
remove_tensor_overload_for_qdq_ops(model)
|
||||
|
||||
for rewrite_info in _REWRITE_INFO_LIST:
|
||||
example_inputs = rewrite_info.example_inputs
|
||||
pattern = rewrite_info.pattern
|
||||
replacement = rewrite_info.replacement
|
||||
pattern_post_trans = rewrite_info.pattern_post_trans
|
||||
replacement_post_trans = rewrite_info.replacement_post_trans
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
||||
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
||||
if pattern_post_trans:
|
||||
pattern = pattern_post_trans(pattern)
|
||||
if replacement_post_trans:
|
||||
replacement = replacement_post_trans(replacement)
|
||||
pattern.recompile() # type: ignore[attr-defined]
|
||||
replacement.recompile() # type: ignore[attr-defined]
|
||||
replace_pattern(model, pattern, replacement)
|
||||
|
||||
return model
|
610
venv/Lib/site-packages/torch/ao/quantization/pt2e/utils.py
Normal file
610
venv/Lib/site-packages/torch/ao/quantization/pt2e/utils.py
Normal file
|
@ -0,0 +1,610 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
import types
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Makes sure that quantized_decomposed ops are registered
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
||||
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||
from torch.fx import GraphModule, Node
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||
from torch.utils._pytree import LeafSpec
|
||||
|
||||
|
||||
__all__ = [
|
||||
"fold_bn_weights_into_conv_node",
|
||||
"remove_tensor_overload_for_qdq_ops",
|
||||
]
|
||||
|
||||
_QUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
_DEQUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
|
||||
"""
|
||||
Assuming dest is one of the ops inserted by quant workflow, this function
|
||||
finds if source and dest are connected. Assumption is that only quant workflow
|
||||
inserted ops exist between source and dest
|
||||
"""
|
||||
quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
|
||||
quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
|
||||
while dest.target in quant_workflow_ops:
|
||||
if not isinstance(dest.args[0], torch.fx.Node):
|
||||
raise ValueError(
|
||||
f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}"
|
||||
)
|
||||
dest = dest.args[0]
|
||||
return dest == source
|
||||
|
||||
|
||||
def _find_q_dq_node_for_user(
|
||||
produer: torch.fx.Node, user: torch.fx.Node
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Find q, dq pair corresponding to [producer -> q -> dq -> user]
|
||||
Utils works by finding dq arg of user and ensuring it is connected to
|
||||
producer
|
||||
"""
|
||||
dq_node = None
|
||||
for n in user.args:
|
||||
if (
|
||||
isinstance(n, torch.fx.Node)
|
||||
and n.op == "call_function"
|
||||
and n.target in _DEQUANTIZE_OPS
|
||||
):
|
||||
if _is_connected(produer, n):
|
||||
dq_node = n
|
||||
break
|
||||
if dq_node is None:
|
||||
for n in user.kwargs:
|
||||
if (
|
||||
isinstance(n, torch.fx.Node)
|
||||
and n.op == "call_function"
|
||||
and n.target in _DEQUANTIZE_OPS
|
||||
):
|
||||
if _is_connected(produer, n):
|
||||
dq_node = n
|
||||
break
|
||||
if dq_node is None:
|
||||
return (None, None)
|
||||
|
||||
q_node = None
|
||||
if (
|
||||
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
|
||||
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
|
||||
):
|
||||
q_node = dq_node.args[0]
|
||||
return (q_node, dq_node)
|
||||
|
||||
|
||||
def _is_sym_size_node(node: Node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.sym_size.default
|
||||
or node.target == torch.ops.aten.sym_numel.default
|
||||
or node.target == torch.ops.aten.sym_numel
|
||||
or node.target == torch.ops.aten.sym_size
|
||||
)
|
||||
|
||||
|
||||
def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]:
|
||||
node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
|
||||
return node_users
|
||||
|
||||
|
||||
def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
|
||||
if annotation is None:
|
||||
return False
|
||||
input_qspec_map = annotation.input_qspec_map
|
||||
output_qspec = annotation.output_qspec
|
||||
if len(input_qspec_map) == 0 and output_qspec is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_tensor_constant_from_node(node, m):
|
||||
if node is None:
|
||||
return None
|
||||
assert node.op == "get_attr"
|
||||
target_atoms = node.target.split(".")
|
||||
attr_itr = m
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(
|
||||
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
|
||||
)
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
|
||||
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
|
||||
all_args = []
|
||||
for i, schema in enumerate(args_schema):
|
||||
if schema.name in orig_kwargs:
|
||||
all_args.append(orig_kwargs[schema.name])
|
||||
elif not schema.kwarg_only and i < len(orig_args):
|
||||
all_args.append(orig_args[i])
|
||||
else:
|
||||
all_args.append(schema.default_value)
|
||||
return all_args
|
||||
|
||||
|
||||
def _is_supported_batch_norm_for_training(node: Node):
|
||||
"""
|
||||
Return True if the given node refers to an aten batch norm op QAT supports.
|
||||
"""
|
||||
supported_ops = [
|
||||
torch.ops.aten.batch_norm.default,
|
||||
torch.ops.aten._native_batch_norm_legit.default,
|
||||
# Note: we won't need this op anymore after batch norm consolidation
|
||||
# For now, we need to continue to support it because it gives better
|
||||
# training numerics than `_native_batch_norm_legit`
|
||||
torch.ops.aten.cudnn_batch_norm.default,
|
||||
torch.ops.aten.miopen_batch_norm.default,
|
||||
]
|
||||
return node.target in supported_ops
|
||||
|
||||
|
||||
# TODO: move this to torch/ao/quantization/utils.py
|
||||
def _is_conv_node(n: Node):
|
||||
"""
|
||||
Return whether the node refers to an aten conv op.
|
||||
"""
|
||||
return n.op == "call_function" and n.target in [
|
||||
torch.ops.aten.conv1d.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
]
|
||||
|
||||
|
||||
def _is_conv_transpose_node(n: Node):
|
||||
"""
|
||||
Return whether the node refers to an aten conv_transpose op.
|
||||
"""
|
||||
return n.op == "call_function" and n.target in [
|
||||
torch.ops.aten.conv_transpose1d,
|
||||
torch.ops.aten.conv_transpose1d.default,
|
||||
torch.ops.aten.conv_transpose2d,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
]
|
||||
|
||||
|
||||
def _is_conv_or_conv_transpose_node(n: Node):
|
||||
"""
|
||||
Return whether the node refers to an aten conv or conv transpose op.
|
||||
"""
|
||||
return _is_conv_node(n) or _is_conv_transpose_node(n)
|
||||
|
||||
|
||||
def _is_conv_transpose_fn(conv_fn: Callable):
|
||||
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
|
||||
|
||||
|
||||
def _is_bn_node(n: Node):
|
||||
return (
|
||||
_is_supported_batch_norm_for_training(n)
|
||||
or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||
)
|
||||
|
||||
|
||||
def fold_bn_weights_into_conv_node(
|
||||
conv_node: Node,
|
||||
conv_weight_node: Node,
|
||||
conv_bias_node: Optional[Node],
|
||||
bn_node: Node,
|
||||
m: GraphModule,
|
||||
) -> None:
|
||||
# conv args: input, weight, bias, stride, padding, dilation, ...
|
||||
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
|
||||
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
|
||||
transpose = _is_conv_transpose_node(conv_node)
|
||||
|
||||
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
|
||||
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
|
||||
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
|
||||
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
|
||||
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
|
||||
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
|
||||
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
|
||||
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
|
||||
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
|
||||
eps_arg_index = 6
|
||||
elif _is_supported_batch_norm_for_training(bn_node):
|
||||
eps_arg_index = 7
|
||||
else:
|
||||
raise ValueError("BN node target is unexpected ", bn_node.target)
|
||||
bn_eps = bn_args[eps_arg_index]
|
||||
|
||||
fused_weight, fused_bias = fuse_conv_bn_weights(
|
||||
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose
|
||||
)
|
||||
|
||||
# update the weight and bias for conv
|
||||
conv_args = list(conv_node.args)
|
||||
# filling in the default bias argument
|
||||
if len(conv_args) == 2:
|
||||
conv_args.append(None)
|
||||
|
||||
# calling data since the fused_weight and fused_bias are nn.Parameter
|
||||
weight_attr_name = conv_weight_node.target
|
||||
assert isinstance(weight_attr_name, str)
|
||||
_assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
|
||||
if conv_bias_node is not None:
|
||||
bias_attr_name = conv_bias_node.target
|
||||
_assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
|
||||
else:
|
||||
bias_attr_name = weight_attr_name + "_bias"
|
||||
_assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
|
||||
with m.graph.inserting_before(conv_node):
|
||||
get_bias_node = m.graph.get_attr(bias_attr_name)
|
||||
# NOTE: here we assume the bias of conv is not quantized!
|
||||
conv_args[2] = get_bias_node
|
||||
conv_node.args = tuple(conv_args)
|
||||
|
||||
# native_batch_norm has 3 outputs, we expect getitem calls on the output
|
||||
# and we want to replace the uses of getitem 0 with the output of conv
|
||||
#
|
||||
if bn_node.target == torch.ops.aten.batch_norm.default:
|
||||
# With the new training ir, instead of batch_norm + getitem,
|
||||
# we only have the batch_norm node.
|
||||
#
|
||||
# Before:
|
||||
# conv -> bn -> users
|
||||
# After:
|
||||
# conv -> users
|
||||
# bn has no users now
|
||||
bn_node.replace_all_uses_with(conv_node)
|
||||
else:
|
||||
# Before:
|
||||
# conv -> bn - (first output) -> users1
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# After:
|
||||
# conv -> (first output) -> users1
|
||||
# bn -
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# if users2 and users3 are empty then bn will be removed through dead code elimination
|
||||
for user in bn_node.users:
|
||||
if (
|
||||
user.op != "call_function"
|
||||
or user.target != operator.getitem
|
||||
or user.args[1] != 0
|
||||
):
|
||||
continue
|
||||
user.replace_all_uses_with(conv_node)
|
||||
|
||||
# If the BN node does not have users, erase it from the graph
|
||||
# Note: we need to do this manually because the model can still be in train
|
||||
# mode at this point, in which case DCE won't erase the BN node automatically
|
||||
# since the node refers to a mutating op. Here we still need to call DCE first
|
||||
# to get rid of the unused getitem nodes that consume the BN node.
|
||||
m.graph.eliminate_dead_code()
|
||||
if len(bn_node.users) == 0:
|
||||
m.graph.erase_node(bn_node)
|
||||
|
||||
|
||||
# fuse conv bn weights, inplace modification of the graph_module and graph
|
||||
def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return
|
||||
for n in m.graph.nodes:
|
||||
if n.op != "call_function" or n.target not in (
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default,
|
||||
torch.ops.aten.batch_norm.default,
|
||||
):
|
||||
continue
|
||||
bn_node = n
|
||||
n = bn_node.args[0]
|
||||
if not _is_conv_or_conv_transpose_node(n):
|
||||
continue
|
||||
conv_node = n
|
||||
conv_weight_node = conv_node.args[1]
|
||||
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
|
||||
fold_bn_weights_into_conv_node(
|
||||
conv_node, conv_weight_node, conv_bias_node, bn_node, m
|
||||
)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
|
||||
def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]:
|
||||
# TODO: move this information to fx node itself
|
||||
node_name_to_scope: dict[str, tuple[str, type]] = {}
|
||||
for n in model.graph.nodes:
|
||||
nn_module_stack = n.meta.get("nn_module_stack", None)
|
||||
current_scope = ("", type(None))
|
||||
if nn_module_stack:
|
||||
bt = list(nn_module_stack.values())[-1]
|
||||
current_scope = (bt[0].split(".")[-1], bt[1])
|
||||
node_name_to_scope[n.name] = current_scope
|
||||
return node_name_to_scope
|
||||
|
||||
|
||||
def _get_aten_graph_module_for_pattern(
|
||||
pattern: Callable,
|
||||
example_inputs: tuple[Any, ...],
|
||||
is_cuda: bool = False,
|
||||
**kwargs,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Convert the pattern to an FX graph with decomposed aten ops.
|
||||
"""
|
||||
if is_cuda:
|
||||
example_inputs = tuple(
|
||||
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
||||
)
|
||||
|
||||
aten_pattern = torch.export.export_for_training(
|
||||
pattern, # type: ignore[arg-type]
|
||||
example_inputs,
|
||||
kwargs,
|
||||
).module()
|
||||
|
||||
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
|
||||
aten_pattern.recompile() # type: ignore[operator]
|
||||
|
||||
# ep.module() adds copy_ nodes for the mutated inputs.
|
||||
# For patterns, it doesn't matter
|
||||
for node in aten_pattern.graph.nodes: # type: ignore[union-attr]
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.copy_.default
|
||||
and len(node.users) == 0
|
||||
):
|
||||
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]
|
||||
|
||||
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
|
||||
aten_pattern.recompile() # type: ignore[operator]
|
||||
|
||||
return aten_pattern # type: ignore[return-value]
|
||||
|
||||
|
||||
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||
"""Remove .tensor overload for quantize/dequantize ops so that we can
|
||||
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
|
||||
"""
|
||||
_MAP = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
|
||||
torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
|
||||
}
|
||||
for n in match_pattern.graph.nodes:
|
||||
if n.op != "call_function":
|
||||
continue
|
||||
if n.target in _MAP:
|
||||
n.target = _MAP[n.target]
|
||||
|
||||
|
||||
def _is_literal(arg):
|
||||
if isinstance(arg, (int, float)):
|
||||
return True
|
||||
if isinstance(arg, (tuple, list)):
|
||||
return all(map(_is_literal, arg))
|
||||
return False
|
||||
|
||||
|
||||
def _replace_literals_with_new_placeholders(
|
||||
gm: torch.fx.GraphModule,
|
||||
merge_dup: bool = False,
|
||||
exclude_literals: Optional[list[Any]] = None,
|
||||
):
|
||||
"""Replace the literals in the graph with placeholder nodes that's created on the fly while we
|
||||
traverse the graph, so that the literal arguments in the graph can be matched and replaced
|
||||
|
||||
To use this, the pattern and replacement graph should have the exact same number of literal args
|
||||
and they should be used in the exact same order in the pattern and replacement graph.
|
||||
|
||||
If the literal arguments are not used in the same order in pattern and replacement graph, please
|
||||
use `_replace_literals_with_existing_placeholders` instead
|
||||
|
||||
Args:
|
||||
`gm`: input GraphModule that we'll transform
|
||||
`merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
|
||||
the graph, whether they should correspond to the same placeholder or not
|
||||
`exclude_literals`: a list of literals that will not be replaced with placeholders
|
||||
|
||||
Example:
|
||||
|
||||
# 1. Original Graph
|
||||
def pattern(self, x):
|
||||
return x + 3
|
||||
|
||||
def replacement(self, x):
|
||||
return x - 3
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
|
||||
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
|
||||
|
||||
# 2. Before calling replace literals we'll see the following graph:
|
||||
def pattern(self, x):
|
||||
return x + 3
|
||||
|
||||
def replacement(self, x):
|
||||
return x - 3
|
||||
|
||||
pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
|
||||
replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
|
||||
|
||||
# 3. After replacing literals with new placeholder nodes
|
||||
|
||||
def pattern(self, x, new_ph):
|
||||
return x + new_ph
|
||||
|
||||
def pattern(self, x, new_ph):
|
||||
return x - new_ph
|
||||
|
||||
"""
|
||||
last_ph = None
|
||||
cnt = 0
|
||||
literal_to_ph: dict[Union[float, bool, int, torch.dtype], Node] = {}
|
||||
if exclude_literals is None:
|
||||
exclude_literals = []
|
||||
|
||||
in_spec = gm._in_spec
|
||||
args_spec = in_spec.children_specs[0]
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
last_ph = node
|
||||
cnt += 1
|
||||
continue
|
||||
with gm.graph.inserting_after(last_ph):
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if _is_literal(arg) and arg not in exclude_literals:
|
||||
if merge_dup and arg in literal_to_ph:
|
||||
new_args.append(literal_to_ph[arg])
|
||||
else:
|
||||
ph_node = gm.graph.placeholder("arg" + str(cnt))
|
||||
new_args.append(ph_node)
|
||||
args_spec.children_specs.append(LeafSpec())
|
||||
cnt += 1
|
||||
if merge_dup:
|
||||
literal_to_ph[arg] = ph_node
|
||||
else:
|
||||
new_args.append(arg)
|
||||
new_args = tuple(new_args)
|
||||
|
||||
node.args = new_args
|
||||
|
||||
# Update `num_nodes`, `num_leaves`, `num_children`.
|
||||
args_spec.__post_init__()
|
||||
in_spec.__post_init__()
|
||||
return gm
|
||||
|
||||
|
||||
def _replace_literals_with_existing_placeholders(
|
||||
gm: torch.fx.GraphModule,
|
||||
exclude_literals: Optional[list[Any]] = None,
|
||||
literal_to_ph_idx: Optional[dict[Union[float, int, bool, torch.dtype], int]] = None,
|
||||
):
|
||||
"""Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
|
||||
in the graph can be matched and replaced
|
||||
|
||||
To use this, all literal args in the graph should be unique and each of them should correspond
|
||||
to exactly one placeholder node
|
||||
|
||||
# 1. Original Graph
|
||||
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
|
||||
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
|
||||
|
||||
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
|
||||
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
|
||||
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(1, 3, 3, 3),
|
||||
1.0,
|
||||
0,
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
|
||||
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
|
||||
|
||||
# 2. Before calling replace literals we'll see the following graph:
|
||||
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
|
||||
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
||||
return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
|
||||
|
||||
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
|
||||
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
||||
x_i8 = torch.clamp(x_i8, -128, 127)
|
||||
return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
|
||||
|
||||
# Note that literal args appear in different order in pattern and replacement graph, so
|
||||
# we can't use _replace_literals_with_new_placeholders
|
||||
|
||||
literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
|
||||
pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
|
||||
replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
|
||||
|
||||
# 3. After replacing literals with existing placeholder nodes
|
||||
|
||||
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
|
||||
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
||||
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
|
||||
|
||||
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
|
||||
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
||||
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
|
||||
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
||||
"""
|
||||
if exclude_literals is None:
|
||||
exclude_literals = []
|
||||
|
||||
if literal_to_ph_idx is None:
|
||||
literal_to_ph_idx = {}
|
||||
|
||||
phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if (
|
||||
_is_literal(arg)
|
||||
and arg not in exclude_literals
|
||||
and arg in literal_to_ph_idx
|
||||
):
|
||||
ph_idx = literal_to_ph_idx[arg]
|
||||
ph_node = phs[ph_idx]
|
||||
new_args.append(ph_node)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
new_args = tuple(new_args)
|
||||
node.args = new_args
|
||||
return gm
|
||||
|
||||
|
||||
# TODO: Handle this in export itself and don't wrap the model in another GraphModule
|
||||
# in prepare and convert
|
||||
def _disallow_eval_train(model: GraphModule):
|
||||
"""
|
||||
Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
|
||||
This is useful for exported models, where these methods don't actually behave as expected.
|
||||
"""
|
||||
error_message = """
|
||||
Calling train() or eval() is not supported for exported models.
|
||||
Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.
|
||||
|
||||
If you cannot replace the calls to `model.train()` and `model.eval()`, you may override
|
||||
the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,
|
||||
which does the above automatically for you. Note that this has limited effect on switching
|
||||
behavior between train and eval modes, and should be used only for special ops such as dropout
|
||||
and batchnorm.
|
||||
"""
|
||||
|
||||
def _train(self, mode: bool = True):
|
||||
raise NotImplementedError(error_message)
|
||||
|
||||
def _eval(self, mode: bool = True):
|
||||
raise NotImplementedError(error_message)
|
||||
|
||||
model.train = types.MethodType(_train, model) # type: ignore[method-assign]
|
||||
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
|
||||
return model
|
Loading…
Add table
Add a link
Reference in a new issue