Skip to content

Commit

Permalink
Fix initialize place of loss_recorder
Browse files Browse the repository at this point in the history
  • Loading branch information
shirayu committed Oct 27, 2023
1 parent 9d00c8e commit 63992b8
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
@@ -288,14 +288,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
m.train()

loss_recorder = train_util.LossRecorder()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
@@ -452,14 +452,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
m.train()

loss_recorder = train_util.LossRecorder()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく

0 comments on commit 63992b8

Please sign in to comment.