319 lines
13 KiB
Python
319 lines
13 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Adapted from
|
|
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py
|
|
"""
|
|
|
|
import inspect
|
|
from inspect import signature
|
|
from typing import Union
|
|
|
|
from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
|
|
from ..quantization_config import QuantizationMethod
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
if is_bitsandbytes_available():
|
|
import bitsandbytes as bnb
|
|
|
|
if is_accelerate_available():
|
|
import accelerate
|
|
from accelerate import init_empty_weights
|
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def _replace_with_bnb_linear(
|
|
model,
|
|
modules_to_not_convert=None,
|
|
current_key_name=None,
|
|
quantization_config=None,
|
|
has_been_replaced=False,
|
|
):
|
|
"""
|
|
Private method that wraps the recursion for module replacement.
|
|
|
|
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
|
|
"""
|
|
for name, module in model.named_children():
|
|
if current_key_name is None:
|
|
current_key_name = []
|
|
current_key_name.append(name)
|
|
|
|
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
|
# Check if the current key is not in the `modules_to_not_convert`
|
|
current_key_name_str = ".".join(current_key_name)
|
|
if not any(
|
|
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
|
):
|
|
with init_empty_weights():
|
|
in_features = module.in_features
|
|
out_features = module.out_features
|
|
|
|
if quantization_config.quantization_method() == "llm_int8":
|
|
model._modules[name] = bnb.nn.Linear8bitLt(
|
|
in_features,
|
|
out_features,
|
|
module.bias is not None,
|
|
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
|
threshold=quantization_config.llm_int8_threshold,
|
|
)
|
|
has_been_replaced = True
|
|
else:
|
|
if (
|
|
quantization_config.llm_int8_skip_modules is not None
|
|
and name in quantization_config.llm_int8_skip_modules
|
|
):
|
|
pass
|
|
else:
|
|
extra_kwargs = (
|
|
{"quant_storage": quantization_config.bnb_4bit_quant_storage}
|
|
if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
|
|
else {}
|
|
)
|
|
model._modules[name] = bnb.nn.Linear4bit(
|
|
in_features,
|
|
out_features,
|
|
module.bias is not None,
|
|
quantization_config.bnb_4bit_compute_dtype,
|
|
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
|
quant_type=quantization_config.bnb_4bit_quant_type,
|
|
**extra_kwargs,
|
|
)
|
|
has_been_replaced = True
|
|
# Store the module class in case we need to transpose the weight later
|
|
model._modules[name].source_cls = type(module)
|
|
# Force requires grad to False to avoid unexpected errors
|
|
model._modules[name].requires_grad_(False)
|
|
if len(list(module.children())) > 0:
|
|
_, has_been_replaced = _replace_with_bnb_linear(
|
|
module,
|
|
modules_to_not_convert,
|
|
current_key_name,
|
|
quantization_config,
|
|
has_been_replaced=has_been_replaced,
|
|
)
|
|
# Remove the last key for recursion
|
|
current_key_name.pop(-1)
|
|
return model, has_been_replaced
|
|
|
|
|
|
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
|
"""
|
|
Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or
|
|
`bnb.nn.Linear4bit` using the `bitsandbytes` library.
|
|
|
|
References:
|
|
* `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at
|
|
Scale](https://huggingface.co/papers/2208.07339)
|
|
* `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized
|
|
LLMs](https://huggingface.co/papers/2305.14314)
|
|
|
|
Parameters:
|
|
model (`torch.nn.Module`):
|
|
Input model or `torch.nn.Module` as the function is run recursively.
|
|
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`):
|
|
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in
|
|
full precision for numerical stability reasons.
|
|
current_key_name (`List[`str`]`, *optional*):
|
|
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
|
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
|
|
`disk`).
|
|
quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'):
|
|
To configure and manage settings related to quantization, a technique used to compress neural network
|
|
models by reducing the precision of the weights and activations, thus making models more efficient in terms
|
|
of both storage and computation.
|
|
"""
|
|
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
|
|
|
|
has_been_replaced = any(
|
|
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
|
|
for _, replaced_module in model.named_modules()
|
|
)
|
|
if not has_been_replaced:
|
|
logger.warning(
|
|
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
|
" Please double check your model architecture, or submit an issue on github if you think this is"
|
|
" a bug."
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
|
|
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
|
|
"""
|
|
Helper function to dequantize 4bit or 8bit bnb weights.
|
|
|
|
If the weight is not a bnb quantized weight, it will be returned as is.
|
|
"""
|
|
if not isinstance(weight, torch.nn.Parameter):
|
|
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
|
|
|
|
cls_name = weight.__class__.__name__
|
|
if cls_name not in ("Params4bit", "Int8Params"):
|
|
return weight
|
|
|
|
if cls_name == "Params4bit":
|
|
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
|
|
msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
|
|
if dtype:
|
|
msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
|
|
output_tensor = output_tensor.to(dtype)
|
|
logger.warning_once(msg)
|
|
return output_tensor
|
|
|
|
if state.SCB is None:
|
|
state.SCB = weight.SCB
|
|
|
|
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
|
|
# Use bitsandbytes API if available (requires v0.45.0+)
|
|
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
|
|
else:
|
|
# Multiply by (scale/127) to dequantize.
|
|
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
|
|
|
|
if dtype:
|
|
dequantized = dequantized.to(dtype)
|
|
return dequantized
|
|
|
|
|
|
def _create_accelerate_new_hook(old_hook):
|
|
r"""
|
|
Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
|
|
https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
|
|
some changes
|
|
"""
|
|
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
|
|
old_hook_attr = old_hook.__dict__
|
|
filtered_old_hook_attr = {}
|
|
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
|
|
for k in old_hook_attr.keys():
|
|
if k in old_hook_init_signature.parameters:
|
|
filtered_old_hook_attr[k] = old_hook_attr[k]
|
|
new_hook = old_hook_cls(**filtered_old_hook_attr)
|
|
return new_hook
|
|
|
|
|
|
def _dequantize_and_replace(
|
|
model,
|
|
dtype,
|
|
modules_to_not_convert=None,
|
|
current_key_name=None,
|
|
quantization_config=None,
|
|
has_been_replaced=False,
|
|
):
|
|
"""
|
|
Converts a quantized model into its dequantized original version. The newly converted model will have some
|
|
performance drop compared to the original model before quantization - use it only for specific usecases such as
|
|
QLoRA adapters merging.
|
|
|
|
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
|
|
"""
|
|
quant_method = quantization_config.quantization_method()
|
|
|
|
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
|
|
|
|
for name, module in model.named_children():
|
|
if current_key_name is None:
|
|
current_key_name = []
|
|
current_key_name.append(name)
|
|
|
|
if isinstance(module, target_cls) and name not in modules_to_not_convert:
|
|
# Check if the current key is not in the `modules_to_not_convert`
|
|
current_key_name_str = ".".join(current_key_name)
|
|
|
|
if not any(
|
|
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
|
):
|
|
bias = getattr(module, "bias", None)
|
|
|
|
device = module.weight.device
|
|
with init_empty_weights():
|
|
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
|
|
|
|
if quant_method == "llm_int8":
|
|
state = module.state
|
|
else:
|
|
state = None
|
|
|
|
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))
|
|
|
|
if bias is not None:
|
|
new_module.bias = bias
|
|
|
|
# Create a new hook and attach it in case we use accelerate
|
|
if hasattr(module, "_hf_hook"):
|
|
old_hook = module._hf_hook
|
|
new_hook = _create_accelerate_new_hook(old_hook)
|
|
|
|
remove_hook_from_module(module)
|
|
add_hook_to_module(new_module, new_hook)
|
|
|
|
new_module.to(device)
|
|
model._modules[name] = new_module
|
|
has_been_replaced = True
|
|
if len(list(module.children())) > 0:
|
|
_, has_been_replaced = _dequantize_and_replace(
|
|
module,
|
|
dtype=dtype,
|
|
modules_to_not_convert=modules_to_not_convert,
|
|
current_key_name=current_key_name,
|
|
quantization_config=quantization_config,
|
|
has_been_replaced=has_been_replaced,
|
|
)
|
|
# Remove the last key for recursion
|
|
current_key_name.pop(-1)
|
|
return model, has_been_replaced
|
|
|
|
|
|
def dequantize_and_replace(
|
|
model,
|
|
modules_to_not_convert=None,
|
|
quantization_config=None,
|
|
):
|
|
model, _ = _dequantize_and_replace(
|
|
model,
|
|
dtype=model.dtype,
|
|
modules_to_not_convert=modules_to_not_convert,
|
|
quantization_config=quantization_config,
|
|
)
|
|
has_been_replaced = any(
|
|
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
|
|
)
|
|
if not has_been_replaced:
|
|
logger.warning(
|
|
"Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def _check_bnb_status(module) -> Union[bool, bool]:
|
|
is_loaded_in_4bit_bnb = (
|
|
hasattr(module, "is_loaded_in_4bit")
|
|
and module.is_loaded_in_4bit
|
|
and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
|
)
|
|
is_loaded_in_8bit_bnb = (
|
|
hasattr(module, "is_loaded_in_8bit")
|
|
and module.is_loaded_in_8bit
|
|
and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
|
)
|
|
return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb
|