Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference straight from the cli #119

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,6 @@ dmypy.json

# Pyre type checker
.pyre/

# Output files for cli
output/
107 changes: 107 additions & 0 deletions riffusion/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,116 @@
import tqdm
from PIL import Image

from riffusion.datatypes import InferenceInput, PromptInput, RawInferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.server import compute_request
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util

"""
Where built-in seed images are stored

To use custom seed images, such as for audio to audio, place them in this folder as well.
Use this cli's audio to image to generate spectrograms from your desired starting audio
"""
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")


@argh.arg("--out", help="The path to place outputs from inference, default is ../../output")
def inference(*, out: str = "output"):
"""
Interface to run inference request in the cli.
"""

# Model to use
checkpoint: str = "riffusion/riffusion-model-v1"

# Initialize the model
PIPELINE = RiffusionPipeline.load_checkpoint(
checkpoint=checkpoint,
use_traced_unet=True,
device="cuda",
)

# Output path
output_dir_path = Path(Path(__file__).resolve().parent.parent, out)
output_dir_path.mkdir(parents=True, exist_ok=True)

# Why the fuck does python not have a list equivalent to dictget/getattr
def get_default(i, default):
try:
retval = input_array[i].strip()
if retval.lower() == "none":
return default
else:
return retval
except IndexError:
return default

print("\n")
print("Please enter the parameters for inference, comma seperated")
print("ctrl-c to quit")
print("Everything except starting prompt is optional")
print(
"Example input: bossa nova with distorted guitar, 674, 0.75, 7, 0, 50, og_beat, None,\
bossa nova with distorted guitar, 675, 0.75, 7, outputfilenames"
)
print(
"Details in InferenceInput, order is starting PromptInput, modifiers, ending PromptInput,\
outputfilenames"
)

try:
while 1:
input_string = input(">> ")
input_array = input_string.split(",")

startprompt = PromptInput(
get_default(0, "Sad trombone noises"),
int(get_default(1, 123)),
None,
float(get_default(2, 0.75)),
float(get_default(3, 7.0)),
)
endprompt = PromptInput(
get_default(8, "Sad trombone noises"),
int(get_default(9, 123)),
None,
float(get_default(10, 0.75)),
float(get_default(11, 7.0)),
)
query = InferenceInput(
startprompt,
endprompt,
float(get_default(4, 0)),
int(get_default(5, 50)),
get_default(6, "og_beat"),
get_default(7, None),
)
outputfilenames = get_default(12, "output")

results = compute_request(
inputs=query,
seed_images_dir=str(SEED_IMAGES_DIR),
pipeline=PIPELINE,
)

if type(results) == tuple:
print(results)
raise RuntimeError(f"Couldnt run infference, error {results[0]}")
elif type(results) == RawInferenceOutput:
with open(Path(output_dir_path, f"{outputfilenames}.jpeg"), "wb") as f:
f.write(results.image.getbuffer())
with open(Path(output_dir_path, f"{outputfilenames}.mp3"), "wb") as f:
f.write(results.audio.getbuffer())
print(f"Finished, took {results.duration_s} seconds")
else:
raise RuntimeError("Unexpected output from compute_request")

except KeyboardInterrupt:
print("Goodbye!")


@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
Expand Down Expand Up @@ -274,5 +380,6 @@ def process_one(audio_path: Path) -> None:
print_exif,
audio_to_images_batch,
sample_clips_batch,
inference,
]
)
17 changes: 17 additions & 0 deletions riffusion/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

import io
import typing as T
from dataclasses import dataclass

Expand Down Expand Up @@ -57,6 +58,22 @@ class InferenceInput:
mask_image_id: T.Optional[str] = None


@dataclass(frozen=True)
class RawInferenceOutput:
"""
Output of infrence model
"""

# Raw JPEG bytes
image: io.BytesIO

# Raw MP3 bytes
audio: io.BytesIO

# The duration of the audio clip
duration_s: float


@dataclass(frozen=True)
class InferenceOutput:
"""
Expand Down
37 changes: 27 additions & 10 deletions riffusion/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import PIL
from flask_cors import CORS

from riffusion.datatypes import InferenceInput, InferenceOutput
from riffusion.datatypes import InferenceInput, InferenceOutput, RawInferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
Expand Down Expand Up @@ -101,10 +101,12 @@ def run_inference():
logging.info(json_data)
return str(exception), 400

response = compute_request(
inputs=inputs,
seed_images_dir=SEED_IMAGES_DIR,
pipeline=PIPELINE,
response = form_response(
compute_request(
inputs=inputs,
seed_images_dir=SEED_IMAGES_DIR,
pipeline=PIPELINE,
)
)

# Log the total time
Expand All @@ -113,11 +115,26 @@ def run_inference():
return response


def form_response(inference_result: RawInferenceOutput) -> T.Union[str, T.Tuple[str, int]]:
"""
Converts an InferenceOutput to a web usable response
"""

# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + base64_util.encode(inference_result.image),
audio="data:audio/mpeg;base64," + base64_util.encode(inference_result.audio),
duration_s=inference_result.duration_s,
)

return json.dumps(dataclasses.asdict(output))


def compute_request(
inputs: InferenceInput,
pipeline: RiffusionPipeline,
seed_images_dir: str,
) -> T.Union[str, T.Tuple[str, int]]:
) -> T.Union[RawInferenceOutput, T.Tuple[str, int]]:
"""
Does all the heavy lifting of the request.

Expand Down Expand Up @@ -174,13 +191,13 @@ def compute_request(
image_bytes.seek(0)

# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
output = RawInferenceOutput(
image=image_bytes,
audio=mp3_bytes,
duration_s=segment.duration_seconds,
)

return json.dumps(dataclasses.asdict(output))
return output


if __name__ == "__main__":
Expand Down