Skip to content

Commit

Permalink
Apply is_training_dataset only to DreamBoothDataset. Add validation_s…
Browse files Browse the repository at this point in the history
…plit check and warning
  • Loading branch information
rockerBOO committed Jan 9, 2025
1 parent d6f158d commit 264167f
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,36 +471,49 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

for dataset_blueprint in dataset_group_blueprint.datasets:
extra_dataset_params = {}

if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": True}
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
datasets.append(dataset)

val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
continue

# if the dataset isn't setting a validation split, there is no current validation dataset
if dataset_blueprint.params.validation_split == 0.0:
continue

extra_dataset_params = {}
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": False}
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
val_datasets.append(dataset)

def print_info(_datasets, dataset_type: str):
Expand Down

0 comments on commit 264167f

Please sign in to comment.