Skip to content

Commit

Permalink
Use LossRecorder
Browse files Browse the repository at this point in the history
  • Loading branch information
shirayu committed Oct 27, 2023
1 parent 0d21925 commit 9d00c8e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 33 deletions.
16 changes: 5 additions & 11 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -470,14 +469,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -488,7 +482,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down
16 changes: 5 additions & 11 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group

# function for saving/removing
Expand Down Expand Up @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand All @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()
Expand Down

0 comments on commit 9d00c8e

Please sign in to comment.