update
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled

This commit is contained in:
DN6
2026-05-23 01:12:11 +05:30
parent 5eaaa3f0ea
commit 774807bb77

View File

@@ -150,6 +150,56 @@ def unscale_lora_layers(model, weight: float | None = None):
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occurring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
# for modules with rank different from the most occurring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occurring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
if is_unet:
alpha_pattern = {
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
for k, v in alpha_pattern.items()
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
else:
lora_alpha = set(network_alpha_dict.values()).pop()
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
lora_config_kwargs = {
"r": r,
"lora_alpha": lora_alpha,
"rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
"use_dora": use_dora,
"lora_bias": lora_bias,
}
return lora_config_kwargs
def get_adapter_name(model):
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -294,55 +344,6 @@ def check_peft_version(min_version: str) -> None:
)
def get_peft_kwargs(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occurring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
# for modules with rank different from the most occurring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occurring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
if is_unet:
alpha_pattern = {
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
for k, v in alpha_pattern.items()
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
else:
lora_alpha = set(network_alpha_dict.values()).pop()
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
lora_config_kwargs = {
"r": r,
"lora_alpha": lora_alpha,
"rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
"use_dora": use_dora,
"lora_bias": lora_bias,
}
return lora_config_kwargs
def _create_lora_config(
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
@@ -362,6 +363,7 @@ def _create_lora_config(
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
# Version checks for DoRA and lora_bias
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
@@ -381,6 +383,11 @@ def _maybe_raise_error_for_ambiguous_keys(config):
target_modules = config["target_modules"]
for key in list(rank_pattern.keys()):
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
@@ -394,6 +401,7 @@ def _maybe_raise_error_for_ambiguous_keys(config):
def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
@@ -403,6 +411,7 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]