Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rank_stabilized for networks #1870

Open
wants to merge 2 commits into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DyLoRAModule(torch.nn.Module):
"""

# NOTE: support dropout in future
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1, rank_stabilized=False):
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
Expand All @@ -48,7 +48,10 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
Expand Down Expand Up @@ -285,6 +288,7 @@ def __init__(
unit=1,
module_class=DyLoRAModule,
varbose=False,
rank_stabilized=False,
) -> None:
super().__init__()
self.multiplier = multiplier
Expand Down Expand Up @@ -334,7 +338,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
continue

# dropout and fan_in_fan_out is default
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit, rank_stabilized=rank_stabilized)
loras.append(lora)
return loras

Expand Down
14 changes: 12 additions & 2 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
dropout=None,
rank_dropout=None,
module_dropout=None,
rank_stabilized=False
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
Expand Down Expand Up @@ -69,7 +70,10 @@ def __init__(
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# same as microsoft's
Expand Down Expand Up @@ -895,6 +899,7 @@ def __init__(
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
is_sdxl: Optional[bool] = False,
rank_stabilized: Optional[bool] = False
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
Expand All @@ -914,6 +919,7 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.rank_stabilized = rank_stabilized

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
Expand Down Expand Up @@ -1011,6 +1017,7 @@ def create_modules(
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
rank_stabilized=rank_stabilized,
)
loras.append(lora)
return loras, skipped
Expand Down Expand Up @@ -1385,7 +1392,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down
14 changes: 12 additions & 2 deletions networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
dropout=None,
rank_dropout=None,
module_dropout=None,
rank_stabilized=False
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
Expand Down Expand Up @@ -69,7 +70,10 @@ def __init__(
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# # same as microsoft's
Expand Down Expand Up @@ -783,6 +787,7 @@ def __init__(
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
rank_stabilized: Optional[bool] = False
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
Expand All @@ -802,6 +807,7 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.rank_stabilized = rank_stabilized

if modules_dim is not None:
logger.info(f"create LoRA network from weights")
Expand Down Expand Up @@ -889,6 +895,7 @@ def create_modules(
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
rank_stabilized=rank_stabilized
)
loras.append(lora)
return loras, skipped
Expand Down Expand Up @@ -1219,7 +1226,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down
14 changes: 12 additions & 2 deletions networks/lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
rank_dropout=None,
module_dropout=None,
split_dims: Optional[List[int]] = None,
rank_stabilized: Optional[bool] = False
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
Expand Down Expand Up @@ -93,7 +94,10 @@ def __init__(
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# same as microsoft's
Expand Down Expand Up @@ -562,6 +566,7 @@ def __init__(
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
rank_stabilized: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
Expand All @@ -576,6 +581,7 @@ def __init__(
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.rank_stabilized = rank_stabilized

self.type_dims = type_dims
self.in_dims = in_dims
Expand Down Expand Up @@ -722,6 +728,7 @@ def create_modules(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
rank_stabilized=rank_stabilized,
)
loras.append(lora)

Expand Down Expand Up @@ -1132,7 +1139,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down
8 changes: 7 additions & 1 deletion networks/lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
rank_stabilized: Optional[bool] = False
) -> None:
super().__init__()
self.multiplier = multiplier
Expand All @@ -269,6 +270,7 @@ def __init__(
self.module_dropout = module_dropout
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.rank_stabilized = rank_stabilized

self.type_dims = type_dims
self.emb_dims = emb_dims
Expand Down Expand Up @@ -404,6 +406,7 @@ def create_modules(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
rank_stabilized=rank_stabilized
)
loras.append(lora)

Expand Down Expand Up @@ -814,7 +817,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down
Loading