Spaces:
Running
on
A10G
Running
on
A10G
import gradio as gr | |
from PIL import Image | |
import torch | |
from muse import PipelineMuse | |
from diffusers import AutoPipelineForText2Image, UniPCMultistepScheduler | |
muse_512 = PipelineMuse.from_pretrained("openMUSE/muse-512").to("cuda", dtype=torch.float16) | |
muse_512.transformer.enable_xformers_memory_efficient_attention() | |
muse_512_fine = PipelineMuse.from_pretrained("openMUSE/muse-512-finetuned").to("cuda", dtype=torch.float16) | |
muse_512_fine.transformer.enable_xformers_memory_efficient_attention() | |
sdv1_5 = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16").to("cuda", dtype=torch.float16) | |
sdv1_5.scheduler = UniPCMultistepScheduler.from_config(sdv1_5.scheduler.config) | |
sdv1_5.enable_xformers_memory_efficient_attention() | |
def infer(prompt, negative): | |
print("Generating:") | |
muse_512_image = muse_512( | |
prompt, timesteps=16, guidance_scale=10, transformer_seq_len=1024, use_fp16=True, temperature=(2, 0), | |
)[0] | |
muse_512_fine_image = muse_512_fine( | |
prompt, timesteps=16, guidance_scale=10, transformer_seq_len=1024, use_fp16=True, temperature=(2, 0), | |
)[0] | |
sdv1_5_image = sdv1_5(prompt, num_inference_steps=25).images[0] | |
images = [muse_512_image, muse_512_fine_image, sdv1_5_image] | |
return images | |
examples = [ | |
[ | |
'A high tech solarpunk utopia in the Amazon rainforest', | |
'low quality', | |
10, | |
], | |
[ | |
'A pikachu fine dining with a view to the Eiffel Tower', | |
'low quality', | |
10, | |
], | |
[ | |
'A mecha robot in a favela in expressionist style', | |
'low quality, 3d, photorealistic', | |
10, | |
], | |
[ | |
'an insect robot preparing a delicious meal', | |
'low quality, illustration', | |
10, | |
], | |
[ | |
"A small cabin on top of a snowy mountain in the style of Disney, artstation", | |
'low quality, ugly', | |
10, | |
], | |
] | |
css = """ | |
h1 { | |
text-align: center; | |
} | |
#component-0 { | |
max-width: 730px; | |
margin: auto; | |
} | |
""" | |
block = gr.Blocks(css=css) | |
with block: | |
gr.Markdown("MUSE is an upcoming fast text2image model.") | |
with gr.Group(): | |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): | |
with gr.Column(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
btn = gr.Button("Generate image", scale=0) | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, | |
).style() | |
with gr.Accordion("Advanced settings", open=False): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", minimum=0, maximum=20, value=10, step=0.1 | |
) | |
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, negative, guidance_scale], outputs=gallery, cache_examples=False) | |
ex.dataset.headers = [""] | |
text.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery) | |
negative.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery) | |
btn.click(infer, inputs=[text, negative, guidance_scale], outputs=gallery) | |
block.launch() | |