Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flux compile example #217

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions examples/dynamo/flux/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse
import torch
from diffusers import FluxPipeline

import torch_migraphx

torch._dynamo.reset()

parser = argparse.ArgumentParser(description='Conversion parameters')

parser.add_argument('--num_steps',
type=int,
default=50,
help='Number of steps to run unet')

parser.add_argument('--fname',
type=str,
default='output.png',
help='Output file name')

parser.add_argument('--prompts',
nargs='*',
default=["A cat holding a sign that says hello world"],
help='Prompts to use as input')

parser.add_argument('--prompt2',
nargs='*',
default=None,
help='Prompts to use as input')

parser.add_argument('--model_repo',
type=str,
default='black-forest-labs/FLUX.1-dev',
help='Huggingface repo path')

parser.add_argument('--fp16',
action='store_true',
help='Load fp16 version of the pipeline')

parser.add_argument("-d",
"--image-height",
type=int,
default=1024,
help="Output Image height, default 1024")

parser.add_argument("-w",
"--image-width",
type=int,
default=1024,
help="Output Image width, default 1024")


def run(args):
dtype = torch.float16 if args.fp16 else torch.float32
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype)

pipe = pipe.to("cuda")

# pipe.text_encoder = torch.compile(pipe.text_encoder, backend='migraphx')
# pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, backend='migraphx')
pipe.transformer = torch.compile(pipe.transformer, backend='migraphx')
# pipe.vae.decoder = torch.compile(pipe.vae.decoder, backend='migraphx')

image = pipe(prompt=args.prompts,
height=args.image_height,
width=args.image_width,
guidance_scale=3.5,
num_inference_steps=args.num_steps,
max_sequence_length=512).images[0]

image.save(args.fname)


if __name__ == '__main__':
args = parser.parse_args()

run(args)
5 changes: 5 additions & 0 deletions examples/dynamo/flux/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
diffusers
transformers
accelerate
sentencepiece
protobuf
1 change: 1 addition & 0 deletions py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def build_extension(self, ext):
description='Intergrate PyTorch with MIGraphX acceleration engine',
long_description_content_type='text/markdown',
long_description=long_description,
setup_requires=["cmake"],
install_requires=[
"torch>=1.11.0",
"numpy>=1.20.0",
Expand Down
Loading