Skip to content

Commit

Permalink
Lock to prevent concurrent access to model pipeline in playground
Browse files Browse the repository at this point in the history
This should avoid the most basic conflict with multiple users on
the app at one time. Seems to work with initial testing.

Topic: pipeline_lock
  • Loading branch information
hmartiro committed Jan 16, 2023
1 parent e832b92 commit 1fb9141
Showing 1 changed file with 51 additions and 40 deletions.
91 changes: 51 additions & 40 deletions riffusion/streamlit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Streamlit utilities (mostly cached wrappers around riffusion code).
"""
import io
import threading
import typing as T

import pydub
Expand Down Expand Up @@ -106,6 +107,14 @@ def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
raise ValueError(f"Unknown scheduler {scheduler}")


@st.experimental_singleton
def pipeline_lock() -> threading.Lock:
"""
Singleton lock used to prevent concurrent access to any model pipeline.
"""
return threading.Lock()


@st.experimental_singleton
def load_stable_diffusion_img2img_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
Expand Down Expand Up @@ -149,22 +158,23 @@ def run_txt2img(
"""
Run the text to image pipeline with caching.
"""
pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler)

generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)

output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=negative_prompt or None,
generator=generator,
width=width,
height=height,
)
with pipeline_lock():
pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler)

generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)

output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=negative_prompt or None,
generator=generator,
width=width,
height=height,
)

return output["images"][0]
return output["images"][0]


@st.experimental_singleton
Expand Down Expand Up @@ -269,31 +279,32 @@ def run_img2img(
scheduler: str = SCHEDULER_OPTIONS[0],
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
) -> Image.Image:
pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler)

generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)

num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)

def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
if progress_callback is not None:
progress_callback(step / num_expected_steps)

result = pipeline(
prompt=prompt,
image=init_image,
strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt or None,
num_images_per_prompt=1,
generator=generator,
callback=callback,
callback_steps=1,
)

return result.images[0]
with pipeline_lock():
pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler)

generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)

num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)

def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
if progress_callback is not None:
progress_callback(step / num_expected_steps)

result = pipeline(
prompt=prompt,
image=init_image,
strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt or None,
num_images_per_prompt=1,
generator=generator,
callback=callback,
callback_steps=1,
)

return result.images[0]


class StreamlitCounter:
Expand Down

0 comments on commit 1fb9141

Please sign in to comment.