diff --git a/fine_tune.py b/fine_tune.py index 176087065..e1ed47496 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,9 +91,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train.py b/flux_train.py index fced3bef9..6f98adea8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -138,9 +138,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 9d36a41d3..cecd00019 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -126,9 +126,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975bae..b3aebecc7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -339,6 +339,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -375,7 +376,7 @@ def get_noise_pred_and_target( def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -420,7 +421,9 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + with torch.set_grad_enabled(is_train and train_unet): + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) """ return model_pred diff --git a/library/config_util.py b/library/config_util.py index 12d0be173..a2e07dc6c 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -73,6 +73,8 @@ class BaseSubsetParams: token_warmup_min: int = 1 token_warmup_step: float = 0 custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 @dataclass @@ -102,6 +104,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass @@ -113,8 +117,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -234,6 +237,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, } @@ -462,119 +467,136 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None): return default_value - -def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]: datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": True} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent( - f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) + val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0: + logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...") + continue - if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) - else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - keep_tokens_separator: {subset.keep_tokens_separator} - caption_separator: {subset.caption_separator} - secondary_separator: {subset.secondary_separator} - enable_wildcard: {subset.enable_wildcard} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min} - token_warmup_step: {subset.token_warmup_step} - alpha_mask: {subset.alpha_mask} - custom_attributes: {subset.custom_attributes} - """ - ), - " ", - ) + # if the dataset isn't setting a validation split, there is no current validation dataset + if dataset_blueprint.params.validation_split == 0.0: + continue - if is_dreambooth: - info += indent( - dedent( - f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n""" - ), - " ", - ) - elif not is_controlnet: - info += indent( - dedent( - f"""\ - metadata_file: {subset.metadata_file} - \n""" - ), - " ", - ) + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": False} + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset - logger.info(f"{info}") + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) + val_datasets.append(dataset) + + def print_info(_datasets, dataset_type: str): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [{dataset_type} {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent(dedent(f"""\ + [Subset {j} of {dataset_type} {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + logger.info(info) + + print_info(datasets, "Dataset") + + if len(val_datasets) > 0: + print_info(val_datasets, "Validation Dataset") # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + logger.info(f"[Prepare dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + logger.info(f"[Prepare validation dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index faf443048..ad3e69ffb 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,7 +1,9 @@ +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler import torch import argparse import random import re +from torch.types import Number from typing import List, Optional, Union from .utils import setup_logging @@ -63,7 +65,7 @@ def enforce_zero_terminal_snr(betas): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): +def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: @@ -74,13 +76,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False return loss -def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): +def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): scale = get_snr_scale(timesteps, noise_scheduler) loss = loss * scale return loss -def get_snr_scale(timesteps, noise_scheduler): +def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) @@ -89,14 +91,14 @@ def get_snr_scale(timesteps, noise_scheduler): return scale -def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): +def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if v_prediction: @@ -453,7 +455,7 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): +def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor: b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): @@ -466,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4): # https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor: if noise_offset is None: return noise if adaptive_noise_scale is not None: @@ -482,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, batch) -> torch.FloatTensor: if "conditioning_images" in batch: # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel diff --git a/library/strategy_sd.py b/library/strategy_sd.py index d0a3a68bf..a44fc4092 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,7 +40,7 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: text = [text] if isinstance(text, str) else text tokens_list = [] weights_list = [] diff --git a/library/train_util.py b/library/train_util.py index 72b5b24db..4d143c373 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import re import shutil import time +import typing from typing import ( Any, Callable, @@ -145,6 +146,37 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" +def split_train_val( + paths: List[str], + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None +) -> List[str]: + """ + Split the dataset into train and validation + + Shuffle the dataset based on the validation_seed or the current random seed. + For example if the split of 0.2 of 100 images. + [0:80] = 80 training images + [80:] = 20 validation images + """ + if validation_seed is not None: + logging.info(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + # Split the dataset between training and validation + if is_training_dataset: + # Training dataset we split to the first part + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + # Validation dataset we split to the second part + return paths[len(paths) - round(len(paths) * validation_split):] + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -397,6 +429,8 @@ def __init__( token_warmup_min: int, token_warmup_step: Union[float, int], custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -424,6 +458,9 @@ def __init__( self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + class DreamBoothSubset(BaseSubset): def __init__( @@ -453,6 +490,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -478,6 +517,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.is_reg = is_reg @@ -518,6 +559,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -543,6 +586,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.metadata_file = metadata_file @@ -579,6 +624,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -604,6 +651,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.conditioning_data_dir = conditioning_data_dir @@ -1786,9 +1835,13 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + # The is_training_dataset defines the type of dataset, training or validation + # if is_training_dataset is True -> training dataset + # if is_training_dataset is False -> validation dataset def __init__( self, subsets: Sequence[DreamBoothSubset], + is_training_dataset: bool, batch_size: int, resolution, network_multiplier: float, @@ -1799,6 +1852,8 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + validation_split: float, + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1808,6 +1863,9 @@ def __init__( self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_training_dataset = is_training_dataset + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1915,6 +1973,30 @@ def load_dreambooth_dir(subset: DreamBoothSubset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + # We want to create a training and validation split. This should be improved in the future + # to allow a clearer distinction between training and validation. This can be seen as a + # short-term solution to limit what is necessary to implement validation datasets + # + # We split the dataset for the subset based on if we are doing a validation split + # The self.is_training_dataset defines the type of dataset, training or validation + # if self.is_training_dataset is True -> training dataset + # if self.is_training_dataset is False -> validation dataset + if self.validation_split > 0.0: + # For regularization images we do not want to split this dataset. + if subset.is_reg is True: + # Skip any validation dataset for regularization images + if self.is_training_dataset is False: + img_paths = [] + # Otherwise the img_paths remain as original img_paths and no split + # required for training images dataset of regularization images + else: + img_paths = split_train_val( + img_paths, + self.is_training_dataset, + self.validation_split, + self.validation_seed + ) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -1973,9 +2055,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: - if subset.num_repeats < 1: + num_repeats = subset.num_repeats if self.is_training_dataset else 1 + if num_repeats < 1: logger.warning( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}" ) continue @@ -1993,12 +2076,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): continue if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) + num_reg_images += num_repeats * len(img_paths) else: - num_train_images += subset.num_repeats * len(img_paths) + num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: @@ -2009,10 +2092,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + images_split_name = "train" if self.is_training_dataset else "validation" + logger.info(f"{num_train_images} {images_split_name} images with repeats.") + self.num_train_images = num_train_images - logger.info(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images with repeats.") if num_train_images < num_reg_images: logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") @@ -2050,6 +2135,8 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2275,7 +2362,9 @@ def __init__( max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset: float, + debug_dataset: bool, + validation_split: float, + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2324,13 +2413,17 @@ def __init__( bucket_no_upscale, 1.0, debug_dataset, + validation_split, + validation_seed, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.validation_split = validation_split + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -4544,7 +4637,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - logger.info(args.config_file) return args @@ -4887,7 +4979,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - + if optimizer_type == "RAdamScheduleFree".lower(): optimizer_class = sf.RAdamScheduleFree logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") @@ -5838,13 +5930,13 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep, max_timestep, b_size, device): +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") timesteps = timesteps.long().to(device) return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -5905,11 +5997,16 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): + """ + NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already + """ if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5917,6 +6014,8 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -6329,6 +6428,30 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): + """ + Initialize experiment trackers with tracker specific behaviors + """ + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + default_tracker_name if args.log_tracker_name is None else args.log_tracker_name, + config=get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) + + # Define specific metrics to handle validation and epochs "steps" + wandb_tracker.define_metric("epoch", hidden=True) + wandb_tracker.define_metric("val_step", hidden=True) + # endregion diff --git a/sd3_train.py b/sd3_train.py index 120455e7b..3bff6a50f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -149,9 +149,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sd3_train_network.py b/sd3_train_network.py index fb7711bda..c7417802d 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -312,6 +312,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -339,7 +340,7 @@ def get_noise_pred_and_target( t5_attn_mask = None # call model - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/sdxl_train.py b/sdxl_train.py index b9d529243..a60f6df63 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -176,9 +176,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index ffbf03cab..c6e8136f7 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -114,7 +114,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 365059b75..00e51a673 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -123,7 +123,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5b372befc..63457cc61 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -103,7 +103,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index c034f949a..515ece98d 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5888b8e3d..00459658e 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/train_control_net.py b/train_control_net.py index 177d2b11f..ba016ac5d 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -100,7 +100,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_db.py b/train_db.py index ad21f8d1b..edd674034 100644 --- a/train_db.py +++ b/train_db.py @@ -89,9 +89,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_network.py b/train_network.py index 5e82b307c..e7d93a108 100644 --- a/train_network.py +++ b/train_network.py @@ -2,17 +2,19 @@ import argparse import math import os +import typing +from typing import Any, List import sys import random import time import json from multiprocessing import Value -from typing import Any, List import toml from tqdm import tqdm import torch +from torch.types import Number from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -20,6 +22,7 @@ from accelerate.utils import set_seed from accelerate import Accelerator from diffusers import DDPMScheduler +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util @@ -114,7 +117,7 @@ def generate_step_logs( ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): + ): logs[f"lr/d*lr/group{i}"] = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) @@ -196,10 +199,10 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor: return vae.encode(images).latent_dist.sample() - def shift_scale_latents(self, args, latents): + def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor: return latents * self.vae_scale_factor def get_noise_pred_and_target( @@ -214,6 +217,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -227,7 +231,7 @@ def get_noise_pred_and_target( t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -271,7 +275,7 @@ def get_noise_pred_and_target( return noise_pred, target, timesteps, None - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -308,6 +312,107 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion + def process_batch( + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True + ) -> torch.Tensor: + """ + Process a batch for the network + """ + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + else: + # latentに変換 + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents)) + + latents = self.shift_scale_latents(args, latents) + + text_encoder_conds = [] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train + ) + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + return loss.mean() + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -373,10 +478,11 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -384,8 +490,12 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) + + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly + train_util.debug_dataset(val_dataset_group) return if len(train_dataset_group) == 0: logger.error( @@ -397,6 +507,10 @@ def train(self, args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" self.assert_extra_args(args, train_dataset_group) # may change some args @@ -444,6 +558,8 @@ def train(self, args): vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -459,6 +575,8 @@ def train(self, args): if text_encoder_outputs_caching_strategy is not None: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -567,6 +685,8 @@ def train(self, args): # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -579,6 +699,15 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -654,8 +783,8 @@ def train(self, args): text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler + ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler ) training_model = ds_model else: @@ -676,8 +805,8 @@ def train(self, args): else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler + network, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, val_dataloader, lr_scheduler ) training_model = network @@ -769,6 +898,7 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -788,6 +918,7 @@ def load_model_hook(models, input_dir): "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0, "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -835,6 +966,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1051,20 +1187,15 @@ def load_model_hook(models, input_dir): noise_scheduler = self.get_noise_scheduler(args, accelerator.device) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) + train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() + val_step_loss_recorder = train_util.LossRecorder() + val_epoch_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): @@ -1109,10 +1240,17 @@ def remove_model(old_ckpt_name): optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) optimizer_train_fn() - if len(accelerator.trackers) > 0: + is_tracking = len(accelerator.trackers) > 0 + if is_tracking: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) + if args.max_validation_steps is not None + else len(val_dataloader) + ) + # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs @@ -1132,13 +1270,14 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) for epoch in range(epoch_to_start, num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + # TRAINING skipped_dataloader = None if initial_step > 0: skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) @@ -1156,98 +1295,24 @@ def remove_model(old_ckpt_name): # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - with torch.no_grad(): - # latentに変換 - latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) - latents = latents.to(dtype=weight_dtype) - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - - latents = self.shift_scale_latents(args, latents) - - # get multiplier for each sample - if network_has_multiplier: - multipliers = batch["network_multipliers"] - # if all multipliers are same, use single multiplier - if torch.all(multipliers == multipliers[0]): - multipliers = multipliers[0].item() - else: - raise NotImplementedError("multipliers for each sample is not supported yet") - # print(f"set multiplier: {multipliers}") - accelerator.unwrap_model(network).set_multiplier(multipliers) - - text_encoder_conds = [] - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - - if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: - # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) - encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, - self.get_models_for_text_encoding(args, accelerator, text_encoders), - input_ids_list, - weights_list, - ) - else: - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( - tokenize_strategy, - self.get_models_for_text_encoding(args, accelerator, text_encoders), - input_ids, - ) - if args.full_fp16: - encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] - - # if text_encoder_conds is not cached, use encoded_text_encoder_conds - if len(text_encoder_conds) == 0: - text_encoder_conds = encoded_text_encoder_conds - else: - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] - - # sample noise, call unet, get target - noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet, - network, - weight_dtype, - train_unet, + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet ) - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1302,19 +1367,144 @@ def remove_model(old_ckpt_name): if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if len(accelerator.trackers) > 0: + + if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm ) accelerator.log(logs, step=global_step) + # VALIDATION PER STEP + should_validate_step = ( + args.validate_every_n_steps is not None + and global_step != 0 # Skip first step + and global_step % args.validate_every_n_steps == 0 + ) + if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + val_progress_bar = tqdm( + range(validation_steps), smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps" + ) + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_steps) + val_step, + } + accelerator.log(logs, step=global_step) + + if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average + logs = { + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, + } + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break - if len(accelerator.trackers) > 0: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + # EPOCH VALIDATION + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 + if args.validate_every_n_epochs is not None + else True + ) + + if should_validate_epoch and len(val_dataloader) > 0: + val_progress_bar = tqdm( + range(validation_steps), smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps" + ) + + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) + + current_loss = loss.detach().item() + val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) + + if is_tracking: + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step + } + accelerator.log(logs, step=global_step) + + if is_tracking: + avr_loss: float = val_epoch_loss_recorder.moving_average + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + logs = { + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1 + } + accelerator.log(logs, step=global_step) + # END OF EPOCH + if is_tracking: + logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1496,9 +1686,36 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) - # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") - # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") - # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + ) + parser.add_argument( + "--validate_every_n_steps", + type=int, + default=None, + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" + ) + parser.add_argument( + "--validate_every_n_epochs", + type=int, + default=None, + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" + ) return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 65da4859b..113f35997 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -320,9 +320,10 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None self.assert_extra_args(args, train_dataset_group) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2a2b42310..6ff97d03f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -239,7 +239,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) current_epoch = Value("i", 0) current_step = Value("i", 0)