diff --git a/networks/dylora.py b/networks/dylora.py index b0925453c..00e786a1f 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -31,7 +31,7 @@ class DyLoRAModule(torch.nn.Module): """ # NOTE: support dropout in future - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1, rank_stabilized=False): super().__init__() self.lora_name = lora_name self.lora_dim = lora_dim @@ -48,7 +48,10 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_ if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim + rank_factor = self.lora_dim + if rank_stabilized: + rank_factor = math.sqrt(rank_factor) + self.scale = alpha / rank_factor self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える self.is_conv2d = org_module.__class__.__name__ == "Conv2d" @@ -285,6 +288,7 @@ def __init__( unit=1, module_class=DyLoRAModule, varbose=False, + rank_stabilized=False, ) -> None: super().__init__() self.multiplier = multiplier @@ -334,7 +338,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules continue # dropout and fan_in_fan_out is default - lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) + lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit, rank_stabilized=rank_stabilized) loras.append(lora) return loras diff --git a/networks/lora.py b/networks/lora.py index 6f33f1a1e..0b752d89b 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -37,6 +37,7 @@ def __init__( dropout=None, rank_dropout=None, module_dropout=None, + rank_stabilized=False ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() @@ -69,7 +70,10 @@ def __init__( if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim + rank_factor = self.lora_dim + if rank_stabilized: + rank_factor = math.sqrt(rank_factor) + self.scale = alpha / rank_factor self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's @@ -895,6 +899,7 @@ def __init__( module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, is_sdxl: Optional[bool] = False, + rank_stabilized: Optional[bool] = False ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -914,6 +919,7 @@ def __init__( self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.rank_stabilized = rank_stabilized self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -1011,6 +1017,7 @@ def create_modules( dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + rank_stabilized=rank_stabilized, ) loras.append(lora) return loras, skipped @@ -1385,7 +1392,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce8..65f36dbf1 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -37,6 +37,7 @@ def __init__( dropout=None, rank_dropout=None, module_dropout=None, + rank_stabilized=False ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() @@ -69,7 +70,10 @@ def __init__( if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim + rank_factor = self.lora_dim + if rank_stabilized: + rank_factor = math.sqrt(rank_factor) + self.scale = alpha / rank_factor self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # # same as microsoft's @@ -783,6 +787,7 @@ def __init__( modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, + rank_stabilized: Optional[bool] = False ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -802,6 +807,7 @@ def __init__( self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.rank_stabilized = rank_stabilized if modules_dim is not None: logger.info(f"create LoRA network from weights") @@ -889,6 +895,7 @@ def create_modules( dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + rank_stabilized=rank_stabilized ) loras.append(lora) return loras, skipped @@ -1219,7 +1226,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 91e9cd77f..ef272ed7b 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -44,6 +44,7 @@ def __init__( rank_dropout=None, module_dropout=None, split_dims: Optional[List[int]] = None, + rank_stabilized: Optional[bool] = False ): """ if alpha == 0 or None, alpha is rank (no scaling). @@ -93,7 +94,10 @@ def __init__( if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim + rank_factor = self.lora_dim + if rank_stabilized: + rank_factor = math.sqrt(rank_factor) + self.scale = alpha / rank_factor self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's @@ -562,6 +566,7 @@ def __init__( train_double_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, + rank_stabilized: Optional[bool] = False, ) -> None: super().__init__() self.multiplier = multiplier @@ -576,6 +581,7 @@ def __init__( self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.rank_stabilized = rank_stabilized self.type_dims = type_dims self.in_dims = in_dims @@ -722,6 +728,7 @@ def create_modules( rank_dropout=rank_dropout, module_dropout=module_dropout, split_dims=split_dims, + rank_stabilized=rank_stabilized, ) loras.append(lora) @@ -1132,7 +1139,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index ce6d1a16f..2838f9e8d 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -256,6 +256,7 @@ def __init__( emb_dims: Optional[List[int]] = None, train_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, + rank_stabilized: Optional[bool] = False ) -> None: super().__init__() self.multiplier = multiplier @@ -269,6 +270,7 @@ def __init__( self.module_dropout = module_dropout self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.rank_stabilized = rank_stabilized self.type_dims = type_dims self.emb_dims = emb_dims @@ -404,6 +406,7 @@ def create_modules( rank_dropout=rank_dropout, module_dropout=module_dropout, split_dims=split_dims, + rank_stabilized=rank_stabilized ) loras.append(lora) @@ -814,7 +817,10 @@ def apply_max_norm_regularization(self, max_norm_value, device): up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] - scale = alpha / dim + rank_factor = dim + if self.rank_stabilized: + rank_factor = math.sqrt(rank_factor) + scale = alpha / rank_factor if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)