Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation loss #1864

Open
wants to merge 77 commits into
base: sd3
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
5b19bda
Add validation loss
rockerBOO Nov 5, 2023
33c311e
new ratio code
rockerBOO Nov 5, 2023
3de9e6c
Add validation split of datasets
rockerBOO Nov 5, 2023
a93c524
Update args to validation_seed and validation_split
rockerBOO Nov 5, 2023
c892521
Add process_batch for train_network
rockerBOO Nov 5, 2023
e545fdf
Removed/cleanup a line
rockerBOO Nov 5, 2023
9c591bd
Remove unnecessary subset line from collate
rockerBOO Nov 5, 2023
569ca72
Set grad enabled if is_train and train_text_encoder
rockerBOO Nov 7, 2023
b558a5b
val
gesen2egee Mar 9, 2024
78cfb01
improve
gesen2egee Mar 10, 2024
923b761
Update train_network.py
gesen2egee Mar 10, 2024
47359b8
Update train_network.py
gesen2egee Mar 10, 2024
a51723c
fix timesteps
gesen2egee Mar 11, 2024
7d84ac2
only use train subset to val
gesen2egee Mar 11, 2024
befbec5
Update train_network.py
gesen2egee Mar 11, 2024
63e58f7
Update train_network.py
gesen2egee Mar 11, 2024
a6c41c6
Update train_network.py
gesen2egee Mar 11, 2024
bd7e229
fix
gesen2egee Mar 13, 2024
5d7ed0d
Merge remote-tracking branch 'kohya-ss/dev' into val
gesen2egee Mar 13, 2024
d05965d
Update train_network.py
gesen2egee Mar 13, 2024
b5e8045
fix control net
gesen2egee Mar 16, 2024
086f600
Merge branch 'main' into val
gesen2egee Apr 10, 2024
36d4023
Update config_util.py
gesen2egee Apr 10, 2024
229c5a3
Update train_util.py
gesen2egee Apr 10, 2024
3b251b7
Update config_util.py
gesen2egee Apr 10, 2024
459b125
Update config_util.py
gesen2egee Apr 10, 2024
89ad69b
Update train_util.py
gesen2egee Apr 11, 2024
fde8026
Update config_util.py
gesen2egee Apr 11, 2024
31507b9
Remove unnecessary is_train changes and use apply_debiased_estimation…
gesen2egee Aug 2, 2024
1db4951
Update train_db.py
gesen2egee Aug 4, 2024
6816217
Update train_db.py
gesen2egee Aug 4, 2024
96eb74f
Update train_db.py
gesen2egee Aug 4, 2024
b9bdd10
Update train_network.py
gesen2egee Aug 4, 2024
3d68754
Update train_db.py
gesen2egee Aug 4, 2024
a593e83
Update train_network.py
gesen2egee Aug 4, 2024
f6dbf7c
Update train_network.py
gesen2egee Aug 4, 2024
aa850aa
Update train_network.py
gesen2egee Aug 4, 2024
cdb2d9c
Update train_network.py
gesen2egee Aug 4, 2024
3028027
Update train_network.py
gesen2egee Oct 4, 2024
dece2c3
Update train_db.py
gesen2egee Oct 4, 2024
05bb918
Add Validation loss for LoRA training
hinablue Dec 27, 2024
62164e5
Change val loss calculate method
hinablue Dec 27, 2024
64bd531
Split val latents/batch and pick up val latents shape size which equa…
hinablue Dec 28, 2024
cb89e02
Change val latent loss compare
hinablue Dec 28, 2024
8743532
val
gesen2egee Mar 9, 2024
449c1c5
Adding modified train_util and config_util
rockerBOO Jan 2, 2025
7f6e124
Merge branch 'gesen2egee/val' into validation-loss-upstream
rockerBOO Jan 3, 2025
d23c732
Merge remote-tracking branch 'hina/feature/val-loss' into validation-…
rockerBOO Jan 3, 2025
7470173
Remove defunct code for train_controlnet.py
rockerBOO Jan 3, 2025
534059d
Typos and lingering is_train
rockerBOO Jan 3, 2025
c8c3569
Cleanup order, types, print to logger
rockerBOO Jan 3, 2025
fbfc275
Update text for train/reg with repeats
rockerBOO Jan 3, 2025
58bfa36
Add seed help clarifying info
rockerBOO Jan 3, 2025
6604b36
Remove duplicate assignment
rockerBOO Jan 3, 2025
0522070
Fix training, validation split, revert to using upstream implemenation
rockerBOO Jan 3, 2025
695f389
Move get_huber_threshold_if_needed
rockerBOO Jan 3, 2025
1f9ba40
Add step break for validation epoch. Remove unused variable
rockerBOO Jan 3, 2025
1c0ae30
Add missing functions for training batch
rockerBOO Jan 3, 2025
bbf6bbd
Use self.get_noise_pred_and_target and drop fixed timesteps
rockerBOO Jan 6, 2025
f4840ef
Revert train_db.py
rockerBOO Jan 6, 2025
1c63e7c
Cleanup unused code and formatting
rockerBOO Jan 6, 2025
c64d1a2
Add validate_every_n_epochs, change name validate_every_n_steps
rockerBOO Jan 6, 2025
f885029
Fix validate epoch, cleanup imports
rockerBOO Jan 6, 2025
fcb2ff0
Clean up some validation help documentation
rockerBOO Jan 6, 2025
742bee9
Set validation steps in multiple lines for readability
rockerBOO Jan 6, 2025
1231f51
Remove unused train_util code, fix accelerate.log for wandb, add init…
rockerBOO Jan 8, 2025
556f3f1
Fix documentation, remove unused function, fix bucket reso for sd1.5,…
rockerBOO Jan 8, 2025
9fde0d7
Handle tuple return from generate_dataset_group_by_blueprint
rockerBOO Jan 8, 2025
1e61392
Revert bucket_reso_steps to correct 64
rockerBOO Jan 8, 2025
d6f158d
Fix incorrect destructoring for load_abritrary_dataset
rockerBOO Jan 8, 2025
264167f
Apply is_training_dataset only to DreamBoothDataset. Add validation_s…
rockerBOO Jan 9, 2025
4c61adc
Add divergence to logs
rockerBOO Jan 12, 2025
2bbb40c
Fix regularization images with validation
rockerBOO Jan 12, 2025
0456858
Fix validate_every_n_steps always running first step
rockerBOO Jan 12, 2025
ee9265c
Fix validate_every_n_steps for gradient accumulation
rockerBOO Jan 12, 2025
25929dd
Remove Validating... print to fix output layout
rockerBOO Jan 12, 2025
b489082
Disable repeats for validation datasets
rockerBOO Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve
  • Loading branch information
