Skip to content

Commit

Permalink
Use @Property
Browse files Browse the repository at this point in the history
  • Loading branch information
shirayu committed Oct 27, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent efef5c8 commit 0d21925
Showing 5 changed files with 10 additions and 9 deletions.
4 changes: 2 additions & 2 deletions fine_tune.py
Original file line number Diff line number Diff line change
@@ -406,15 +406,15 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.log(logs, step=global_step)

loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
3 changes: 2 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
@@ -4700,5 +4700,6 @@ def add(self, *, epoch:int, step: int, loss: float) -> None:
self.loss_list[step] = loss
self.loss_total += loss

def get_moving_average(self) -> float:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
4 changes: 2 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
@@ -633,15 +633,15 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.log(logs, step=global_step)

loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
4 changes: 2 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
@@ -392,15 +392,15 @@ def train(args):
accelerator.log(logs, step=global_step)

loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
4 changes: 2 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
@@ -854,7 +854,7 @@ def remove_model(old_ckpt_name):

current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.get_moving_average()
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

@@ -869,7 +869,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.get_moving_average()}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()

0 comments on commit 0d21925

Please sign in to comment.