Tonic's picture
Update app.py
0687eaf verified
raw
history blame
5.94 kB
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import gradio as gr
import inversion
import numpy as np
from PIL import Image
import sa_handler
# import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, scheduler=scheduler).to(device)
# @spaces.GPU
def run(image, src_style, src_prompt, prompts, shared_score_shift, shared_score_scale, guidance_scale, num_inference_steps, seed, large=True):
prompts = prompts.splitlines()
dim, d = (1024, 128) if large else (512, 64)
image = image.resize((dim, dim))
x0 = np.array(image)
zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)
offset = min(5, len(zts) - 1)
prompts.insert(0, src_prompt)
shared_score_shift = np.log(shared_score_shift)
handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(
share_group_norm=True, share_layer_norm=True, share_attention=True,
adain_queries=True, adain_keys=True, adain_values=False,
shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)
handler.register(sa_args)
for i in range(1, len(prompts)):
prompts[i] = f'{prompts[i]}, {src_style}.'
zT, inversion_callback = inversion.make_inversion_callback(zts, offset=offset)
g_cpu = torch.Generator(device='cpu')
if seed > 0:
g_cpu.manual_seed(seed)
latents = torch.randn(len(prompts), 4, d, d, device='cpu', generator=g_cpu, dtype=pipeline.unet.dtype,).to(device)
latents[0] = zT
images_a = pipeline(prompts, latents=latents, callback_on_step_end=inversion_callback, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images
handler.remove()
torch.cuda.empty_cache()
return images_a
with gr.Blocks() as demo:
gr.Markdown("""# Welcome to🌟Tonic's🤵🏻Style📐Align
Here you can generate images with a style from a reference image using [transfer style from sdxl](https://huggingface.co/docs/diffusers/main/en/using-diffusers/sdxl). Add a reference picture, describe the style and add prompts to generate images in that style. It's the most interesting with your own art! You can also use [stabilityai/stable-diffusion-xl-base-1.0] by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic1/TonicsStyleAlign?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [DataTonic](https://github.com/Tonic-AI/DataTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
""")
with gr.Row():
image_input = gr.Image(label="Reference image", type="pil")
with gr.Row():
style_input = gr.Textbox(label="Describe the reference style")
image_desc_input = gr.Textbox(label="Describe the reference image")
prompts_input = gr.Textbox(label="Prompts to generate images (separate with new lines)", lines=5)
with gr.Accordion(label="Advanced Settings"):
with gr.Row():
shared_score_shift_input = gr.Slider(value=1.5, label="shared_score_shift", minimum=1.0, maximum=2.0, step=0.05)
shared_score_scale_input = gr.Slider(value=0.5, label="shared_score_scale", minimum=0.0, maximum=1.0, step=0.05)
guidance_scale_input = gr.Slider(value=10.0, label="guidance_scale", minimum=5.0, maximum=20.0, step=1)
num_inference_steps_input = gr.Slider(value=12, label="num_inference_steps", minimum=12, maximum=300, step=1)
seed_input = gr.Slider(value=0, label="seed", minimum=0, maximum=1000000, step=42)
with gr.Row():
run_button = gr.Button("Generate Images")
with gr.Row():
output_gallery = gr.Gallery()
run_button.click(
run,
inputs=[image_input, style_input, image_desc_input, prompts_input, shared_score_shift_input, shared_score_scale_input, guidance_scale_input, num_inference_steps_input, seed_input],
outputs=output_gallery
)
examples = [
["download (8).jpg", "picasso blue period", "a portrait of a man playing guitar",
"an astronaut holding a cocktail glass\nan astronaut in space holding a laptop\nan astronaut in space with an explosion of iridescent powder",
1.7, 0.7, 20, 144, 245112]
]
gr.Examples(
examples=examples,
inputs=[image_input, style_input, image_desc_input, prompts_input, shared_score_shift_input, shared_score_scale_input, guidance_scale_input, num_inference_steps_input, seed_input],
outputs=output_gallery,
fn=run,
cache_examples=True
)
demo.launch()