Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
68
venv/Lib/site-packages/torch/ao/quantization/stubs.py
Normal file
68
venv/Lib/site-packages/torch/ao/quantization/stubs.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class QuantStub(nn.Module):
|
||||
r"""Quantize stub module, before calibration, this is same as an observer,
|
||||
it will be swapped as `nnq.Quantize` in `convert`.
|
||||
|
||||
Args:
|
||||
qconfig: quantization configuration for the tensor,
|
||||
if qconfig is not provided, we will get qconfig from parent modules
|
||||
"""
|
||||
|
||||
def __init__(self, qconfig=None):
|
||||
super().__init__()
|
||||
if qconfig:
|
||||
self.qconfig = qconfig
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class DeQuantStub(nn.Module):
|
||||
r"""Dequantize stub module, before calibration, this is same as identity,
|
||||
this will be swapped as `nnq.DeQuantize` in `convert`.
|
||||
|
||||
Args:
|
||||
qconfig: quantization configuration for the tensor,
|
||||
if qconfig is not provided, we will get qconfig from parent modules
|
||||
"""
|
||||
|
||||
def __init__(self, qconfig=None):
|
||||
super().__init__()
|
||||
if qconfig:
|
||||
self.qconfig = qconfig
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class QuantWrapper(nn.Module):
|
||||
r"""A wrapper class that wraps the input module, adds QuantStub and
|
||||
DeQuantStub and surround the call to module with call to quant and dequant
|
||||
modules.
|
||||
|
||||
This is used by the `quantization` utility functions to add the quant and
|
||||
dequant modules, before `convert` function `QuantStub` will just be observer,
|
||||
it observes the input tensor, after `convert`, `QuantStub`
|
||||
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
|
||||
for `DeQuantStub`.
|
||||
"""
|
||||
quant: QuantStub
|
||||
dequant: DeQuantStub
|
||||
module: nn.Module
|
||||
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
qconfig = getattr(module, "qconfig", None)
|
||||
self.add_module("quant", QuantStub(qconfig))
|
||||
self.add_module("dequant", DeQuantStub(qconfig))
|
||||
self.add_module("module", module)
|
||||
self.train(module.training)
|
||||
|
||||
def forward(self, X):
|
||||
X = self.quant(X)
|
||||
X = self.module(X)
|
||||
return self.dequant(X)
|
Loading…
Add table
Add a link
Reference in a new issue