Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Val loss (sd3 wip) (need help) #1856

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 122 additions & 3 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class BaseDatasetParams:
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False

validation_seed: Optional[int] = None
validation_split: float = 0.0

@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
Expand Down Expand Up @@ -236,6 +237,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"min_bucket_reso": int,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"validation_seed": int,
"validation_split": float,
}

# options handled by argparse but not handled by user config
Expand Down Expand Up @@ -478,9 +481,42 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)

val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
continue
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = []
for subset_blueprint in dataset_blueprint.subsets:
subset_blueprint.params.num_repeats = 1
subset_blueprint.params.color_aug = False
subset_blueprint.params.flip_aug = False
subset_blueprint.params.random_crop = False
subset_blueprint.params.random_crop = None
subset_blueprint.params.caption_dropout_rate = 0.0
subset_blueprint.params.caption_dropout_every_n_epochs = 0
subset_blueprint.params.caption_tag_dropout_rate = 0.0
subset_blueprint.params.token_warmup_step = 0

if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg):
subsets.append(subset_klass(**asdict(subset_blueprint.params)))

dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

# print info
info = ""
for i, dataset in enumerate(datasets):
Expand Down Expand Up @@ -566,6 +602,78 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu

logger.info(f"{info}")

# print validation info
info = ""
for i, dataset in enumerate(val_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)

if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
),
" ",
)

if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)

logger.info(f'{info}')

# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
Expand All @@ -574,8 +682,19 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset.make_buckets()
dataset.set_seed(seed)

return DatasetGroup(datasets)
if val_datasets:
return DatasetGroup(datasets)

else:
for i, dataset in enumerate(val_datasets):
logger.info(f"[Validation Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

return (
DatasetGroup(datasets),
DatasetGroup(val_datasets)
)

def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
Expand Down
23 changes: 22 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,11 +1799,16 @@ def __init__(
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
is_train: bool,
validation_split: float,
validation_seed: Optional[int],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)

assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"

self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
Expand Down Expand Up @@ -1878,6 +1883,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
else:
img_paths = glob_images(subset.image_dir, "*")
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
sizes = [None] * len(img_paths)

# new caching: get image size from cache files
Expand Down Expand Up @@ -6328,6 +6335,20 @@ def sample_image_inference(
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption

def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)

if is_train:
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
return paths[len(paths) - round(len(paths) * validation_split):]

# endregion

Expand Down
Loading
Loading