Skip to content

Commit

Permalink
simplyfy code by updating accelerate to 0.30.0
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 12, 2024
1 parent c1ef6dc commit f33e155
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
accelerate==0.29.2
accelerate==0.30.0
transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
Expand All @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.43.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.5
tensorboard
safetensors==0.4.2
# gradio==3.16.2
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ def train(args):

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def train(args):

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
Expand Down
4 changes: 2 additions & 2 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def train(args):

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
Expand Down
4 changes: 2 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def train(args):

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def train(args):

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
Expand Down

0 comments on commit f33e155

Please sign in to comment.