From 07a094f5eb8a75ab08202fd890898d15846b9cf8 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 27 Dec 2024 00:44:14 +0800 Subject: [PATCH] val loss (wip) val loss --- library/config_util.py | 125 +++++++++++++++++++++++++++++++- library/train_util.py | 23 +++++- train_network.py | 160 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 299 insertions(+), 9 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 12d0be173..d213bdb94 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -102,7 +102,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False - + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -236,6 +237,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "validation_seed": int, + "validation_split": float, } # options handled by argparse but not handled by user config @@ -478,9 +481,42 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [] + for subset_blueprint in dataset_blueprint.subsets: + subset_blueprint.params.num_repeats = 1 + subset_blueprint.params.color_aug = False + subset_blueprint.params.flip_aug = False + subset_blueprint.params.random_crop = False + subset_blueprint.params.random_crop = None + subset_blueprint.params.caption_dropout_rate = 0.0 + subset_blueprint.params.caption_dropout_every_n_epochs = 0 + subset_blueprint.params.caption_tag_dropout_rate = 0.0 + subset_blueprint.params.token_warmup_step = 0 + + if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): + subsets.append(subset_klass(**asdict(subset_blueprint.params))) + + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + # print info info = "" for i, dataset in enumerate(datasets): @@ -566,6 +602,78 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(f"{info}") + # print validation info + info = "" + for i, dataset in enumerate(val_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f'{info}') + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no @@ -574,8 +682,19 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + if val_datasets: + return DatasetGroup(datasets) + else: + for i, dataset in enumerate(val_datasets): + logger.info(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: diff --git a/library/train_util.py b/library/train_util.py index 72b5b24db..534489bdc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1799,11 +1799,16 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + is_train: bool, + validation_split: float, + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1878,6 +1883,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) sizes = [None] * len(img_paths) # new caching: get image size from cache files @@ -6328,6 +6335,20 @@ def sample_image_inference( # not to commit images to avoid inconsistency between training and logging steps wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] # endregion diff --git a/train_network.py b/train_network.py index 5e82b307c..16c1b40c9 100644 --- a/train_network.py +++ b/train_network.py @@ -44,6 +44,7 @@ setup_logging() import logging +import itertools logger = logging.getLogger(__name__) @@ -306,6 +307,98 @@ def prepare_unet_with_accelerator( def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass + def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, tokenize_strategy): + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) + latents = latents.to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + latents = self.shift_scale_latents(args, latents) + + text_encoder_conds = [] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' + with torch.set_grad_enabled(False), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(False), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss + # endregion def train(self, args): @@ -327,7 +420,6 @@ def train(self, args): tokenize_strategy = self.get_tokenize_strategy(args) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored - # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = self.get_latents_caching_strategy(args) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -373,11 +465,11 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) - + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None @@ -444,7 +536,9 @@ def train(self, args): vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) - + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -580,6 +674,17 @@ def train(self, args): persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + cyclic_val_dataloader = itertools.cycle(val_dataloader) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -1064,6 +1169,7 @@ def load_model_hook(models, input_dir): ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() del train_dataset_group # callback for step start @@ -1308,11 +1414,31 @@ def remove_model(old_ckpt_name): ) accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, tokenize_strategy) + total_loss += loss.detach().item() + current_val_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_val_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_val_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if len(accelerator.trackers) > 0: - logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -1496,6 +1622,30 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")