Skip to content

Commit

Permalink
Add rank_stabilized for networks
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 10, 2025
1 parent e896539 commit 6e3d33b
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 6 deletions.
10 changes: 7 additions & 3 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DyLoRAModule(torch.nn.Module):
"""

# NOTE: support dropout in future
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1, rank_stabilized=False):
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
Expand All @@ -48,7 +48,10 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
Expand Down Expand Up @@ -285,6 +288,7 @@ def __init__(
unit=1,
module_class=DyLoRAModule,
varbose=False,
rank_stabilized=False,
) -> None:
super().__init__()
self.multiplier = multiplier
Expand Down Expand Up @@ -334,7 +338,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
continue

# dropout and fan_in_fan_out is default
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit, rank_stabilized=rank_stabilized)
loras.append(lora)
return loras

Expand Down
8 changes: 7 additions & 1 deletion networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
dropout=None,
rank_dropout=None,
module_dropout=None,
rank_stabilized=False
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
Expand Down Expand Up @@ -69,7 +70,10 @@ def __init__(
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# same as microsoft's
Expand Down Expand Up @@ -895,6 +899,7 @@ def __init__(
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
is_sdxl: Optional[bool] = False,
rank_stabilized: Optional[bool] = False
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
Expand Down Expand Up @@ -1011,6 +1016,7 @@ def create_modules(
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
rank_stabilized=rank_stabilized,
)
loras.append(lora)
return loras, skipped
Expand Down
8 changes: 7 additions & 1 deletion networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
dropout=None,
rank_dropout=None,
module_dropout=None,
rank_stabilized=False
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
Expand Down Expand Up @@ -69,7 +70,10 @@ def __init__(
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# # same as microsoft's
Expand Down Expand Up @@ -783,6 +787,7 @@ def __init__(
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
rank_stabilized: Optional[bool] = False
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
Expand Down Expand Up @@ -889,6 +894,7 @@ def create_modules(
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
rank_stabilized=rank_stabilized
)
loras.append(lora)
return loras, skipped
Expand Down
6 changes: 5 additions & 1 deletion networks/lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
rank_dropout=None,
module_dropout=None,
split_dims: Optional[List[int]] = None,
rank_stabilized: Optional[bool] = False
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
Expand Down Expand Up @@ -93,7 +94,10 @@ def __init__(
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# same as microsoft's
Expand Down
2 changes: 2 additions & 0 deletions networks/lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
rank_stabilized: Optional[bool] = False
) -> None:
super().__init__()
self.multiplier = multiplier
Expand Down Expand Up @@ -404,6 +405,7 @@ def create_modules(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
rank_stabilized=rank_stabilized
)
loras.append(lora)

Expand Down

0 comments on commit 6e3d33b

Please sign in to comment.