Skip to content

Commit

Permalink
Add rank-stabilized to apply_max_norm_regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 11, 2025
1 parent 6e3d33b commit 4ae674f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
6 changes: 5 additions & 1 deletion networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,7 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.rank_stabilized = rank_stabilized

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
Expand Down Expand Up @@ -1391,7 +1392,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down
6 changes: 5 additions & 1 deletion networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.rank_stabilized = rank_stabilized

if modules_dim is not None:
logger.info(f"create LoRA network from weights")
Expand Down Expand Up @@ -1225,7 +1226,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down
8 changes: 7 additions & 1 deletion networks/lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def __init__(
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
rank_stabilized: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
Expand All @@ -580,6 +581,7 @@ def __init__(
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.rank_stabilized = rank_stabilized

self.type_dims = type_dims
self.in_dims = in_dims
Expand Down Expand Up @@ -726,6 +728,7 @@ def create_modules(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
rank_stabilized=rank_stabilized,
)
loras.append(lora)

Expand Down Expand Up @@ -1136,7 +1139,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down
6 changes: 5 additions & 1 deletion networks/lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def __init__(
self.module_dropout = module_dropout
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.rank_stabilized = rank_stabilized

self.type_dims = type_dims
self.emb_dims = emb_dims
Expand Down Expand Up @@ -816,7 +817,10 @@ def apply_max_norm_regularization(self, max_norm_value, device):
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
rank_factor = dim
if self.rank_stabilized:
rank_factor = math.sqrt(rank_factor)
scale = alpha / rank_factor

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
Expand Down

0 comments on commit 4ae674f

Please sign in to comment.