From 905375fef744e80f5daedae8ab3edaa6bc110fa6 Mon Sep 17 00:00:00 2001 From: Nick Triantafillou Date: Mon, 19 Jun 2023 15:40:41 +1000 Subject: [PATCH] add image2image support --- src/rp_handler.py | 41 +++++++++++++++++++++++++++++++---------- src/rp_schemas.py | 10 ++++++++++ 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/rp_handler.py b/src/rp_handler.py index 442aa55..e51ba59 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -6,7 +6,8 @@ import os import torch -from diffusers import DiffusionPipeline +from diffusers import DiffusionPipeline, KandinskyImg2ImgPipeline, KandinskyPriorPipeline +from diffusers.utils import load_image import runpod from runpod.serverless.utils import rp_upload, rp_cleanup @@ -22,6 +23,8 @@ "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16).to("cuda") t2i_pipe.enable_xformers_memory_efficient_attention() +i2i_prior = KandinskyPriorPipeline(**pipe_prior.components).to("cuda") +i2i_pipe = KandinskyImg2ImgPipeline(**t2i_pipe.components).to("cuda") def _setup_generator(seed): generator = torch.Generator(device="cuda") @@ -68,18 +71,36 @@ def generate_image(job): validated_input['negative_prompt'], generator=generator).to_tuple() + # Check if an input image is provided, determining if this is text2image or image2image + init_image = None + if job_input.get('init_image', None) is not None: + init_image = load_image(job_input['init_image']) + # List to hold the image URLs image_urls = [] - # Create image - output = t2i_pipe(validated_input['prompt'], - image_embeds=image_embeds, - negative_image_embeds=negative_image_embeds, - height=validated_input['h'], - width=validated_input['w'], - num_inference_steps=validated_input['num_steps'], - guidance_scale=validated_input['guidance_scale'], - num_images_per_prompt=validated_input['num_images']).images + if init_image is None: + # Create text2image + output = t2i_pipe(validated_input['prompt'], + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + height=validated_input['h'], + width=validated_input['w'], + num_inference_steps=validated_input['num_steps'], + guidance_scale=validated_input['guidance_scale'], + num_images_per_prompt=validated_input['num_images']).images + else: + # Create image2image + output = i2i_pipe( + validated_input["prompt"], + image=init_image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + height=validated_input['h'], + width=validated_input['w'], + num_inference_steps=validated_input['num_steps'], + strength=validated_input['strength']).images + image_urls = _save_and_upload_images(output, job['id']) diff --git a/src/rp_schemas.py b/src/rp_schemas.py index 2a9329f..d30e7f9 100644 --- a/src/rp_schemas.py +++ b/src/rp_schemas.py @@ -67,6 +67,16 @@ 'required': False, 'default': 1 }, + 'strength': { + 'type': float, + 'required': False, + 'default': 0.2 + }, + 'init_image': { + 'type': str, + 'required': False, + 'default': "" + }, # Included for backwards compatibility 'batch_size': {