373 lines
14 KiB
Python
373 lines
14 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
PEFT utilities: Utilities related to peft library
|
|
"""
|
|
|
|
import collections
|
|
import importlib
|
|
from typing import Optional
|
|
|
|
from packaging import version
|
|
|
|
from . import logging
|
|
from .import_utils import is_peft_available, is_peft_version, is_torch_available
|
|
from .torch_utils import empty_device_cache
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
|
|
def recurse_remove_peft_layers(model):
|
|
r"""
|
|
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
|
|
"""
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
has_base_layer_pattern = False
|
|
for module in model.modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
has_base_layer_pattern = hasattr(module, "base_layer")
|
|
break
|
|
|
|
if has_base_layer_pattern:
|
|
from peft.utils import _get_submodules
|
|
|
|
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
|
|
for key in key_list:
|
|
try:
|
|
parent, target, target_name = _get_submodules(model, key)
|
|
except AttributeError:
|
|
continue
|
|
if hasattr(target, "base_layer"):
|
|
setattr(parent, target_name, target.get_base_layer())
|
|
else:
|
|
# This is for backwards compatibility with PEFT <= 0.6.2.
|
|
# TODO can be removed once that PEFT version is no longer supported.
|
|
from peft.tuners.lora import LoraLayer
|
|
|
|
for name, module in model.named_children():
|
|
if len(list(module.children())) > 0:
|
|
## compound module, go inside it
|
|
recurse_remove_peft_layers(module)
|
|
|
|
module_replaced = False
|
|
|
|
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
|
new_module = torch.nn.Linear(
|
|
module.in_features,
|
|
module.out_features,
|
|
bias=module.bias is not None,
|
|
).to(module.weight.device)
|
|
new_module.weight = module.weight
|
|
if module.bias is not None:
|
|
new_module.bias = module.bias
|
|
|
|
module_replaced = True
|
|
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
|
new_module = torch.nn.Conv2d(
|
|
module.in_channels,
|
|
module.out_channels,
|
|
module.kernel_size,
|
|
module.stride,
|
|
module.padding,
|
|
module.dilation,
|
|
module.groups,
|
|
).to(module.weight.device)
|
|
|
|
new_module.weight = module.weight
|
|
if module.bias is not None:
|
|
new_module.bias = module.bias
|
|
|
|
module_replaced = True
|
|
|
|
if module_replaced:
|
|
setattr(model, name, new_module)
|
|
del module
|
|
|
|
empty_device_cache()
|
|
return model
|
|
|
|
|
|
def scale_lora_layers(model, weight):
|
|
"""
|
|
Adjust the weightage given to the LoRA layers of the model.
|
|
|
|
Args:
|
|
model (`torch.nn.Module`):
|
|
The model to scale.
|
|
weight (`float`):
|
|
The weight to be given to the LoRA layers.
|
|
"""
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
if weight == 1.0:
|
|
return
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
module.scale_layer(weight)
|
|
|
|
|
|
def unscale_lora_layers(model, weight: Optional[float] = None):
|
|
"""
|
|
Removes the previously passed weight given to the LoRA layers of the model.
|
|
|
|
Args:
|
|
model (`torch.nn.Module`):
|
|
The model to scale.
|
|
weight (`float`, *optional*):
|
|
The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be
|
|
re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct
|
|
value.
|
|
"""
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
if weight is None or weight == 1.0:
|
|
return
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
if weight != 0:
|
|
module.unscale_layer(weight)
|
|
else:
|
|
for adapter_name in module.active_adapters:
|
|
# if weight == 0 unscale should re-set the scale to the original value.
|
|
module.set_scale(adapter_name, 1.0)
|
|
|
|
|
|
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
|
|
rank_pattern = {}
|
|
alpha_pattern = {}
|
|
r = lora_alpha = list(rank_dict.values())[0]
|
|
|
|
if len(set(rank_dict.values())) > 1:
|
|
# get the rank occurring the most number of times
|
|
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
|
|
|
# for modules with rank different from the most occurring rank, add it to the `rank_pattern`
|
|
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
|
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
|
|
|
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
|
if len(set(network_alpha_dict.values())) > 1:
|
|
# get the alpha occurring the most number of times
|
|
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
|
|
|
# for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
|
|
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
|
if is_unet:
|
|
alpha_pattern = {
|
|
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
|
|
for k, v in alpha_pattern.items()
|
|
}
|
|
else:
|
|
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
|
else:
|
|
lora_alpha = set(network_alpha_dict.values()).pop()
|
|
|
|
# layer names without the Diffusers specific
|
|
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
|
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
|
# for now we know that the "bias" keys are only associated with `lora_B`.
|
|
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
|
|
|
|
lora_config_kwargs = {
|
|
"r": r,
|
|
"lora_alpha": lora_alpha,
|
|
"rank_pattern": rank_pattern,
|
|
"alpha_pattern": alpha_pattern,
|
|
"target_modules": target_modules,
|
|
"use_dora": use_dora,
|
|
"lora_bias": lora_bias,
|
|
}
|
|
return lora_config_kwargs
|
|
|
|
|
|
def get_adapter_name(model):
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
return f"default_{len(module.r)}"
|
|
return "default_0"
|
|
|
|
|
|
def set_adapter_layers(model, enabled=True):
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
# The recent version of PEFT needs to call `enable_adapters` instead
|
|
if hasattr(module, "enable_adapters"):
|
|
module.enable_adapters(enabled=enabled)
|
|
else:
|
|
module.disable_adapters = not enabled
|
|
|
|
|
|
def delete_adapter_layers(model, adapter_name):
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
if hasattr(module, "delete_adapter"):
|
|
module.delete_adapter(adapter_name)
|
|
else:
|
|
raise ValueError(
|
|
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
|
|
)
|
|
|
|
# For transformers integration - we need to pop the adapter from the config
|
|
if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
|
|
model.peft_config.pop(adapter_name, None)
|
|
# In case all adapters are deleted, we need to delete the config
|
|
# and make sure to set the flag to False
|
|
if len(model.peft_config) == 0:
|
|
del model.peft_config
|
|
model._hf_peft_config_loaded = None
|
|
|
|
|
|
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
|
|
def get_module_weight(weight_for_adapter, module_name):
|
|
if not isinstance(weight_for_adapter, dict):
|
|
# If weight_for_adapter is a single number, always return it.
|
|
return weight_for_adapter
|
|
|
|
for layer_name, weight_ in weight_for_adapter.items():
|
|
if layer_name in module_name:
|
|
return weight_
|
|
|
|
parts = module_name.split(".")
|
|
# e.g. key = "down_blocks.1.attentions.0"
|
|
key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
|
|
block_weight = weight_for_adapter.get(key, 1.0)
|
|
|
|
return block_weight
|
|
|
|
for module_name, module in model.named_modules():
|
|
if isinstance(module, BaseTunerLayer):
|
|
# For backward compatibility with previous PEFT versions, set multiple active adapters
|
|
if hasattr(module, "set_adapter"):
|
|
module.set_adapter(adapter_names)
|
|
else:
|
|
module.active_adapter = adapter_names
|
|
|
|
# Set the scaling weight for each adapter for this module
|
|
for adapter_name, weight in zip(adapter_names, weights):
|
|
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
|
|
|
|
|
def check_peft_version(min_version: str) -> None:
|
|
r"""
|
|
Checks if the version of PEFT is compatible.
|
|
|
|
Args:
|
|
version (`str`):
|
|
The version of PEFT to check against.
|
|
"""
|
|
if not is_peft_available():
|
|
raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
|
|
|
|
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)
|
|
|
|
if not is_peft_version_compatible:
|
|
raise ValueError(
|
|
f"The version of PEFT you are using is not compatible, please use a version that is greater"
|
|
f" than {min_version}"
|
|
)
|
|
|
|
|
|
def _create_lora_config(
|
|
state_dict,
|
|
network_alphas,
|
|
metadata,
|
|
rank_pattern_dict,
|
|
is_unet: bool = True,
|
|
):
|
|
from peft import LoraConfig
|
|
|
|
if metadata is not None:
|
|
lora_config_kwargs = metadata
|
|
else:
|
|
lora_config_kwargs = get_peft_kwargs(
|
|
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
|
|
)
|
|
|
|
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
|
|
|
|
# Version checks for DoRA and lora_bias
|
|
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
|
|
if is_peft_version("<", "0.9.0"):
|
|
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
|
|
|
|
if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
|
|
if is_peft_version("<=", "0.13.2"):
|
|
raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")
|
|
|
|
try:
|
|
return LoraConfig(**lora_config_kwargs)
|
|
except TypeError as e:
|
|
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
|
|
|
|
|
def _maybe_raise_error_for_ambiguous_keys(config):
|
|
rank_pattern = config["rank_pattern"].copy()
|
|
target_modules = config["target_modules"]
|
|
|
|
for key in list(rank_pattern.keys()):
|
|
# try to detect ambiguity
|
|
# `target_modules` can also be a str, in which case this loop would loop
|
|
# over the chars of the str. The technically correct way to match LoRA keys
|
|
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
|
# But this cuts it for now.
|
|
exact_matches = [mod for mod in target_modules if mod == key]
|
|
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
|
|
|
if exact_matches and substring_matches:
|
|
if is_peft_version("<", "0.14.1"):
|
|
raise ValueError(
|
|
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
|
)
|
|
|
|
|
|
def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
|
warn_msg = ""
|
|
if incompatible_keys is not None:
|
|
# Check only for unexpected keys.
|
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
|
if unexpected_keys:
|
|
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
|
if lora_unexpected_keys:
|
|
warn_msg = (
|
|
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
|
f" {', '.join(lora_unexpected_keys)}. "
|
|
)
|
|
|
|
# Filter missing keys specific to the current adapter.
|
|
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
|
if missing_keys:
|
|
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
|
if lora_missing_keys:
|
|
warn_msg += (
|
|
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
|
f" {', '.join(lora_missing_keys)}."
|
|
)
|
|
|
|
if warn_msg:
|
|
logger.warning(warn_msg)
|