gesen2egee committed Mar 10, 2024

Verified

This commit was signed with the committer’s verified signature.
rockerBOO Dave Lage
commit 78cfb01922ff97bbc62ff12a4d69eaaa2d89d7c1
260 changes: 187 additions & 73 deletions library/config_util.py
Original file line number Diff line number Diff line change
@@ -41,12 +41,17 @@
DatasetGroup,
)
from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
parser.add_argument(
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
)


# TODO: inherit Params class in Subset, Dataset
@@ -60,6 +65,8 @@ class BaseSubsetParams:
caption_separator: str = (",",)
keep_tokens: int = 0
keep_tokens_separator: str = (None,)
secondary_separator: Optional[str] = None
enable_wildcard: bool = False
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -181,6 +188,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
@@ -247,9 +256,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
}

def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert (
support_dreambooth or support_finetuning or support_controlnet
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
assert support_dreambooth or support_finetuning or support_controlnet, (
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
)

self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -361,7 +371,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) ->
return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid:
# XXX: this should be a bug
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
logger.error(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise

# NOTE: value would be overwritten by latter dict if there is already the same key
@@ -447,7 +459,6 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):

return default_value


def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

@@ -467,7 +478,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
datasets.append(dataset)

val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
continue
@@ -485,75 +496,174 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

def print_info(_datasets):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")
# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)

if dataset.enable_bucket:
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ")

if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")

print(info)

print_info(datasets)

if len(val_datasets) > 0:
print("Validation dataset")
print_info(val_datasets)

info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
),
" ",
)

if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)

logger.info(f'{info}')

# print validation info
info = ""
for i, dataset in enumerate(val_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)

if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
),
" ",
)

if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)

logger.info(f'{info}')

# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]")
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

for i, dataset in enumerate(val_datasets):
print(f"[Validation Dataset {i}]")
dataset.make_buckets()
@@ -562,8 +672,8 @@ def print_info(_datasets):
return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)
)

def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
tokens = name.split("_")
@@ -642,13 +752,17 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f:
config = json.load(f)
except Exception:
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
elif file.name.lower().endswith(".toml"):
try:
config = toml.load(file)
except Exception:
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -675,13 +789,13 @@ def load_user_config(file: str) -> dict:
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)

logger.info("[argparse_namespace]")
logger.info(f'{vars(argparse_namespace)}')
logger.info(f"{vars(argparse_namespace)}")

user_config = load_user_config(config_args.dataset_config)

logger.info("")
logger.info("[user_config]")
logger.info(f'{user_config}')
logger.info(f"{user_config}")

sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
@@ -690,10 +804,10 @@ def load_user_config(file: str) -> dict:

logger.info("")
logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}')
logger.info(f"{sanitized_user_config}")

blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)

logger.info("")
logger.info("[blueprint]")
logger.info(f'{blueprint}')
logger.info(f"{blueprint}")
Loading