diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..9ae67b0e9 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,42 @@ + +name: Python package + +on: [push] + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel + + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + cache: 'pip' # caching pip dependencies + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + pip install -r requirements.txt + + - name: Test with pytest + run: | + pip install pytest + pytest + diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 0149dcdd3..87ebdf894 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -1,9 +1,11 @@ --- -# yamllint disable rule:line-length name: Typos -on: # yamllint disable-line rule:truthy +on: push: + branches: + - main + - dev pull_request: types: - opened @@ -18,4 +20,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.28.1 diff --git a/README.md b/README.md index f9c85e3ac..f02725191 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,26 @@ The command to install PyTorch is as follows: ### Recent Updates + +Dec 3, 2024: + +-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training). + +Dec 2, 2024: + +- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. + - Not fully tested. Feedback is welcome. + - 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution. + - Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required. + - Multi-GPU training is not tested. + +Dec 1, 2024: + +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! + - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. + +- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO! + Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. @@ -28,6 +48,7 @@ Nov 14, 2024: - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) +- [FLUX.1 ControlNet training](#flux1-controlnet-training) - [FLUX.1 OFT training](#flux1-oft-training) - [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) @@ -245,6 +266,30 @@ example: If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. +### FLUX.1 ControlNet training +We have added a new training script for ControlNet training. The script is flux_train_control_net.py. See --help for options. + +Sample command is below. It will work with 80GB VRAM GPUs. +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors +--ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 +--optimizer_type adamw8bit --learning_rate 2e-5 +--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config dataset.toml +--output_dir /path/to/output/dir --output_name flux-cn +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed +``` + +For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested. +``` + --blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk +``` + +The training can be done with 16GB VRAM GPUs with around 30 blocks swapped. + +`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used. + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. diff --git a/fine_tune.py b/fine_tune.py index b1668d657..09fa9537a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,9 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -394,11 +392,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -410,9 +407,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/flux_train.py b/flux_train.py index e5e4d17c1..ca151fc24 100644 --- a/flux_train.py +++ b/flux_train.py @@ -676,9 +676,8 @@ def grad_hook(parameter: torch.Tensor): target = noise - latents # calculate loss - loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): diff --git a/flux_train_control_net.py b/flux_train_control_net.py new file mode 100644 index 000000000..5548fd991 --- /dev/null +++ b/flux_train_control_net.py @@ -0,0 +1,877 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value +from typing import List, Optional, Tuple, Union + +import toml +import torch +import torch.nn as nn +from tqdm import tqdm + +from library import utils +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +from accelerate.utils import set_seed + +import library.train_util as train_util +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + BlueprintGenerator, + ConfigSanitizer, +) +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, args.conditioning_data_dir, args.caption_extension + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + is_schnell, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + flux.requires_grad_(False) + + # load controlnet + controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype + controlnet = flux_utils.load_controlnet( + args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + ) + controlnet.train() + + if args.gradient_checkpointing: + if not args.deepspeed: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + # ControlNet only has two blocks, so we can keep it on GPU + # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + else: + flux.to(accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(controlnet) + name_and_params = list(controlnet.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(controlnet.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks]) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + flux.prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = ( + flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width) + .to(device=accelerator.device) + .to(weight_dtype) + ) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + with accelerator.autocast(): + block_samples, block_single_samples = controlnet( + img=packed_noisy_model_input, + img_ids=img_ids, + controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype), + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, + args, + None, + global_step, + flux, + ae, + [clip_l, t5xxl], + sample_prompts_te_outputs, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 19b9fda87..c5f2a949a 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -6,7 +6,8 @@ import torch from accelerate import Accelerator -from library.device_utils import init_ipex, clean_memory_on_device + +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -177,7 +178,7 @@ def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length - + # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, @@ -473,7 +474,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/flux_models.py b/library/flux_models.py index fa3c7ad2b..328ad481d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,15 +2,15 @@ # license: Apache-2.0 License -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass import math import os import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from typing import Dict, List, Optional, Union from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -18,6 +18,7 @@ from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint + from library import custom_offloading_utils # USE_REENTRANT = True @@ -1013,6 +1014,8 @@ def forward( txt_ids: Tensor, timesteps: Tensor, y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, ) -> Tensor: @@ -1031,18 +1034,29 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + if block_controlnet_single_hidden_states is not None: + controlnet_single_depth = len(block_controlnet_single_hidden_states) if not self.blocks_to_swap: - for block in self.double_blocks: + for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None and controlnet_depth > 0: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + img = torch.cat((txt, img), 1) - for block in self.single_blocks: + for block_idx, block in enumerate(self.single_blocks): img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None and controlnet_depth > 0: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1052,6 +1066,8 @@ def forward( self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1066,6 +1082,246 @@ def forward( return img +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(controlnet_single_depth) + ] + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.controlnet_blocks_for_single = nn.ModuleList([]) + for _ in range(controlnet_single_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_single.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> tuple[tuple[Tensor]]: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_samples = () + block_single_samples = () + if not self.blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + controlnet_block_samples = () + controlnet_single_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): + block_sample = controlnet_block(block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) + + return controlnet_block_samples, controlnet_single_block_samples + + """ class FluxUpper(nn.Module): "" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d90644a25..de2e2b48d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,6 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, + controlnet=None ): if steps == 0: if not args.sample_at_first: @@ -67,6 +68,8 @@ def sample_images( flux = accelerator.unwrap_model(flux) if text_encoders is not None: text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if controlnet is not None: + controlnet = accelerator.unwrap_model(controlnet) # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -98,6 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -121,6 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) torch.set_rng_state(rng_state) @@ -142,6 +147,7 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ): assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") @@ -150,7 +156,7 @@ def sample_image_inference( height = prompt_dict.get("height", 512) scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") - # controlnet_image = prompt_dict.get("controlnet_image") + controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) @@ -169,7 +175,6 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") @@ -223,10 +228,15 @@ def sample_image_inference( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image @@ -301,18 +311,39 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + controlnet: Optional[flux_models.ControlNetFlux] = None, + controlnet_img: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: + block_samples, block_single_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_img, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + else: + block_samples = None + block_single_samples = None pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, @@ -432,7 +463,7 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents - return noisy_model_input, timesteps, sigmas + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): @@ -532,6 +563,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" + ) parser.add_argument( "--t5xxl_max_token_length", type=int, diff --git a/library/flux_utils.py b/library/flux_utils.py index f3093615d..8be1d63ee 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,14 +1,14 @@ -from dataclasses import replace import json import os +from dataclasses import replace from typing import List, Optional, Tuple, Union + import einops import torch - -from safetensors.torch import load_file -from safetensors import safe_open from accelerate import init_empty_weights -from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from library.utils import setup_logging @@ -153,6 +153,22 @@ def load_ae( return ae +def load_controlnet( + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device(device): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet + + def load_clip_l( ckpt_path: Optional[str], dtype: torch.dtype, diff --git a/library/sd3_models.py b/library/sd3_models.py index 8b90205db..2f3c82eed 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1017,22 +1017,35 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b patched_size = patched_size_ break if patched_size is None: - raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] pos_embed = self.resolution_pos_embeds[patched_size] - pos_embed_size = round(math.sqrt(pos_embed.shape[1])) + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO if h > pos_embed_size or w > pos_embed_size: # # fallback to normal pos_embed # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) # extend pos_embed size logger.warning( - f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - pos_embed_size = max(h, w) - pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) self.resolution_pos_embeds[patched_size] = pos_embed - logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) if not random_crop: top = (pos_embed_size - h) // 2 diff --git a/library/train_util.py b/library/train_util.py index b9c43ebdd..7f0231e8b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -21,7 +21,7 @@ Optional, Sequence, Tuple, - Union, + Union ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -3422,7 +3422,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1" + " / Huber損失の減衰パラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0" + " / Huber損失のスケールパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは1.0", ) parser.add_argument( @@ -4121,7 +4130,7 @@ def task(): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): +def get_optimizer(args, trainable_params) -> tuple[str, str, object]: # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type @@ -4408,10 +4417,10 @@ def get_optimizer(args, trainable_params): optimizer_class = sf.SGDScheduleFree logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop - optimizer.train() + optimizer_class = None + + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う @@ -4513,6 +4522,10 @@ def __instancecheck__(self, instance): optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + if hasattr(optimizer, "train") and callable(optimizer.train): + # make optimizer as train mode before training for schedulefree optimizer. the optimizer will be in eval mode in sampling and saving. + optimizer.train() + return optimizer_name, optimizer_args, optimizer @@ -5344,29 +5357,10 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): +def get_timesteps(min_timestep, max_timestep, b_size, device): timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") - - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = torch.exp(-alpha * timesteps) - elif args.huber_schedule == "snr": - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = torch.full((b_size,), args.huber_c) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - huber_c = huber_c.to(device) - elif args.loss_type == "l2": - huber_c = None # may be anything, as it's not used - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - timesteps = timesteps.long().to(device) - return timesteps, huber_c + return timesteps def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): @@ -5388,7 +5382,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -5401,11 +5395,34 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps + + +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] + model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) @@ -5426,7 +5443,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {loss_type}") return loss diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..484d3aef6 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +minversion = 6.0 +testpaths = + tests +filterwarnings = + ignore::DeprecationWarning + ignore::UserWarning + ignore::FutureWarning diff --git a/sd3_train.py b/sd3_train.py index fbbdf8bdb..f8e0c7343 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -681,8 +681,8 @@ def grad_hook(parameter: torch.Tensor): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - # noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # only used to get timesteps, etc. TODO manage timesteps etc. separately + dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) if accelerator.is_main_process: init_kwargs = {} @@ -850,9 +850,8 @@ def grad_hook(parameter: torch.Tensor): # 1, # ) # calculate loss - loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/sd3_train_network.py b/sd3_train_network.py index 164f834a2..e7e7cb636 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -382,7 +382,7 @@ def get_noise_pred_and_target( target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/sdxl_train.py b/sdxl_train.py index b62d10b5b..29346c432 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -699,9 +699,7 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -715,6 +713,7 @@ def optimizer_hook(parameter: torch.Tensor): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if ( args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred @@ -723,9 +722,7 @@ def optimizer_hook(parameter: torch.Tensor): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -741,9 +738,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 50f2cf20d..182adf358 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -516,9 +516,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) @@ -537,9 +535,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index aa3fd6623..75a71fd98 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -467,9 +467,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -488,9 +486,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 2d4465234..5b372befc 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -12,6 +12,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -324,7 +325,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -406,7 +409,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +429,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 000000000..f6ade91a6 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,153 @@ +from unittest.mock import patch +from library.train_util import get_optimizer +from train_network import setup_parser +import torch +from torch.nn import Parameter + +# Optimizer libraries +import bitsandbytes as bnb +from lion_pytorch import lion_pytorch +import schedulefree + +import dadaptation +import dadaptation.experimental as dadapt_experimental + +import prodigyopt +import schedulefree as sf +import transformers + + +def test_default_get_optimizer(): + with patch("sys.argv", [""]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "torch.optim.adamw.AdamW" + assert optimizer_args == "" + assert isinstance(optimizer, torch.optim.AdamW) + + +def test_get_schedulefree_optimizer(): + with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree" + assert optimizer_args == "" + assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree) + + +def test_all_supported_optimizers(): + optimizers = [ + { + "name": "bitsandbytes.optim.adamw.AdamW8bit", + "alias": "AdamW8bit", + "instance": bnb.optim.AdamW8bit, + }, + { + "name": "lion_pytorch.lion_pytorch.Lion", + "alias": "Lion", + "instance": lion_pytorch.Lion, + }, + { + "name": "torch.optim.adamw.AdamW", + "alias": "AdamW", + "instance": torch.optim.AdamW, + }, + { + "name": "bitsandbytes.optim.lion.Lion8bit", + "alias": "Lion8bit", + "instance": bnb.optim.Lion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW8bit", + "alias": "PagedAdamW8bit", + "instance": bnb.optim.PagedAdamW8bit, + }, + { + "name": "bitsandbytes.optim.lion.PagedLion8bit", + "alias": "PagedLion8bit", + "instance": bnb.optim.PagedLion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW", + "alias": "PagedAdamW", + "instance": bnb.optim.PagedAdamW, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW32bit", + "alias": "PagedAdamW32bit", + "instance": bnb.optim.PagedAdamW32bit, + }, + {"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD}, + { + "name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint", + "alias": "DAdaptAdamPreprint", + "instance": dadapt_experimental.DAdaptAdamPreprint, + }, + { + "name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad", + "alias": "DAdaptAdaGrad", + "instance": dadaptation.DAdaptAdaGrad, + }, + { + "name": "dadaptation.dadapt_adan.DAdaptAdan", + "alias": "DAdaptAdan", + "instance": dadaptation.DAdaptAdan, + }, + { + "name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP", + "alias": "DAdaptAdanIP", + "instance": dadapt_experimental.DAdaptAdanIP, + }, + { + "name": "dadaptation.dadapt_lion.DAdaptLion", + "alias": "DAdaptLion", + "instance": dadaptation.DAdaptLion, + }, + { + "name": "dadaptation.dadapt_sgd.DAdaptSGD", + "alias": "DAdaptSGD", + "instance": dadaptation.DAdaptSGD, + }, + { + "name": "prodigyopt.prodigy.Prodigy", + "alias": "Prodigy", + "instance": prodigyopt.Prodigy, + }, + { + "name": "transformers.optimization.Adafactor", + "alias": "Adafactor", + "instance": transformers.optimization.Adafactor, + }, + { + "name": "schedulefree.adamw_schedulefree.AdamWScheduleFree", + "alias": "AdamWScheduleFree", + "instance": sf.AdamWScheduleFree, + }, + { + "name": "schedulefree.sgd_schedulefree.SGDScheduleFree", + "alias": "SGDScheduleFree", + "instance": sf.SGDScheduleFree, + }, + ] + + for opt in optimizers: + with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, _, optimizer = get_optimizer(args, [param]) + assert optimizer_name == opt.get("name") + + instance = opt.get("instance") + assert instance is not None + assert isinstance(optimizer, instance) diff --git a/train_controlnet.py b/train_controlnet.py index 8c7882c8f..177d2b11f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -307,10 +307,12 @@ def __contains__(self, name): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -464,9 +466,7 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c( - args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device - ) + timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -498,9 +498,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_db.py b/train_db.py index 4038b94e0..5d464f883 100644 --- a/train_db.py +++ b/train_db.py @@ -370,9 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -384,9 +382,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index 4714e06a9..5fd9be198 100644 --- a/train_network.py +++ b/train_network.py @@ -61,6 +61,7 @@ def generate_step_logs( avr_loss, lr_scheduler, lr_descriptions, + optimizer=None, keys_scaled=None, mean_norm=None, maximum_norm=None, @@ -93,6 +94,30 @@ def generate_step_logs( logs[f"lr/d*lr/{lr_desc}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + ) return logs @@ -192,7 +217,7 @@ def get_noise_pred_and_target( ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -244,7 +269,7 @@ def get_noise_pred_and_target( network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, huber_c, None + return noise_pred, target, timesteps, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): if args.min_snr_gamma: @@ -806,6 +831,7 @@ def load_model_hook(models, input_dir): "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), @@ -1193,7 +1219,7 @@ def remove_model(old_ckpt_name): text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -1206,9 +1232,8 @@ def remove_model(old_ckpt_name): train_unet, ) - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): @@ -1279,7 +1304,7 @@ def remove_model(old_ckpt_name): if len(accelerator.trackers) > 0: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3cc815cdc..5a48c6adf 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -585,7 +585,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -601,9 +601,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 52d525fc5..2a2b42310 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -461,7 +463,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,7 +475,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3])