Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scale map to apply_max_norm_regularization #1873

Draft
wants to merge 1 commit into
base: sd3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4537,6 +4537,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
ignore_nesting_dict[section_name] = section_dict
continue

if section_name == "scale_weight_norms_map":
ignore_nesting_dict[section_name] = section_dict
continue

# if value is dict, save all key and value into one dict
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
Expand Down
7 changes: 6 additions & 1 deletion networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ def pre_calculation(self):
org_module._lora_restored = False
lora.enabled = False

def apply_max_norm_regularization(self, max_norm_value, device):
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
downkeys = []
upkeys = []
alphakeys = []
Expand All @@ -1381,6 +1381,11 @@ def apply_max_norm_regularization(self, max_norm_value, device):
alphakeys.append(key.replace("lora_down.weight", "alpha"))

for i in range(len(downkeys)):
max_norm_value = max_norm
for key in scale_map.keys():
if key in downkeys[i]:
max_norm_value = scale_map[key]

down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
Expand Down
19 changes: 18 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Any, List
import toml

import ast

from tqdm import tqdm

import torch
Expand Down Expand Up @@ -1260,8 +1262,9 @@ def remove_model(old_ckpt_name):
optimizer.zero_grad(set_to_none=True)

if args.scale_weight_norms:
scale_map = args.scale_weight_norms_map if args.scale_weight_norms_map else {}
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
args.scale_weight_norms, accelerator.device, scale_map=scale_map
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
Expand Down Expand Up @@ -1356,6 +1359,14 @@ def remove_model(old_ckpt_name):

logger.info("model saved.")

def parse_dict(input_str):
"""Convert string input into a dictionary."""
try:
# Use ast.literal_eval to safely evaluate the string as a Python literal (dict)
return ast.literal_eval(input_str)
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {input_str}")


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -1458,6 +1469,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
)
parser.add_argument(
"--scale_weight_norms_map",
type=parse_dict,
default="{}",
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
)
parser.add_argument(
"--base_weights",
type=str,
Expand Down
Loading