mirror of
https://github.com/huggingface/diffusers.git
synced 2026-05-28 00:39:35 +08:00
This commit is contained in:
@@ -150,6 +150,56 @@ def unscale_lora_layers(model, weight: float | None = None):
|
||||
module.set_scale(adapter_name, 1.0)
|
||||
|
||||
|
||||
def get_peft_kwargs(
|
||||
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
||||
):
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occurring the most number of times
|
||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with rank different from the most occurring rank, add it to the `rank_pattern`
|
||||
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
||||
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
||||
|
||||
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
||||
if len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occurring the most number of times
|
||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
if is_unet:
|
||||
alpha_pattern = {
|
||||
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
|
||||
for k, v in alpha_pattern.items()
|
||||
}
|
||||
else:
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
else:
|
||||
lora_alpha = set(network_alpha_dict.values()).pop()
|
||||
|
||||
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
||||
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
||||
# for now we know that the "bias" keys are only associated with `lora_B`.
|
||||
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
|
||||
|
||||
lora_config_kwargs = {
|
||||
"r": r,
|
||||
"lora_alpha": lora_alpha,
|
||||
"rank_pattern": rank_pattern,
|
||||
"alpha_pattern": alpha_pattern,
|
||||
"target_modules": target_modules,
|
||||
"use_dora": use_dora,
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
def get_adapter_name(model):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
@@ -294,55 +344,6 @@ def check_peft_version(min_version: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_peft_kwargs(
|
||||
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
||||
):
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occurring the most number of times
|
||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with rank different from the most occurring rank, add it to the `rank_pattern`
|
||||
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
||||
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
||||
|
||||
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
||||
if len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occurring the most number of times
|
||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
if is_unet:
|
||||
alpha_pattern = {
|
||||
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
|
||||
for k, v in alpha_pattern.items()
|
||||
}
|
||||
else:
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
else:
|
||||
lora_alpha = set(network_alpha_dict.values()).pop()
|
||||
|
||||
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
||||
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
||||
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
|
||||
|
||||
lora_config_kwargs = {
|
||||
"r": r,
|
||||
"lora_alpha": lora_alpha,
|
||||
"rank_pattern": rank_pattern,
|
||||
"alpha_pattern": alpha_pattern,
|
||||
"target_modules": target_modules,
|
||||
"use_dora": use_dora,
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
def _create_lora_config(
|
||||
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
||||
):
|
||||
@@ -362,6 +363,7 @@ def _create_lora_config(
|
||||
|
||||
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
|
||||
|
||||
# Version checks for DoRA and lora_bias
|
||||
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
|
||||
@@ -381,6 +383,11 @@ def _maybe_raise_error_for_ambiguous_keys(config):
|
||||
target_modules = config["target_modules"]
|
||||
|
||||
for key in list(rank_pattern.keys()):
|
||||
# try to detect ambiguity
|
||||
# `target_modules` can also be a str, in which case this loop would loop
|
||||
# over the chars of the str. The technically correct way to match LoRA keys
|
||||
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
||||
# But this cuts it for now.
|
||||
exact_matches = [mod for mod in target_modules if mod == key]
|
||||
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
||||
|
||||
@@ -394,6 +401,7 @@ def _maybe_raise_error_for_ambiguous_keys(config):
|
||||
def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
@@ -403,6 +411,7 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
|
||||
Reference in New Issue
Block a user