163 lines
6.1 KiB
Python
163 lines
6.1 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import copy
|
|
from typing import TYPE_CHECKING, Dict, List, Union
|
|
|
|
from ..utils import logging
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
# import here to avoid circular imports
|
|
from ..models import UNet2DConditionModel
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
def _translate_into_actual_layer_name(name):
|
|
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
|
|
if name == "mid":
|
|
return "mid_block.attentions.0"
|
|
|
|
updown, block, attn = name.split(".")
|
|
|
|
updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
|
|
block = block.replace("block_", "")
|
|
attn = "attentions." + attn
|
|
|
|
return ".".join((updown, block, attn))
|
|
|
|
|
|
def _maybe_expand_lora_scales(
|
|
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
|
|
):
|
|
blocks_with_transformer = {
|
|
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
|
|
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
|
|
}
|
|
transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
|
|
|
|
expanded_weight_scales = [
|
|
_maybe_expand_lora_scales_for_one_adapter(
|
|
weight_for_adapter,
|
|
blocks_with_transformer,
|
|
transformer_per_block,
|
|
unet.state_dict(),
|
|
default_scale=default_scale,
|
|
)
|
|
for weight_for_adapter in weight_scales
|
|
]
|
|
|
|
return expanded_weight_scales
|
|
|
|
|
|
def _maybe_expand_lora_scales_for_one_adapter(
|
|
scales: Union[float, Dict],
|
|
blocks_with_transformer: Dict[str, int],
|
|
transformer_per_block: Dict[str, int],
|
|
state_dict: None,
|
|
default_scale: float = 1.0,
|
|
):
|
|
"""
|
|
Expands the inputs into a more granular dictionary. See the example below for more details.
|
|
|
|
Parameters:
|
|
scales (`Union[float, Dict]`):
|
|
Scales dict to expand.
|
|
blocks_with_transformer (`Dict[str, int]`):
|
|
Dict with keys 'up' and 'down', showing which blocks have transformer layers
|
|
transformer_per_block (`Dict[str, int]`):
|
|
Dict with keys 'up' and 'down', showing how many transformer layers each block has
|
|
|
|
E.g. turns
|
|
```python
|
|
scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
|
|
blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
|
|
transformer_per_block = {"down": 2, "up": 3}
|
|
```
|
|
into
|
|
```python
|
|
{
|
|
"down.block_1.0": 2,
|
|
"down.block_1.1": 2,
|
|
"down.block_2.0": 2,
|
|
"down.block_2.1": 2,
|
|
"mid": 3,
|
|
"up.block_0.0": 4,
|
|
"up.block_0.1": 4,
|
|
"up.block_0.2": 4,
|
|
"up.block_1.0": 5,
|
|
"up.block_1.1": 6,
|
|
"up.block_1.2": 7,
|
|
}
|
|
```
|
|
"""
|
|
if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
|
|
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
|
|
|
|
if sorted(transformer_per_block.keys()) != ["down", "up"]:
|
|
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
|
|
|
|
if not isinstance(scales, dict):
|
|
# don't expand if scales is a single number
|
|
return scales
|
|
|
|
scales = copy.deepcopy(scales)
|
|
|
|
if "mid" not in scales:
|
|
scales["mid"] = default_scale
|
|
elif isinstance(scales["mid"], list):
|
|
if len(scales["mid"]) == 1:
|
|
scales["mid"] = scales["mid"][0]
|
|
else:
|
|
raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
|
|
|
|
for updown in ["up", "down"]:
|
|
if updown not in scales:
|
|
scales[updown] = default_scale
|
|
|
|
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
|
|
if not isinstance(scales[updown], dict):
|
|
scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
|
|
|
|
# eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
|
|
for i in blocks_with_transformer[updown]:
|
|
block = f"block_{i}"
|
|
# set not assigned blocks to default scale
|
|
if block not in scales[updown]:
|
|
scales[updown][block] = default_scale
|
|
if not isinstance(scales[updown][block], list):
|
|
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
|
|
elif len(scales[updown][block]) == 1:
|
|
# a list specifying scale to each masked IP input
|
|
scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
|
|
elif len(scales[updown][block]) != transformer_per_block[updown]:
|
|
raise ValueError(
|
|
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
|
|
)
|
|
|
|
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
|
|
for i in blocks_with_transformer[updown]:
|
|
block = f"block_{i}"
|
|
for tf_idx, value in enumerate(scales[updown][block]):
|
|
scales[f"{updown}.{block}.{tf_idx}"] = value
|
|
|
|
del scales[updown]
|
|
|
|
for layer in scales.keys():
|
|
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
|
|
raise ValueError(
|
|
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
|
|
)
|
|
|
|
return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}
|