jjuun's picture
add github profile
0cb1e78
raw
history blame
3.91 kB
import gradio as gr
import numpy as np
import random
import torch
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
from utils import randomize_seed_fn
MAX_SEED = np.iinfo(np.int32).max
def model_load():
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", vae=vae, torch_dtype=torch.float16
)
# load lora weight
pipe.load_lora_weights("jjuun/vivid_color_style")
return pipe.to('cuda')
def sdxl_process(seed, prompt, additional_prompt, negative_prompt, num_steps, guidance_scale):
pipe = model_load()
generator = torch.Generator("cuda")
generator.manual_seed(seed)
special_prompt = 'jjj, scratch art style'
prompt = f'{special_prompt}, {prompt}, with a black background'
output = pipe(prompt, additional_prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale,
generator=generator).images[0]
return output
title = "๐ŸŒˆ Colorful illustration"
description_en = "๐Ÿš€ How to use: please make sure to include 'a colorful' in prompt and click Run button!"
def create_demo():
with gr.Blocks() as demo:
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
gr.Markdown(f"<h3 style='text-align: center'>{description_en}</h3>")
gr.Markdown(f"<a href='https://github.com/jjuun0'><img src='https://img.shields.io/badge/GitHub-181717?style=flat-square&logo=GitHub&logoColor=white'/></a>")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
a_prompt = gr.Textbox(label="Additional prompt", value="")
n_prompt = gr.Textbox(
label="Negative prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
)
with gr.Column():
result = gr.Image(label="Output")
result_seed = gr.Textbox(label="Used seed")
gr.Examples(
examples= [["a colorful lion", "20", "9", "0", "", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "examples/lion.png"],
["a colorful messi", "20", "9", "0", "", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "examples/messi.png"]],
inputs = [prompt, num_steps, guidance_scale, seed, a_prompt, n_prompt, result]
)
inputs = [
seed,
prompt,
a_prompt,
n_prompt,
num_steps,
guidance_scale,
]
run_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=result_seed,
queue=False,
api_name=False,
).then(
fn=sdxl_process,
inputs=inputs,
outputs=result,
api_name=False,
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue().launch()