From 4ae674f7074ac5fa4e54107db3431ffdda8e98e3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 10 Jan 2025 19:19:33 -0500 Subject: [PATCH] Add rank-stabilized to apply_max_norm_regularization --- networks/lora.py | 6 +++++- networks/lora_fa.py | 6 +++++- networks/lora_flux.py | 8 +++++++- networks/lora_sd3.py | 6 +++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index d70368379..0b752d89b 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -919,6 +919,7 @@ def __init__( self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.rank_stabilized = rank_stabilized self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -1391,7 +1392,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 067783c9b..65f36dbf1 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -807,6 +807,7 @@ def __init__( self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.rank_stabilized = rank_stabilized if modules_dim is not None: logger.info(f"create LoRA network from weights") @@ -1225,7 +1226,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 8f11cf784..ef272ed7b 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -566,6 +566,7 @@ def __init__( train_double_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, + rank_stabilized: Optional[bool] = False, ) -> None: super().__init__() self.multiplier = multiplier @@ -580,6 +581,7 @@ def __init__( self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.rank_stabilized = rank_stabilized self.type_dims = type_dims self.in_dims = in_dims @@ -726,6 +728,7 @@ def create_modules( rank_dropout=rank_dropout, module_dropout=module_dropout, split_dims=split_dims, + rank_stabilized=rank_stabilized, ) loras.append(lora) @@ -1136,7 +1139,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index b9b0621a6..2838f9e8d 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -270,6 +270,7 @@ def __init__( self.module_dropout = module_dropout self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.rank_stabilized = rank_stabilized self.type_dims = type_dims self.emb_dims = emb_dims @@ -816,7 +817,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)