From ccfaa001e74f80798e528b4b3ea6ef811017c07b Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 20:21:28 +0900 Subject: [PATCH 01/32] add flux controlnet base module --- flux_train_control_net.py | 573 ++++++++++++++++++++++++++++++++++++++ flux_train_network.py | 5 +- library/flux_models.py | 257 ++++++++++++++++- library/flux_utils.py | 8 + 4 files changed, 841 insertions(+), 2 deletions(-) create mode 100644 flux_train_control_net.py diff --git a/flux_train_control_net.py b/flux_train_control_net.py new file mode 100644 index 000000000..704c4d32e --- /dev/null +++ b/flux_train_control_net.py @@ -0,0 +1,573 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + + """ + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def prepare_block_swap_before_forward(self): + pass + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) + """ + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--split_mode", + action="store_true", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32e..0feb9b011 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,7 +125,10 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + controlnet = flux_utils.load_controlnet() + controlnet.train() + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) diff --git a/library/flux_models.py b/library/flux_models.py index fa3c7ad2b..a3bd19743 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1013,6 +1013,8 @@ def forward( txt_ids: Tensor, timesteps: Tensor, y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, ) -> Tensor: @@ -1031,18 +1033,29 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + if block_controlnet_single_hidden_states is not None: + controlnet_single_depth = len(block_controlnet_single_hidden_states) if not self.blocks_to_swap: - for block in self.double_blocks: + for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1052,6 +1065,8 @@ def forward( self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1066,6 +1081,246 @@ def forward( return img +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(0) # TMP + ] + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # add ControlNet blocks + self.controlnet_blocks_for_double = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks_for_single = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_single.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> tuple[tuple[Tensor]]: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_samples = () + block_single_samples = () + if not self.blocks_to_swap: + for block_idx, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + controlnet_block_samples = () + controlnet_single_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + block_sample = controlnet_block(block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) + + return controlnet_block_samples, controlnet_single_block_samples + + """ class FluxUpper(nn.Module): "" diff --git a/library/flux_utils.py b/library/flux_utils.py index f3093615d..678efbc8a 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,6 +153,14 @@ def load_ae( return ae +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + + def load_clip_l( ckpt_path: Optional[str], dtype: torch.dtype, From 42f6edf3a886287b99770bc7a8c0bafd3fa03f39 Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 23:48:51 +0900 Subject: [PATCH 02/32] fix for adding controlnet --- flux_train_control_net.py | 1270 +++++++++++++++++++++-------------- flux_train_network.py | 3 - library/flux_train_utils.py | 32 +- library/flux_utils.py | 11 +- 4 files changed, 820 insertions(+), 496 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 704c4d32e..8a7be75f2 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -1,563 +1,860 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math -import random -from typing import Any, Optional +import os +from multiprocessing import Value +import time +from typing import List, Optional, Tuple, Union +import toml + +from tqdm import tqdm import torch -from accelerate import Accelerator +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util -import train_network -from library.utils import setup_logging +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True -class FluxNetworkTrainer(train_network.NetworkTrainer): - def __init__(self): - super().__init__() - self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None - self.is_swapping_blocks: bool = False + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) - # sdxl_train_util.verify_sdxl_training_args(args) + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) - if args.fp8_base_unet: - args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None - if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" - ) - args.cache_text_encoder_outputs = True + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - # prepare CLIP-L/T5XXL training flags - self.train_clip_l = not args.network_train_unet_only - self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return - if args.max_token_length is not None: - logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if args.cache_text_encoder_outputs: assert ( - args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - # deprecated split_mode option - if args.split_mode: - if args.blocks_to_swap is not None: - logger.warning( - "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." - " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" - ) - else: - logger.warning( - "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." - " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" - ) - args.blocks_to_swap = 18 # 18 is safe for most cases + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) - train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - def load_target_model(self, args, weight_dtype, accelerator): - # currently offload to cpu for some models + # モデルを読み込む - # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) - loading_dtype = None if args.fp8_base else weight_dtype + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() - # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask ) - if args.fp8_base: - # check dtype of model - if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") - elif model.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 FLUX model") - else: - logger.info( - "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." - " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" - ) - model.to(torch.float8_e4m3fn) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) - # if args.split_mode: - # model = self.prepare_split_model(model, weight_dtype, accelerator) + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + flux.requires_grad_(False) - self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 - if self.is_swapping_blocks: - # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. - logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - model.enable_block_swap(args.blocks_to_swap, accelerator.device) + # load controlnet + controlnet = flux_utils.load_controlnet() + controlnet.requires_grad_(True) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - clip_l.eval() + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) - # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) - if args.fp8_base and not args.fp8_base_unet: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype + # block swap - # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - t5xxl.eval() - if args.fp8_base and not args.fp8_base_unet: - # check dtype of model - if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") - elif t5xxl.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 T5XXL model") + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(controlnet) + name_and_params = list(controlnet.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(controlnet.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) - def get_tokenize_strategy(self, args): - _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - if args.t5xxl_max_token_length is None: - if is_schnell: - t5xxl_max_token_length = 256 - else: - t5xxl_max_token_length = 512 - else: - t5xxl_max_token_length = args.t5xxl_max_token_length + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] - logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") - return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) - def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): - return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() - def get_latents_caching_strategy(self, args): - latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) - return latents_caching_strategy + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) - def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 - def post_process_network(self, args, accelerator, network, text_encoders, unet): - # check t5xxl is trained or not - self.train_t5xxl = network.train_t5xxl + for m in training_models: + m.train() - if self.train_t5xxl and args.cache_text_encoder_outputs: - raise ValueError( - "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" - ) + for step, batch in enumerate(train_dataloader): + current_step.value = global_step - def get_models_for_text_encoding(self, args, accelerator, text_encoders): - if args.cache_text_encoder_outputs: - if self.train_clip_l and not self.train_t5xxl: - return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached - else: - return None # no text encoders are needed for encoding because both are cached - else: - return text_encoders # both CLIP-L and T5XXL are needed for encoding + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - def get_text_encoders_train_flags(self, args, text_encoders): - return [self.train_clip_l, self.train_t5xxl] + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - def get_text_encoder_outputs_caching_strategy(self, args): - if args.cache_text_encoder_outputs: - # if the text encoders is trained, we need tokenization, so is_partial is True - return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - args.skip_cache_check, - is_partial=self.train_clip_l or self.train_t5xxl, - apply_t5_attn_mask=args.apply_t5_attn_mask, - ) - else: - return None + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps - def cache_text_encoder_outputs_if_needed( - self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype - ): - if args.cache_text_encoder_outputs: - if not args.lowram: - # メモリ消費を減らす - logger.info("move vae and unet to cpu to save memory") - org_vae_device = vae.device - org_unet_device = unet.device - vae.to("cpu") - unet.to("cpu") - clean_memory_on_device(accelerator.device) - - # When TE is not be trained, it will not be prepared so we need to use explicit autocast - logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 - text_encoders[1].to(accelerator.device) - - if text_encoders[1].dtype == torch.float8_e4m3fn: - # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) - else: - # otherwise, we need to convert it to target dtype - text_encoders[1].to(weight_dtype) - - with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) - - # cache sample prompts - if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - - prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs - with accelerator.autocast(), torch.no_grad(): - for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: - if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask - ) - self.sample_prompts_te_outputs = sample_prompts_te_outputs - - accelerator.wait_for_everyone() - - # move back to cpu - if not self.is_train_text_encoder(args): - logger.info("move CLIP-L back to cpu") - text_encoders[0].to("cpu") - logger.info("move t5XXL back to cpu") - text_encoders[1].to("cpu") - clean_memory_on_device(accelerator.device) - - if not args.lowram: - logger.info("move vae and unet back to original device") - vae.to(org_vae_device) - unet.to(org_unet_device) - else: - # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device) + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] - # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) - # # get size embeddings - # orig_size = batch["original_sizes_hw"] - # crop_size = batch["crop_top_lefts"] - # target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - # # concat embeddings - # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - # return noise_pred + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None - def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): - text_encoders = text_encoder # for compatibility - text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + with accelerator.autocast(): + block_samples, block_single_samples = controlnet( + img=packed_noisy_model_input, + img_ids=img_ids, + controlnet_cond=batch["control_image"].to(accelerator.device), + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) - # return - - """ - class FluxUpperLowerWrapper(torch.nn.Module): - def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): - super().__init__() - self.flux_upper = flux_upper - self.flux_lower = flux_lower - self.target_device = device - - def prepare_block_swap_before_forward(self): - pass - - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): - self.flux_lower.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) - self.flux_upper.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe, txt_attention_mask) - - wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) - clean_memory_on_device(accelerator.device) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs - ) - clean_memory_on_device(accelerator.device) - """ - - def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) - self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) - return noise_scheduler - - def encode_images_to_latents(self, args, accelerator, vae, images): - return vae.encode(images) - - def shift_scale_latents(self, args, latents): - return latents - - def get_noise_pred_and_target( - self, - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet: flux_models.Flux, - network, - weight_dtype, - train_unet, - ): - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype - ) + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance - # ensure guidance_scale in args is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - noisy_model_input.requires_grad_(True) - for t in text_encoder_conds: - if t is not None and t.dtype.is_floating_point: - t.requires_grad_(True) - img_ids.requires_grad_(True) - guidance_vec.requires_grad_(True) - - # Predict the noise residual - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=img, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), ) + optimizer_train_fn() - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - - return model_pred - - model_pred = call_dit( - img=packed_noisy_model_input, - img_ids=img_ids, - t5_out=t5_out, - txt_ids=txt_ids, - l_pooled=l_pooled, - timesteps=timesteps, - guidance_vec=guidance_vec, - t5_attn_mask=t5_attn_mask, - ) + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(): - model_pred_prior = call_dit( - img=packed_noisy_model_input[diff_output_pr_indices], - img_ids=img_ids[diff_output_pr_indices], - t5_out=t5_out[diff_output_pr_indices], - txt_ids=txt_ids[diff_output_pr_indices], - l_pooled=l_pooled[diff_output_pr_indices], - timesteps=timesteps[diff_output_pr_indices], - guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, - t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, - ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + accelerator.log(logs, step=global_step) - model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) - model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( args, - model_pred_prior, - noisy_model_input[diff_output_pr_indices], - sigmas[diff_output_pr_indices] if sigmas is not None else None, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), ) - target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - - return model_pred, target, timesteps, None, weighting - - def post_process_loss(self, loss, args, timesteps, noise_scheduler): - return loss - - def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") - - def update_metadata(self, metadata, args): - metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask - metadata["ss_weighting_scheme"] = args.weighting_scheme - metadata["ss_logit_mean"] = args.logit_mean - metadata["ss_logit_std"] = args.logit_std - metadata["ss_mode_scale"] = args.mode_scale - metadata["ss_guidance_scale"] = args.guidance_scale - metadata["ss_timestep_sampling"] = args.timestep_sampling - metadata["ss_sigmoid_scale"] = args.sigmoid_scale - metadata["ss_model_prediction_type"] = args.model_prediction_type - metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift - - def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) - - def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): - if index == 0: # CLIP-L - return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) - else: # T5XXL - text_encoder.encoder.embed_tokens.requires_grad_(True) - - def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): - if index == 0: # CLIP-L - logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") - text_encoder.to(te_weight_dtype) # fp8 - text_encoder.text_model.embeddings.to(dtype=weight_dtype) - else: # T5XXL - - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: - logger.info(f"T5XXL already prepared for fp8") - else: - logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") - text_encoder.to(te_weight_dtype) # fp8 - prepare_fp8(text_encoder, weight_dtype) - def prepare_unet_with_accelerator( - self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module - ) -> torch.nn.Module: - if not self.is_swapping_blocks: - return super().prepare_unet_with_accelerator(args, accelerator, unet) + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) - # if we doesn't swap blocks, we can move the model to device - flux: flux_models.Flux = unet - flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) - accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage - accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す - return flux + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: - parser = train_network.setup_parser() + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( - "--split_mode", + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", action="store_true", - # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" - # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", - help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." - " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) return parser @@ -569,5 +866,4 @@ def setup_parser() -> argparse.ArgumentParser: train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) - trainer = FluxNetworkTrainer() - trainer.train(args) + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 0feb9b011..6668012e4 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,9 +125,6 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - controlnet = flux_utils.load_controlnet() - controlnet.train() - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d90644a25..cc3bcb0ec 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,6 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, + controlnet=None ): if steps == 0: if not args.sample_at_first: @@ -67,6 +68,8 @@ def sample_images( flux = accelerator.unwrap_model(flux) if text_encoders is not None: text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if controlnet is not None: + controlnet = accelerator.unwrap_model(controlnet) # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -98,6 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -121,6 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) torch.set_rng_state(rng_state) @@ -142,6 +147,7 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ): assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") @@ -150,7 +156,7 @@ def sample_image_inference( height = prompt_dict.get("height", 512) scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") - # controlnet_image = prompt_dict.get("controlnet_image") + controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) @@ -169,6 +175,9 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 @@ -224,7 +233,7 @@ def sample_image_inference( t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -301,18 +310,37 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + controlnet: Optional[flux_models.ControlNetFlux] = None, + controlnet_img: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: + block_samples, block_single_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_img, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + else: + block_samples = None + block_single_samples = None pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, diff --git a/library/flux_utils.py b/library/flux_utils.py index 678efbc8a..7b538d133 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,11 +153,14 @@ def load_ae( return ae -def load_controlnet(name, device, transformer=None): - with torch.device(device): +def load_controlnet(): + # TODO + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - if transformer is not None: - controlnet.load_state_dict(transformer.state_dict(), strict=False) + # if transformer is not None: + # controlnet.load_state_dict(transformer.state_dict(), strict=False) return controlnet From e358b118afbc93f63dbb5ab6d2412ec553ea9cd7 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 16 Nov 2024 14:49:29 +0900 Subject: [PATCH 03/32] fix dataloader --- flux_train_control_net.py | 84 ++++++++++++++++++++------------------- library/flux_models.py | 17 ++++---- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 8a7be75f2..ee4d0ebf3 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -11,31 +11,36 @@ # - Per-block fused optimizer instances import argparse -from concurrent.futures import ThreadPoolExecutor import copy import math import os -from multiprocessing import Value import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value from typing import List, Optional, Tuple, Union -import toml - -from tqdm import tqdm +import toml import torch import torch.nn as nn +from tqdm import tqdm + from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() from accelerate.utils import set_seed -from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util - -from library.utils import setup_logging, add_logging_arguments +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging setup_logging() import logging @@ -46,10 +51,10 @@ # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( - ConfigSanitizer, BlueprintGenerator, + ConfigSanitizer, ) -from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss def train(args): @@ -85,7 +90,6 @@ def train(args): ) cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する @@ -103,7 +107,7 @@ def train(args): if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "conditioing_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -111,31 +115,17 @@ def train(args): ) ) else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension + ) + } + ] + } blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) @@ -648,12 +638,12 @@ def grad_hook(parameter: torch.Tensor): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_cond=batch["control_image"].to(accelerator.device), + controlnet_img=batch["conditioing_image"].to(accelerator.device), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index a3bd19743..b52ea6f0b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,15 +2,15 @@ # license: Apache-2.0 License -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass import math import os import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from typing import Dict, List, Optional, Union from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -18,6 +18,7 @@ from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint + from library import custom_offloading_utils # USE_REENTRANT = True @@ -1251,7 +1252,7 @@ def forward( self, img: Tensor, img_ids: Tensor, - controlnet_cond: Tensor, + controlnet_img: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1264,10 +1265,10 @@ def forward( # running on sequences img img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_cond = self.pos_embed_input(controlnet_cond) - img = img + controlnet_cond + controlnet_img = self.input_hint_block(controlnet_img) + controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_img = self.pos_embed_input(controlnet_img) + img = img + controlnet_img vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: From b2660bbe7410d7ffa40906a7a09f84a17139cb46 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 10:24:57 +0000 Subject: [PATCH 04/32] train run --- flux_train_control_net.py | 39 ++++++++++++++++++++++--------------- library/flux_models.py | 30 ++++++++++++++-------------- library/flux_train_utils.py | 2 +- library/flux_utils.py | 2 +- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index ee4d0ebf3..205ff6b6a 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -103,11 +103,11 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioing_data_dir"] + ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -263,10 +263,11 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) + flux.to(accelerator.device) # load controlnet controlnet = flux_utils.load_controlnet() - controlnet.requires_grad_(True) + controlnet.train() if args.gradient_checkpointing: controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) @@ -443,7 +444,8 @@ def train(args): clean_memory_on_device(accelerator.device) - if args.deepspeed: + # if args.deepspeed: + if True: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -612,8 +614,10 @@ def grad_hook(parameter: torch.Tensor): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # if args.full_fp16: + # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # TODO: check + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -629,10 +633,10 @@ def grad_hook(parameter: torch.Tensor): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds @@ -640,10 +644,11 @@ def grad_hook(parameter: torch.Tensor): t5_attn_mask = None with accelerator.autocast(): + print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_img=batch["conditioing_image"].to(accelerator.device), + controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -651,6 +656,8 @@ def grad_hook(parameter: torch.Tensor): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) + print("control end") + print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -796,7 +803,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this - train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) @@ -852,12 +859,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="controlnet model name or path / controlnetのモデル名またはパス", ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) + # parser.add_argument( + # "--conditioning_data_dir", + # type=str, + # default=None, + # help="conditioning data directory / 条件付けデータのディレクトリ", + # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index b52ea6f0b..2fc21db9d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1042,20 +1042,20 @@ def forward( if not self.blocks_to_swap: for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] img = torch.cat((txt, img), 1) - for block in self.single_blocks: + for block_idx, block in enumerate(self.single_blocks): img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1066,7 +1066,7 @@ def forward( self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1121,14 +1121,14 @@ def __init__(self, params: FluxParams, controlnet_depth=2): mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, ) - for _ in range(params.depth) + for _ in range(controlnet_depth) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TMP + for _ in range(0) # TODO ] ) @@ -1148,7 +1148,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_double.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(controlnet_depth): + for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) @@ -1252,7 +1252,7 @@ def forward( self, img: Tensor, img_ids: Tensor, - controlnet_img: Tensor, + controlnet_cond: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1265,10 +1265,10 @@ def forward( # running on sequences img img = self.img_in(img) - controlnet_img = self.input_hint_block(controlnet_img) - controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_img = self.pos_embed_input(controlnet_img) - img = img + controlnet_img + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: @@ -1283,7 +1283,7 @@ def forward( block_samples = () block_single_samples = () if not self.blocks_to_swap: - for block_idx, block in enumerate(self.double_blocks): + for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) block_samples = block_samples + (img,) @@ -1315,7 +1315,7 @@ def forward( for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) - for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): block_sample = controlnet_block(block_sample) controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index cc3bcb0ec..d82bde91c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -460,7 +460,7 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents - return noisy_model_input, timesteps, sigmas + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b538d133..4a3817fdb 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -157,7 +157,7 @@ def load_controlnet(): # TODO is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("meta"): + with torch.device("cuda:0"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) # if transformer is not None: # controlnet.load_state_dict(transformer.state_dict(), strict=False) From 35778f021897796410372aed8540547ba317c2a3 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 11:09:05 +0000 Subject: [PATCH 05/32] fix sample_images type --- flux_train_control_net.py | 31 ++++++++++++++----------------- library/flux_train_utils.py | 2 +- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 205ff6b6a..791900d17 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -444,8 +444,7 @@ def train(args): clean_memory_on_device(accelerator.device) - # if args.deepspeed: - if True: + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -644,7 +643,6 @@ def grad_hook(parameter: torch.Tensor): t5_attn_mask = None with accelerator.autocast(): - print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, @@ -656,8 +654,6 @@ def grad_hook(parameter: torch.Tensor): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - print("control end") - print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -763,18 +759,19 @@ def grad_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() optimizer_eval_fn() - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(flux), - ) + # TODO: save cn models + # if args.save_every_n_epochs is not None: + # if accelerator.is_main_process: + # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + # args, + # True, + # accelerator, + # save_dtype, + # epoch, + # num_train_epochs, + # global_step, + # accelerator.unwrap_model(flux), + # ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d82bde91c..de2ee030a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -235,7 +235,7 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - x = x.float() + # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 4dd4cd6ec8c55fa94b53217181ed9c95e59eed56 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 12:47:01 +0000 Subject: [PATCH 06/32] work cn load and validation --- flux_train_control_net.py | 20 ++++---------------- library/flux_models.py | 6 +++--- library/flux_train_utils.py | 18 ++++++++++++++---- library/flux_utils.py | 25 ++++++++++++++++--------- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 791900d17..cbfac418f 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet() + controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -568,7 +568,7 @@ def grad_hook(parameter: torch.Tensor): # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -718,7 +718,7 @@ def grad_hook(parameter: torch.Tensor): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) # 指定ステップごとにモデルを保存 @@ -774,7 +774,7 @@ def grad_hook(parameter: torch.Tensor): # ) flux_train_utils.sample_images( - accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) optimizer_train_fn() @@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - # parser.add_argument( - # "--conditioning_data_dir", - # type=str, - # default=None, - # help="conditioning data directory / 条件付けデータのディレクトリ", - # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index 2fc21db9d..4123b40e5 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1142,11 +1142,11 @@ def __init__(self, params: FluxParams, controlnet_depth=2): self.num_single_blocks = len(self.single_blocks) # add ControlNet blocks - self.controlnet_blocks_for_double = nn.ModuleList([]) + self.controlnet_blocks = nn.ModuleList([]) for _ in range(controlnet_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) @@ -1312,7 +1312,7 @@ def forward( controlnet_block_samples = () controlnet_single_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2ee030a..dbbaba734 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -175,10 +175,6 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") @@ -232,6 +228,12 @@ def sample_image_inference( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) @@ -315,6 +317,8 @@ def denoise( ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() @@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + ) parser.add_argument( "--t5xxl_max_token_length", type=int, diff --git a/library/flux_utils.py b/library/flux_utils.py index 4a3817fdb..fb7a30749 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,15 +153,22 @@ def load_ae( return ae -def load_controlnet(): - # TODO - is_schnell = False - name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("cuda:0"): - controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - # if transformer is not None: - # controlnet.load_state_dict(transformer.state_dict(), strict=False) - return controlnet +def load_controlnet( + ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet def load_clip_l( From 31ca899b6b5425466c814d0d9e2e4e8bfbf93001 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 13:03:28 +0000 Subject: [PATCH 07/32] fix depth value --- library/flux_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 4123b40e5..328ad481d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1093,7 +1093,7 @@ class ControlNetFlux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, params: FluxParams, controlnet_depth=2): + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): super().__init__() self.params = params @@ -1128,7 +1128,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TODO + for _ in range(controlnet_single_depth) ] ) @@ -1148,7 +1148,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(0): # TODO + for _ in range(controlnet_single_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) From 0b5229a9550cb921b83d22472c4785a15c42ba90 Mon Sep 17 00:00:00 2001 From: minux302 Date: Thu, 21 Nov 2024 15:55:27 +0000 Subject: [PATCH 08/32] save cn --- flux_train_control_net.py | 34 +++++++++++++++------------------- library/flux_train_utils.py | 1 - 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index cbfac418f..0f38b7094 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -613,9 +613,6 @@ def grad_hook(parameter: torch.Tensor): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - # if args.full_fp16: - # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # TODO: check text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -733,7 +730,7 @@ def grad_hook(parameter: torch.Tensor): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(flux), + accelerator.unwrap_model(controlnet), ) optimizer_train_fn() @@ -759,19 +756,18 @@ def grad_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() optimizer_eval_fn() - # TODO: save cn models - # if args.save_every_n_epochs is not None: - # if accelerator.is_main_process: - # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - # args, - # True, - # accelerator, - # save_dtype, - # epoch, - # num_train_epochs, - # global_step, - # accelerator.unwrap_model(flux), - # ) + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet @@ -791,7 +787,7 @@ def grad_hook(parameter: torch.Tensor): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet) logger.info("model saved.") diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index dbbaba734..5e25c7feb 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -237,7 +237,6 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 420a180d938c7b5a6e3006b1719dbfeaae72a2cc Mon Sep 17 00:00:00 2001 From: recris Date: Wed, 27 Nov 2024 18:11:51 +0000 Subject: [PATCH 09/32] Implement pseudo Huber loss for Flux and SD3 --- fine_tune.py | 6 +-- flux_train.py | 2 +- flux_train_network.py | 2 +- library/train_util.py | 74 ++++++++++++++++------------ sd3_train.py | 2 +- sd3_train_network.py | 2 +- sdxl_train.py | 6 +-- sdxl_train_control_net.py | 4 +- sdxl_train_control_net_lllite.py | 4 +- sdxl_train_control_net_lllite_old.py | 6 ++- train_controlnet.py | 6 +-- train_db.py | 4 +- train_network.py | 9 ++-- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 6 ++- 15 files changed, 76 insertions(+), 61 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0090bd190..70959a751 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -397,7 +397,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) @@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/flux_train.py b/flux_train.py index a89e2f139..f6e43b27a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor): # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/flux_train_network.py b/flux_train_network.py index 679db62b6..04287f399 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -468,7 +468,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..c204ebd38 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( @@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") - - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = torch.exp(-alpha * timesteps) - elif args.huber_schedule == "snr": - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = torch.full((b_size,), args.huber_c) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - huber_c = huber_c.to(device) - elif args.loss_type == "l2": - huber_c = None # may be anything, as it's not used - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - timesteps = timesteps.long().to(device) - return timesteps, huber_c +def get_timesteps(min_timestep, max_timestep, b_size, device): + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + timesteps = timesteps.long() + return timesteps def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): @@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -5878,24 +5866,46 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps + + +def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, 'alphas_cumprod'): + raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] + args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler ): - if loss_type == "l2": + if args.loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif loss_type == "l1": + elif args.loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) - elif loss_type == "huber": + elif args.loss_type == "huber": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif loss_type == "smooth_l1": + elif args.loss_type == "smooth_l1": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5903,7 +5913,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}") return loss diff --git a/sd3_train.py b/sd3_train.py index 96ec951b9..cf2bdf938 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor): # ) # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/sd3_train_network.py b/sd3_train_network.py index 1726e325f..fb7711bda 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -378,7 +378,7 @@ def get_noise_pred_and_target( target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/sdxl_train.py b/sdxl_train.py index e26f4aa19..1bc27ec6c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -695,7 +695,7 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -720,7 +720,7 @@ def optimizer_hook(parameter: torch.Tensor): ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) @@ -738,7 +738,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 24080afbd..d0051d18f 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -512,7 +512,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -534,7 +534,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 2946c97d4..66214f5df 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -463,7 +463,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -485,7 +485,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 2d4465234..5e10654b9 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +426,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index 8c7882c8f..da7a08d69 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -464,8 +464,8 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c( - args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + timesteps = train_util.get_timesteps( + 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device ) # Add noise to the latents according to the noise magnitude at each timestep @@ -499,7 +499,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/train_db.py b/train_db.py index 51e209f34..a185b31b3 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -385,7 +385,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_network.py b/train_network.py index bbf381f99..c7d4f5dc5 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ def get_noise_pred_and_target( ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -244,7 +244,7 @@ def get_noise_pred_and_target( network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, huber_c, None + return noise_pred, target, timesteps, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): if args.min_snr_gamma: @@ -806,6 +806,7 @@ def load_model_hook(models, input_dir): "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), @@ -1193,7 +1194,7 @@ def remove_model(old_ckpt_name): text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -1207,7 +1208,7 @@ def remove_model(old_ckpt_name): ) loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5f4657eb9..9e1e57c48 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -585,7 +585,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -602,7 +602,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 52d525fc5..944733602 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -461,7 +461,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,7 +473,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 740ec1d5265fa321659589ae6a75a4a9898ef8be Mon Sep 17 00:00:00 2001 From: recris Date: Thu, 28 Nov 2024 20:38:32 +0000 Subject: [PATCH 10/32] Fix issues found in review --- fine_tune.py | 2 +- library/train_util.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 70959a751..401a40f08 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, noise_pred.float(), target.float(), timesteps, "mean", noise_scheduler ) accelerator.backward(loss) diff --git a/library/train_util.py b/library/train_util.py index c204ebd38..eaf6ec004 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5829,8 +5829,8 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep, max_timestep, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - timesteps = timesteps.long() + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + timesteps = timesteps.long().to(device) return timesteps @@ -5875,8 +5875,8 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, 'alphas_cumprod'): - raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.") + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c From 575f583fd9cbaf7f7b644a31437ed9094810b99a Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 29 Nov 2024 23:55:52 +0900 Subject: [PATCH 11/32] add README --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index f9c85e3ac..2b1ca3f8c 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Nov 14, 2024: - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) +- [FLUX.1 ControlNet training](#flux1-controlnet-training) - [FLUX.1 OFT training](#flux1-oft-training) - [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) @@ -245,6 +246,22 @@ example: If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. +### FLUX.1 ControlNet training +We have added a new training script for ControlNet training. The script is flux_train_control_net.py. See --help for options. + +Sample command is below. It will work with 80GB VRAM GPUs. +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors +--ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 +--optimizer_type adamw8bit --learning_rate 2e-5 +--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config dataset.toml +--output_dir /path/to/output/dir --output_name flux-cn +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed +``` + + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. From be5860f8e266c5562f123fe9e0cb3febef615290 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:08:21 +0900 Subject: [PATCH 12/32] add schnell option to load_cn --- flux_train_control_net.py | 4 ++-- library/flux_utils.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index a17c811e3..bb27c35ed 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -259,14 +259,14 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - _, flux = flux_utils.load_flow_model( + is_schnell, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: diff --git a/library/flux_utils.py b/library/flux_utils.py index f2759c375..8be1d63ee 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,14 +1,14 @@ -from dataclasses import replace import json import os +from dataclasses import replace from typing import List, Optional, Tuple, Union + import einops import torch - -from safetensors.torch import load_file -from safetensors import safe_open from accelerate import init_empty_weights -from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from library.utils import setup_logging @@ -154,11 +154,9 @@ def load_ae( def load_controlnet( - ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ): logger.info("Building ControlNet") - # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) - is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL with torch.device(device): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) From f40632bac6704886a7640c327d64820f8f017df8 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:15:47 +0900 Subject: [PATCH 13/32] rm abundant arg --- flux_train_network.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 314335366..fa3810e34 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -6,12 +6,21 @@ import torch from accelerate import Accelerator -from library.device_utils import init_ipex, clean_memory_on_device + +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util import train_network +from library import ( + flux_models, + flux_train_utils, + flux_utils, + sd3_train_utils, + strategy_base, + strategy_flux, + train_util, +) from library.utils import setup_logging setup_logging() @@ -125,7 +134,7 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) From 928b9393daac252d0b6c4c9dd277d549b3dad8e9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 20 Nov 2024 11:15:30 -0500 Subject: [PATCH 14/32] Allow unknown schedule-free optimizers to continue to module loader --- library/train_util.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..74050880a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4600,7 +4600,7 @@ def task(): def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" - + optimizer_type = args.optimizer_type if args.use_8bit_adam: assert ( @@ -4874,6 +4874,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type.endswith("schedulefree".lower()): + should_train_optimizer = True try: import schedulefree as sf except ImportError: @@ -4885,10 +4886,10 @@ def get_optimizer(args, trainable_params): optimizer_class = sf.SGDScheduleFree logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop - optimizer.train() + optimizer_class = None + + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う @@ -4990,6 +4991,10 @@ def __instancecheck__(self, instance): optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + if hasattr(optimizer, 'train') and callable(optimizer.train): + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + return optimizer_name, optimizer_args, optimizer From 87f5224e2d19254748158939cbca75802fc024f2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 20 Nov 2024 11:57:15 -0500 Subject: [PATCH 15/32] Support d*lr for ProdigyPlus optimizer --- train_network.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index bbf381f99..65962bd74 100644 --- a/train_network.py +++ b/train_network.py @@ -61,6 +61,7 @@ def generate_step_logs( avr_loss, lr_scheduler, lr_descriptions, + optimizer=None, keys_scaled=None, mean_norm=None, maximum_norm=None, @@ -93,6 +94,30 @@ def generate_step_logs( logs[f"lr/d*lr/{lr_desc}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) return logs @@ -1279,7 +1304,7 @@ def remove_model(old_ckpt_name): if len(accelerator.trackers) > 0: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) From 6593cfbec14c0be70407b5d6d85d569ecf8160f1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 21 Nov 2024 14:41:37 -0500 Subject: [PATCH 16/32] Fix d * lr step log --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 65962bd74..c236a2c95 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ def generate_step_logs( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) return logs From c7cadbc8c73b48eaacbfb44b18121d20df373e19 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 15:52:03 -0500 Subject: [PATCH 17/32] Add pytest testing --- .github/workflows/tests.yml | 54 +++++++++++++ library/train_util.py | 4 +- pytest.ini | 7 ++ tests/test_optimizer.py | 153 ++++++++++++++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 pytest.ini create mode 100644 tests/test_optimizer.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..50b08243a --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,54 @@ + +name: Python package + +on: [push] + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel + + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Test with pytest + run: | + pip install pytest pytest-cov + pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html + + - name: Upload pytest test results + uses: actions/upload-artifact@v4 + with: + name: pytest-results-${{ matrix.python-version }} + path: junit/test-results-${{ matrix.python-version }}.xml + # Use always() to always run this step to publish test results when there are test failures + if: ${{ always() }} diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..823cd3663 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -21,7 +21,7 @@ Optional, Sequence, Tuple, - Union, + Union ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -4598,7 +4598,7 @@ def task(): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): +def get_optimizer(args, trainable_params) -> tuple[str, str, object]: # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..63e03efc5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +minversion = 6.0 +testpaths = + tests +filterwarnings = + ignore::DeprecationWarning + ignore::UserWarning diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 000000000..f6ade91a6 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,153 @@ +from unittest.mock import patch +from library.train_util import get_optimizer +from train_network import setup_parser +import torch +from torch.nn import Parameter + +# Optimizer libraries +import bitsandbytes as bnb +from lion_pytorch import lion_pytorch +import schedulefree + +import dadaptation +import dadaptation.experimental as dadapt_experimental + +import prodigyopt +import schedulefree as sf +import transformers + + +def test_default_get_optimizer(): + with patch("sys.argv", [""]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "torch.optim.adamw.AdamW" + assert optimizer_args == "" + assert isinstance(optimizer, torch.optim.AdamW) + + +def test_get_schedulefree_optimizer(): + with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree" + assert optimizer_args == "" + assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree) + + +def test_all_supported_optimizers(): + optimizers = [ + { + "name": "bitsandbytes.optim.adamw.AdamW8bit", + "alias": "AdamW8bit", + "instance": bnb.optim.AdamW8bit, + }, + { + "name": "lion_pytorch.lion_pytorch.Lion", + "alias": "Lion", + "instance": lion_pytorch.Lion, + }, + { + "name": "torch.optim.adamw.AdamW", + "alias": "AdamW", + "instance": torch.optim.AdamW, + }, + { + "name": "bitsandbytes.optim.lion.Lion8bit", + "alias": "Lion8bit", + "instance": bnb.optim.Lion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW8bit", + "alias": "PagedAdamW8bit", + "instance": bnb.optim.PagedAdamW8bit, + }, + { + "name": "bitsandbytes.optim.lion.PagedLion8bit", + "alias": "PagedLion8bit", + "instance": bnb.optim.PagedLion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW", + "alias": "PagedAdamW", + "instance": bnb.optim.PagedAdamW, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW32bit", + "alias": "PagedAdamW32bit", + "instance": bnb.optim.PagedAdamW32bit, + }, + {"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD}, + { + "name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint", + "alias": "DAdaptAdamPreprint", + "instance": dadapt_experimental.DAdaptAdamPreprint, + }, + { + "name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad", + "alias": "DAdaptAdaGrad", + "instance": dadaptation.DAdaptAdaGrad, + }, + { + "name": "dadaptation.dadapt_adan.DAdaptAdan", + "alias": "DAdaptAdan", + "instance": dadaptation.DAdaptAdan, + }, + { + "name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP", + "alias": "DAdaptAdanIP", + "instance": dadapt_experimental.DAdaptAdanIP, + }, + { + "name": "dadaptation.dadapt_lion.DAdaptLion", + "alias": "DAdaptLion", + "instance": dadaptation.DAdaptLion, + }, + { + "name": "dadaptation.dadapt_sgd.DAdaptSGD", + "alias": "DAdaptSGD", + "instance": dadaptation.DAdaptSGD, + }, + { + "name": "prodigyopt.prodigy.Prodigy", + "alias": "Prodigy", + "instance": prodigyopt.Prodigy, + }, + { + "name": "transformers.optimization.Adafactor", + "alias": "Adafactor", + "instance": transformers.optimization.Adafactor, + }, + { + "name": "schedulefree.adamw_schedulefree.AdamWScheduleFree", + "alias": "AdamWScheduleFree", + "instance": sf.AdamWScheduleFree, + }, + { + "name": "schedulefree.sgd_schedulefree.SGDScheduleFree", + "alias": "SGDScheduleFree", + "instance": sf.SGDScheduleFree, + }, + ] + + for opt in optimizers: + with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, _, optimizer = get_optimizer(args, [param]) + assert optimizer_name == opt.get("name") + + instance = opt.get("instance") + assert instance is not None + assert isinstance(optimizer, instance) From 2dd063a679effae2538c474fece1e7aacad0c9c5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 15:57:31 -0500 Subject: [PATCH 18/32] add torch torchvision accelerate versions --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 50b08243a..96ab612d8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,6 +40,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 - name: Test with pytest run: | pip install pytest pytest-cov From e59e276fb948a1dc8a64672d8fd6d3a7eb166c80 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:03:29 -0500 Subject: [PATCH 19/32] Add dadaptation --- .github/workflows/tests.yml | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 96ab612d8..433c326bf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10", "3.11"] + python-version: ["3.10"] steps: - uses: actions/checkout@v4 @@ -26,30 +26,14 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.x' + cache: 'pip' # caching pip dependencies - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 - name: Test with pytest run: | - pip install pytest pytest-cov - pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html + pip install pytest + pytest - - name: Upload pytest test results - uses: actions/upload-artifact@v4 - with: - name: pytest-results-${{ matrix.python-version }} - path: junit/test-results-${{ matrix.python-version }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: ${{ always() }} From dd3b846b54814b605bd33ae08ed480ea5075483b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:18:05 -0500 Subject: [PATCH 20/32] Install pytorch first to pin version --- .github/workflows/tests.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 433c326bf..9ae67b0e9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,6 +18,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.x' + - name: Install dependencies run: python -m pip install --upgrade pip setuptools wheel @@ -27,11 +28,13 @@ jobs: with: python-version: '3.x' cache: 'pip' # caching pip dependencies + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + pip install -r requirements.txt + - name: Test with pytest run: | pip install pytest From 89825d6898ba6629b18cc8c1f9fbd93a730ff36e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:27:13 -0500 Subject: [PATCH 21/32] Run typos workflows once where appropriate --- .github/workflows/typos.yml | 6 ++++-- pytest.ini | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 0149dcdd3..667146a7a 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -1,9 +1,11 @@ --- -# yamllint disable rule:line-length name: Typos -on: # yamllint disable-line rule:truthy +on: push: + branches: + - main + - dev pull_request: types: - opened diff --git a/pytest.ini b/pytest.ini index 63e03efc5..484d3aef6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,3 +5,4 @@ testpaths = filterwarnings = ignore::DeprecationWarning ignore::UserWarning + ignore::FutureWarning From 4f7f248071c93f539c12c8a35380b6d983bfff4c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:28:51 -0500 Subject: [PATCH 22/32] Bump typos action --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 667146a7a..87ebdf894 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -20,4 +20,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.28.1 From 9c885e549dbb5535b37f2a3220b5a8f53ad4d211 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 30 Nov 2024 18:25:50 +0900 Subject: [PATCH 23/32] fix: improve pos_embed handling for oversized images and update resolution_area_to_latent_size, when sample image size > train image size --- library/sd3_models.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 8b90205db..2f3c82eed 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1017,22 +1017,35 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b patched_size = patched_size_ break if patched_size is None: - raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] pos_embed = self.resolution_pos_embeds[patched_size] - pos_embed_size = round(math.sqrt(pos_embed.shape[1])) + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO if h > pos_embed_size or w > pos_embed_size: # # fallback to normal pos_embed # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) # extend pos_embed size logger.warning( - f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - pos_embed_size = max(h, w) - pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) self.resolution_pos_embeds[patched_size] = pos_embed - logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) if not random_crop: top = (pos_embed_size - h) // 2 From 7b61e9eb58e0a004b451e8f06c9f90b861f81b45 Mon Sep 17 00:00:00 2001 From: recris Date: Sat, 30 Nov 2024 11:36:40 +0000 Subject: [PATCH 24/32] Fix issues found in review (pt 2) --- library/train_util.py | 2 +- sd3_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index eaf6ec004..d5e72323a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5875,7 +5875,7 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): + if noise_scheduler is None or not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 diff --git a/sd3_train.py b/sd3_train.py index cf2bdf938..909c5ead6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor): # ) # calculate loss loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, model_pred.float(), target.float(), timesteps, "none", None ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) From 14f642f88be888ce1a4157b550186347c159ca42 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 13:30:35 +0900 Subject: [PATCH 25/32] fix: huber_schedule exponential not working on sd3_train.py --- library/train_util.py | 2 +- sd3_train.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d5e72323a..eaf6ec004 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5875,7 +5875,7 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if noise_scheduler is None or not hasattr(noise_scheduler, "alphas_cumprod"): + if not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 diff --git a/sd3_train.py b/sd3_train.py index 909c5ead6..73a68aa6a 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -675,8 +675,8 @@ def grad_hook(parameter: torch.Tensor): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - # noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # only used to get timesteps, etc. TODO manage timesteps etc. separately + dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) if accelerator.is_main_process: init_kwargs = {} @@ -844,9 +844,7 @@ def grad_hook(parameter: torch.Tensor): # 1, # ) # calculate loss - loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", None - ) + loss = train_util.conditional_loss(args, model_pred.float(), target.float(), timesteps, "none", dummy_scheduler) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 0fe6320f09a61859c3faa134affb810cb42b62cd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 14:13:37 +0900 Subject: [PATCH 26/32] fix flux_train.py is not working --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index f6e43b27a..cfe14885e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor): # calculate loss loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting From cc11989755d0dd61f10eeec85983c751fd7ebb47 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:20:28 +0900 Subject: [PATCH 27/32] fix: refactor huber-loss calculation in multiple training scripts --- fine_tune.py | 13 ++++--------- flux_train.py | 5 ++--- library/train_util.py | 21 +++++++++++---------- sd3_train.py | 3 ++- sdxl_train.py | 13 ++++--------- sdxl_train_control_net.py | 9 +++------ sdxl_train_control_net_lllite.py | 9 +++------ sdxl_train_control_net_lllite_old.py | 10 ++++++---- train_controlnet.py | 11 +++++------ train_db.py | 9 +++------ train_network.py | 5 ++--- train_textual_inversion.py | 5 ++--- train_textual_inversion_XTI.py | 9 +++++---- 13 files changed, 52 insertions(+), 70 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 401a40f08..176087065 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,9 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -394,11 +392,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -410,9 +407,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "mean", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/flux_train.py b/flux_train.py index cfe14885e..fced3bef9 100644 --- a/flux_train.py +++ b/flux_train.py @@ -666,9 +666,8 @@ def grad_hook(parameter: torch.Tensor): target = noise - latents # calculate loss - loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): diff --git a/library/train_util.py b/library/train_util.py index eaf6ec004..fe74ddc7e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5869,7 +5869,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps -def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + b_size = timesteps.shape[0] if args.huber_schedule == "exponential": alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps @@ -5890,22 +5893,20 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch def conditional_loss( - args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler + model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): - if args.loss_type == "l2": + if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif args.loss_type == "l1": + elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) - elif args.loss_type == "huber": - huber_c = get_huber_threshold(args, timesteps, noise_scheduler) + elif loss_type == "huber": huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif args.loss_type == "smooth_l1": - huber_c = get_huber_threshold(args, timesteps, noise_scheduler) + elif loss_type == "smooth_l1": huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5913,7 +5914,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {loss_type}") return loss @@ -5923,7 +5924,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") - names.append("text_encoder3") # SD3 + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/sd3_train.py b/sd3_train.py index 73a68aa6a..120455e7b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -844,7 +844,8 @@ def grad_hook(parameter: torch.Tensor): # 1, # ) # calculate loss - loss = train_util.conditional_loss(args, model_pred.float(), target.float(), timesteps, "none", dummy_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train.py b/sdxl_train.py index 1bc27ec6c..b9d529243 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -695,9 +695,7 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -711,6 +709,7 @@ def optimizer_hook(parameter: torch.Tensor): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if ( args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred @@ -719,9 +718,7 @@ def optimizer_hook(parameter: torch.Tensor): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -737,9 +734,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index d0051d18f..01387409a 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -512,9 +512,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) @@ -533,9 +531,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 66214f5df..365059b75 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -463,9 +463,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -484,9 +482,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5e10654b9..5b372befc 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -12,6 +12,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -324,7 +325,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -426,9 +429,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index da7a08d69..177d2b11f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -307,10 +307,12 @@ def __contains__(self, name): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -464,9 +466,7 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps = train_util.get_timesteps( - 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device - ) + timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -498,9 +498,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_db.py b/train_db.py index a185b31b3..ad21f8d1b 100644 --- a/train_db.py +++ b/train_db.py @@ -370,9 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -384,9 +382,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index c7d4f5dc5..0b4208187 100644 --- a/train_network.py +++ b/train_network.py @@ -1207,9 +1207,8 @@ def remove_model(old_ckpt_name): train_unet, ) - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9e1e57c48..65da4859b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -601,9 +601,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 944733602..2a2b42310 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -473,9 +475,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 14760407871c7eaa26210c7db71ce2740a817c4c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:26:39 +0900 Subject: [PATCH 28/32] fix: update help text for huber loss parameters in train_util.py --- library/train_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index fe74ddc7e..a40983a68 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3905,14 +3905,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1" + " / Huber損失の減衰パラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( "--huber_scale", type=float, default=1.0, - help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0" + " / Huber損失のスケールパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは1.0", ) parser.add_argument( From 34e7f509c41491f9a08c16c8ead2adf5cb210ec1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:36:24 +0900 Subject: [PATCH 29/32] docs: update README for huber loss --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index f9c85e3ac..89a96827c 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates +1 Dec, 2024: + +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! + - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. + Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. From 1dc873d9b463d50e27ae8572c28a473ce9a1254f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 22:00:44 +0900 Subject: [PATCH 30/32] update README and clean up code for schedulefree optimizer --- README.md | 4 +++- library/train_util.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 89a96827c..8db5c4d42 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,11 @@ The command to install PyTorch is as follows: 1 Dec, 2024: -- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. +- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO! + Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. diff --git a/library/train_util.py b/library/train_util.py index 289ab8235..6cfd14d5e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4609,7 +4609,7 @@ def task(): def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" - + optimizer_type = args.optimizer_type if args.use_8bit_adam: assert ( @@ -4883,7 +4883,6 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type.endswith("schedulefree".lower()): - should_train_optimizer = True try: import schedulefree as sf except ImportError: @@ -5000,8 +4999,8 @@ def __instancecheck__(self, instance): optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - if hasattr(optimizer, 'train') and callable(optimizer.train): - # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + if hasattr(optimizer, "train") and callable(optimizer.train): + # make optimizer as train mode before training for schedulefree optimizer. the optimizer will be in eval mode in sampling and saving. optimizer.train() return optimizer_name, optimizer_args, optimizer From e369b9a252b90d1f57ea20dd6f5d05ec0c287ae1 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 2 Dec 2024 23:38:54 +0900 Subject: [PATCH 31/32] docs: update README with FLUX.1 ControlNet training details and improve argument help text --- README.md | 10 +++++++++- library/flux_train_utils.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 45e3cb7ab..6a5cdd342 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,15 @@ The command to install PyTorch is as follows: ### Recent Updates -1 Dec, 2024: +Dec 2, 2024: + +- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. + - Not fully tested. Feedback is welcome. + - 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution. + - Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required. + - Multi-GPU training is not tested. + +Dec 1, 2024: - Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 5e25c7feb..de2e2b48d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -567,7 +567,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--controlnet", type=str, default=None, - help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" ) parser.add_argument( "--t5xxl_max_token_length", From 8b36d907d8635dca64224574b5cb15013e00809d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 3 Dec 2024 08:43:26 +0900 Subject: [PATCH 32/32] feat: support block_to_swap for FLUX.1 ControlNet training --- README.md | 13 +++++++++++ flux_train_control_net.py | 46 +++++++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6a5cdd342..f02725191 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates + +Dec 3, 2024: + +-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training). + Dec 2, 2024: - FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. @@ -276,6 +281,14 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_tr --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed ``` +For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested. +``` + --blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk +``` + +The training can be done with 16GB VRAM GPUs with around 30 blocks swapped. + +`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used. ### FLUX.1 OFT training diff --git a/flux_train_control_net.py b/flux_train_control_net.py index bb27c35ed..5548fd991 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -119,9 +119,7 @@ def train(args): "datasets": [ { "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension + args.train_data_dir, args.conditioning_data_dir, args.caption_extension ) } ] @@ -263,13 +261,17 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) - flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype + controlnet = flux_utils.load_controlnet( + args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + ) controlnet.train() if args.gradient_checkpointing: + if not args.deepspeed: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) # block swap @@ -296,7 +298,11 @@ def train(args): # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") flux.enable_block_swap(args.blocks_to_swap, accelerator.device) - controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + # ControlNet only has two blocks, so we can keep it on GPU + # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + else: + flux.to(accelerator.device) if not cache_latents: # load VAE here if not cached @@ -455,9 +461,7 @@ def train(args): else: # accelerator does some magic # if we doesn't swap blocks, we can move the model to device - controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) - if is_swapping_blocks: - accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks]) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -564,11 +568,13 @@ def grad_hook(parameter: torch.Tensor): ) if is_swapping_blocks: - accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() + flux.prepare_block_swap_before_forward() # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) + flux_train_utils.sample_images( + accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -629,7 +635,11 @@ def grad_hook(parameter: torch.Tensor): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) + img_ids = ( + flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width) + .to(device=accelerator.device) + .to(weight_dtype) + ) # get guidance: ensure args.guidance_scale is float guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) @@ -638,7 +648,7 @@ def grad_hook(parameter: torch.Tensor): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, @@ -715,7 +725,15 @@ def grad_hook(parameter: torch.Tensor): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + accelerator, + args, + None, + global_step, + flux, + ae, + [clip_l, t5xxl], + sample_prompts_te_outputs, + controlnet=controlnet, ) # 指定ステップごとにモデルを保存