From e8369ca166d6147879681ac7bbe98d52305d4251 Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Mon, 16 Jan 2023 07:36:47 +0000 Subject: [PATCH] [WIP] Low Rank Adaptation First implementation of the paper "Low-rank Adaptation for Fast Text-to-Image Diffusion Fine-tuning" for riffusion. Still needs to be integrated a lot more. Reference: https://github.com/cloneofsimo/lora Topic: lora_1 Relative: pipeline_lock --- pyproject.toml | 4 + requirements.txt | 2 + riffusion/external/lora/__init__.py | 0 .../external/lora/run_lora_db_unet_only.sh | 24 + riffusion/external/lora/train_lora.py | 49 + riffusion/external/lora/train_lora.sh | 37 + .../external/lora/train_lora_dreambooth.py | 958 ++++++++++++++++++ riffusion/streamlit/pages/audio_to_audio.py | 7 +- riffusion/streamlit/pages/text_to_audio.py | 5 + riffusion/streamlit/util.py | 54 +- 10 files changed, 1136 insertions(+), 4 deletions(-) create mode 100644 riffusion/external/lora/__init__.py create mode 100644 riffusion/external/lora/run_lora_db_unet_only.sh create mode 100644 riffusion/external/lora/train_lora.py create mode 100755 riffusion/external/lora/train_lora.sh create mode 100644 riffusion/external/lora/train_lora_dreambooth.py diff --git a/pyproject.toml b/pyproject.toml index 88d8f05..987789e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,10 @@ ignore_missing_imports = true module = "diffusers.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "lora_diffusion.*" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "numpy.*" ignore_missing_imports = true diff --git a/requirements.txt b/requirements.txt index 49e1f32..dbfe559 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,6 @@ sox streamlit>=1.10.0 torch torchaudio +torchvision transformers +git+https://github.com/cloneofsimo/lora.git diff --git a/riffusion/external/lora/__init__.py b/riffusion/external/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/riffusion/external/lora/run_lora_db_unet_only.sh b/riffusion/external/lora/run_lora_db_unet_only.sh new file mode 100644 index 0000000..722c397 --- /dev/null +++ b/riffusion/external/lora/run_lora_db_unet_only.sh @@ -0,0 +1,24 @@ +export MODEL_NAME="riffusion/riffusion-model-v1" +export INSTANCE_DIR="/tmp/sample_clips_tdlcqdfi/images" +export OUTPUT_DIR="/home/ubuntu/lora_dreambooth_waterfalls_2k" + +accelerate launch\ + --num_machines 1 \ + --num_processes 8 \ + --dynamo_backend=no \ + --mixed_precision="fp16" \ + riffusion/external/lora/train_lora_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="style of sks" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=2000 + +# TODO try mixed_precision=fp16 +# TODO try num_processes = 8 diff --git a/riffusion/external/lora/train_lora.py b/riffusion/external/lora/train_lora.py new file mode 100644 index 0000000..7fa76c8 --- /dev/null +++ b/riffusion/external/lora/train_lora.py @@ -0,0 +1,49 @@ +from lora_diffusion.cli_lora_pti import train +from lora_diffusion.dataset import STYLE_TEMPLATE + +MODEL_NAME = "riffusion/riffusion-model-v1" +INSTANCE_DIR = "/tmp/sample_clips_xzv8p57g/images" +OUTPUT_DIR = "./lora_output_acoustic" + +if __name__ == "__main__": + entries = [ + "music in the style of {}", + "sound in the style of {}", + "vibe in the style of {}", + "audio in the style of {}", + "groove in the style of {}", + ] + for i in range(len(STYLE_TEMPLATE)): + STYLE_TEMPLATE[i] = entries[i % len(entries)] + print(STYLE_TEMPLATE) + + train( + pretrained_model_name_or_path=MODEL_NAME, + instance_data_dir=INSTANCE_DIR, + output_dir=OUTPUT_DIR, + train_text_encoder=True, + resolution=512, + train_batch_size=1, + gradient_accumulation_steps=4, + scale_lr=True, + learning_rate_unet=1e-4, + learning_rate_text=1e-5, + learning_rate_ti=5e-4, + color_jitter=False, + lr_scheduler="linear", + lr_warmup_steps=0, + placeholder_tokens="|", + use_template="style", + save_steps=100, + max_train_steps_ti=1000, + max_train_steps_tuning=1000, + perform_inversion=True, + clip_ti_decay=True, + weight_decay_ti=0.000, + weight_decay_lora=0.001, + continue_inversion=True, + continue_inversion_lr=1e-4, + device="cuda:0", + lora_rank=1, + use_face_segmentation_condition=False, + ) diff --git a/riffusion/external/lora/train_lora.sh b/riffusion/external/lora/train_lora.sh new file mode 100755 index 0000000..b47f898 --- /dev/null +++ b/riffusion/external/lora/train_lora.sh @@ -0,0 +1,37 @@ +export MODEL_NAME="riffusion/riffusion-model-v1" +export INSTANCE_DIR="/tmp/sample_clips_xzv8p57g/images" +export OUTPUT_DIR="./lora_output_acoustic" + +lora_pti \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --train_text_encoder \ + --resolution=512 \ + # Started as 1 + --train_batch_size=4 \ + --gradient_accumulation_steps=4 \ + --scale_lr \ + --learning_rate_unet=1e-4 \ + --learning_rate_text=1e-5 \ + --learning_rate_ti=5e-4 \ +# --color_jitter \ + --lr_scheduler="linear" \ + --lr_warmup_steps=0 \ + --placeholder_tokens="" \ +# initializer tokens +# class prompt +# --use_template="style"\ + --save_steps=100 \ + --max_train_steps_ti=1000 \ + --max_train_steps_tuning=1000 \ + --perform_inversion=True \ + --clip_ti_decay \ + --weight_decay_ti=0.000 \ + --weight_decay_lora=0.001\ + --continue_inversion \ + --continue_inversion_lr=1e-4 \ + --device="cuda:0" \ + # 1 or 4? + --lora_rank=4 \ +# --use_face_segmentation_condition\ diff --git a/riffusion/external/lora/train_lora_dreambooth.py b/riffusion/external/lora/train_lora_dreambooth.py new file mode 100644 index 0000000..7488683 --- /dev/null +++ b/riffusion/external/lora/train_lora_dreambooth.py @@ -0,0 +1,958 @@ +# Bootstrapped from: +# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py + +# ruff: noqa +# mypy: ignore-errors + +import argparse +import hashlib +import inspect +import itertools +import math +import os +from pathlib import Path + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from lora_diffusion import ( + extract_lora_ups_down, + inject_trainable_lora, + safetensors_available, + save_lora_weight, + save_safeloras, +) +from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + color_jitter=False, + h_flip=False, + resize=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.resize = resize + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + img_transforms = [] + + if resize: + img_transforms.append( + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + ) + if center_crop: + img_transforms.append(transforms.CenterCrop(size)) + if color_jitter: + img_transforms.append(transforms.ColorJitter(0.2, 0.1)) + if h_flip: + img_transforms.append(transforms.RandomHorizontalFlip()) + + self.image_transforms = transforms.Compose( + [*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +logger = get_logger(__name__) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_name_or_path", + type=str, + default=None, + help="Path to pretrained vae or vae identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--output_format", + type=str, + choices=["pt", "safe", "both"], + default="both", + help="The output format of the model predicitions and checkpoints.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution", + ) + parser.add_argument( + "--color_jitter", + action="store_true", + help="Whether to apply color jitter to images", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save checkpoint every X updates steps.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="Rank of LoRA approximation.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=None, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate_text", + type=float, + default=5e-6, + help="Initial learning rate for text encoder (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) + parser.add_argument( + "--resume_unet", + type=str, + default=None, + help=("File path for unet lora to resume training."), + ) + parser.add_argument( + "--resume_text_encoder", + type=str, + default=None, + help=("File path for text encoder lora to resume training."), + ) + parser.add_argument( + "--resize", + type=bool, + default=True, + required=False, + help="Should images be resized to --resolution before training?", + ) + parser.add_argument( + "--use_xformers", action="store_true", help="Whether or not to use xformers" + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + if not safetensors_available: + if args.output_format == "both": + print( + "Safetensors is not available - changing output format to just output PyTorch files" + ) + args.output_format = "pt" + elif args.output_format == "safe": + raise ValueError( + "Safetensors is not available - either install it, or change output_format." + ) + + return args + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if ( + args.train_text_encoder + and args.gradient_accumulation_steps > 1 + and accelerator.num_processes > 1 + ): + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size + ) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + ) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path, + subfolder=None if args.pretrained_vae_name_or_path else "vae", + revision=None if args.pretrained_vae_name_or_path else args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) + unet.requires_grad_(False) + unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank, loras=args.resume_unet) + + for _up, _down in extract_lora_ups_down(unet): + print("Before training: Unet First Layer lora up", _up.weight.data) + print("Before training: Unet First Layer lora down", _down.weight.data) + break + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + if args.train_text_encoder: + text_encoder_lora_params, _ = inject_trainable_lora( + text_encoder, + target_replace_module=["CLIPAttention"], + r=args.lora_rank, + ) + for _up, _down in extract_lora_ups_down( + text_encoder, target_replace_module=["CLIPAttention"] + ): + print("Before training: text encoder First Layer lora up", _up.weight.data) + print("Before training: text encoder First Layer lora down", _down.weight.data) + break + + if args.use_xformers: + set_use_memory_efficient_attention_xformers(unet, True) + set_use_memory_efficient_attention_xformers(vae, True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + text_lr = args.learning_rate if args.learning_rate_text is None else args.learning_rate_text + + params_to_optimize = ( + [ + {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate}, + { + "params": itertools.chain(*text_encoder_lora_params), + "lr": text_lr, + }, + ] + if args.train_text_encoder + else itertools.chain(*unet_lora_params) + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler.from_config( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + color_jitter=args.color_jitter, + resize=args.resize, + ) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=1, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + if args.train_text_encoder: + ( + unet, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = ( + args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + ) + + print("***** Running training *****") + print(f" Num examples = {len(train_dataset)}") + print(f" Num batches each epoch = {len(train_dataloader)}") + print(f" Num Epochs = {args.num_train_epochs}") + print(f" Instantaneous batch size per device = {args.train_batch_size}") + print( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + print(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + last_save = 0 + + for epoch in range(args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + + for step, batch in enumerate(train_dataloader): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = ( + F.mse_loss(model_pred.float(), target.float(), reduction="none") + .mean([1, 2, 3]) + .mean() + ) + + # Compute prior loss + prior_loss = F.mse_loss( + model_pred_prior.float(), target_prior.float(), reduction="mean" + ) + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + progress_bar.update(1) + optimizer.zero_grad() + + global_step += 1 + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.save_steps and global_step - last_save >= args.save_steps: + if accelerator.is_main_process: + # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing + # it, the models will be unwrapped, and when they are then used for further training, + # we will crash. pass this, but only to newer versions of accelerate. fixes + # https://github.com/huggingface/diffusers/issues/1566 + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = ( + {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} + ) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), + revision=args.revision, + ) + + filename_unet = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" + filename_text_encoder = ( + f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt" + ) + print(f"save weights {filename_unet}, {filename_text_encoder}") + save_lora_weight(pipeline.unet, filename_unet) + if args.train_text_encoder: + save_lora_weight( + pipeline.text_encoder, + filename_text_encoder, + target_replace_module=["CLIPAttention"], + ) + + for _up, _down in extract_lora_ups_down(pipeline.unet): + print( + "First Unet Layer's Up Weight is now : ", + _up.weight.data, + ) + print( + "First Unet Layer's Down Weight is now : ", + _down.weight.data, + ) + break + if args.train_text_encoder: + for _up, _down in extract_lora_ups_down( + pipeline.text_encoder, + target_replace_module=["CLIPAttention"], + ): + print( + "First Text Encoder Layer's Up Weight is now : ", + _up.weight.data, + ) + print( + "First Text Encoder Layer's Down Weight is now : ", + _down.weight.data, + ) + break + + last_save = global_step + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + + print("\n\nLora TRAINING DONE!\n\n") + + if args.output_format == "pt" or args.output_format == "both": + save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt") + if args.train_text_encoder: + save_lora_weight( + pipeline.text_encoder, + args.output_dir + "/lora_weight.text_encoder.pt", + target_replace_module=["CLIPAttention"], + ) + + if args.output_format == "safe" or args.output_format == "both": + loras = {} + loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"}) + if args.train_text_encoder: + loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"}) + + save_safeloras(loras, args.output_dir + "/lora_weight.safetensors") + + if args.push_to_hub: + repo.push_to_hub( + commit_message="End of training", + blocking=False, + auto_lfs_prune=True, + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/riffusion/streamlit/pages/audio_to_audio.py b/riffusion/streamlit/pages/audio_to_audio.py index b1271df..d66f5ae 100644 --- a/riffusion/streamlit/pages/audio_to_audio.py +++ b/riffusion/streamlit/pages/audio_to_audio.py @@ -46,6 +46,9 @@ def render_audio_to_audio() -> None: device = streamlit_util.select_device(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar) + lora_path = st.sidebar.text_input("Lora Path", "") + lora_scale = st.sidebar.number_input("Lora Scale", value=1.0) + with st.sidebar: num_inference_steps = T.cast( int, @@ -149,8 +152,6 @@ def render_audio_to_audio() -> None: if counter.value == 0: return - st.write(f"## Counter: {counter.value}") - params = SpectrogramParams() if interpolate: @@ -217,6 +218,8 @@ def render_audio_to_audio() -> None: progress_callback=progress_callback, device=device, scheduler=scheduler, + lora_path=lora_path, + lora_scale=lora_scale, ) # Resize back to original size diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index 3d41534..ec4ce07 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -29,6 +29,9 @@ def render_text_to_audio() -> None: device = streamlit_util.select_device(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar) + lora_path = st.sidebar.text_input("Lora Path", "") + lora_scale = st.sidebar.number_input("Lora Scale", value=1.0) + with st.form("Inputs"): prompt = st.text_input("Prompt") negative_prompt = st.text_input("Negative prompt") @@ -93,6 +96,8 @@ def render_text_to_audio() -> None: height=512, device=device, scheduler=scheduler, + lora_path=lora_path, + lora_scale=lora_scale, ) st.image(image) diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index 7f4c865..28234bd 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -4,6 +4,7 @@ import io import threading import typing as T +from pathlib import Path import pydub import streamlit as st @@ -53,6 +54,8 @@ def load_stable_diffusion_pipeline( device: str = "cuda", dtype: torch.dtype = torch.float16, scheduler: str = SCHEDULER_OPTIONS[0], + lora_path: T.Optional[str] = None, + lora_scale: float = 1.0, ) -> StableDiffusionPipeline: """ Load the riffusion pipeline. @@ -72,6 +75,21 @@ def load_stable_diffusion_pipeline( pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config) + if lora_path: + if not Path(lora_path).is_file() or Path(lora_path).is_dir(): + raise RuntimeError("Bad lora path") + + from lora_diffusion import patch_pipe, tune_lora_scale + + patch_pipe( + pipeline, + lora_path, + patch_text=True, + patch_ti=True, + patch_unet=True, + ) + tune_lora_scale(pipeline.unet, lora_scale) + return pipeline @@ -121,6 +139,8 @@ def load_stable_diffusion_img2img_pipeline( device: str = "cuda", dtype: torch.dtype = torch.float16, scheduler: str = SCHEDULER_OPTIONS[0], + lora_path: T.Optional[str] = None, + lora_scale: float = 1.0, ) -> StableDiffusionImg2ImgPipeline: """ Load the image to image pipeline. @@ -140,6 +160,22 @@ def load_stable_diffusion_img2img_pipeline( pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config) + # TODO reduce duplication + if lora_path: + if not Path(lora_path).is_file() or Path(lora_path).is_dir(): + raise RuntimeError("Bad lora path") + + from lora_diffusion import patch_pipe, tune_lora_scale + + patch_pipe( + pipeline, + lora_path, + patch_text=True, + patch_ti=True, + patch_unet=True, + ) + tune_lora_scale(pipeline.unet, lora_scale) + return pipeline @@ -154,12 +190,19 @@ def run_txt2img( height: int, device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0], + lora_path: T.Optional[str] = None, + lora_scale: float = 1.0, ) -> Image.Image: """ Run the text to image pipeline with caching. """ with pipeline_lock(): - pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler) + pipeline = load_stable_diffusion_pipeline( + device=device, + scheduler=scheduler, + lora_path=lora_path, + lora_scale=lora_scale, + ) generator_device = "cpu" if device.lower().startswith("mps") else device generator = torch.Generator(device=generator_device).manual_seed(seed) @@ -278,9 +321,16 @@ def run_img2img( device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0], progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, + lora_path: T.Optional[str] = None, + lora_scale: float = 1.0, ) -> Image.Image: with pipeline_lock(): - pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler) + pipeline = load_stable_diffusion_img2img_pipeline( + device=device, + scheduler=scheduler, + lora_path=lora_path, + lora_scale=lora_scale, + ) generator_device = "cpu" if device.lower().startswith("mps") else device generator = torch.Generator(device=generator_device).manual_seed(seed)