diff --git a/fine_tune.py b/fine_tune.py index 27d647392..afec7d273 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -288,6 +288,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -295,7 +296,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく diff --git a/sdxl_train.py b/sdxl_train.py index 9017d7b8c..f681f28fc 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -452,6 +452,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -459,7 +460,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく