Skip to content

Commit

Permalink
Fix regularization images with validation
Browse files Browse the repository at this point in the history
Adding metadata recording for validation arguments
Add comments about the validation split for clarity of intention
  • Loading branch information
rockerBOO committed Jan 12, 2025
1 parent 4c61adc commit 2bbb40c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
33 changes: 31 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"

def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]:
def split_train_val(
paths: List[str],
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None
) -> List[str]:
"""
Split the dataset into train and validation
Expand Down Expand Up @@ -1830,6 +1835,9 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"

# The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
Expand Down Expand Up @@ -1965,8 +1973,29 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")

# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
#
# We split the dataset for the subset based on if we are doing a validation split
# The self.is_training_dataset defines the type of dataset, training or validation
# if self.is_training_dataset is True -> training dataset
# if self.is_training_dataset is False -> validation dataset
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed)
# For regularization images we do not want to split this dataset.
if subset.is_reg is True:
# Skip any validation dataset for regularization images
if self.is_training_dataset is False:
img_paths = []
# Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images
else:
img_paths = split_train_val(
img_paths,
self.is_training_dataset,
self.validation_split,
self.validation_seed
)

logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

Expand Down
7 changes: 7 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ def load_model_hook(models, input_dir):

accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
Expand All @@ -917,6 +918,7 @@ def load_model_hook(models, input_dir):
"ss_text_encoder_lr": text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset_group.num_train_images,
"ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0,
"ss_num_reg_images": train_dataset_group.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
Expand Down Expand Up @@ -964,6 +966,11 @@ def load_model_hook(models, input_dir):
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
}

self.update_metadata(metadata, args) # architecture specific metadata
Expand Down

0 comments on commit 2bbb40c

Please sign in to comment.