Skip to content

Commit

Permalink
Merge pull request #4 from xelfer/main
Browse files Browse the repository at this point in the history
add image2image support
  • Loading branch information
justinmerrell authored Jun 26, 2023
2 parents 8ebcf52 + 905375f commit b5892a9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
41 changes: 31 additions & 10 deletions src/rp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import os
import torch

from diffusers import DiffusionPipeline
from diffusers import DiffusionPipeline, KandinskyImg2ImgPipeline, KandinskyPriorPipeline
from diffusers.utils import load_image

import runpod
from runpod.serverless.utils import rp_upload, rp_cleanup
Expand All @@ -22,6 +23,8 @@
"kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16).to("cuda")
t2i_pipe.enable_xformers_memory_efficient_attention()

i2i_prior = KandinskyPriorPipeline(**pipe_prior.components).to("cuda")
i2i_pipe = KandinskyImg2ImgPipeline(**t2i_pipe.components).to("cuda")

def _setup_generator(seed):
generator = torch.Generator(device="cuda")
Expand Down Expand Up @@ -68,18 +71,36 @@ def generate_image(job):
validated_input['negative_prompt'],
generator=generator).to_tuple()

# Check if an input image is provided, determining if this is text2image or image2image
init_image = None
if job_input.get('init_image', None) is not None:
init_image = load_image(job_input['init_image'])

# List to hold the image URLs
image_urls = []

# Create image
output = t2i_pipe(validated_input['prompt'],
image_embeds=image_embeds,
negative_image_embeds=negative_image_embeds,
height=validated_input['h'],
width=validated_input['w'],
num_inference_steps=validated_input['num_steps'],
guidance_scale=validated_input['guidance_scale'],
num_images_per_prompt=validated_input['num_images']).images
if init_image is None:
# Create text2image
output = t2i_pipe(validated_input['prompt'],
image_embeds=image_embeds,
negative_image_embeds=negative_image_embeds,
height=validated_input['h'],
width=validated_input['w'],
num_inference_steps=validated_input['num_steps'],
guidance_scale=validated_input['guidance_scale'],
num_images_per_prompt=validated_input['num_images']).images
else:
# Create image2image
output = i2i_pipe(
validated_input["prompt"],
image=init_image,
image_embeds=image_embeds,
negative_image_embeds=negative_image_embeds,
height=validated_input['h'],
width=validated_input['w'],
num_inference_steps=validated_input['num_steps'],
strength=validated_input['strength']).images


image_urls = _save_and_upload_images(output, job['id'])

Expand Down
10 changes: 10 additions & 0 deletions src/rp_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@
'required': False,
'default': 1
},
'strength': {
'type': float,
'required': False,
'default': 0.2
},
'init_image': {
'type': str,
'required': False,
'default': ""
},

# Included for backwards compatibility
'batch_size': {
Expand Down

0 comments on commit b5892a9

Please sign in to comment.