sdxl-gdsc / app.py
JacobLinCool's picture
Update app.py
c447cda verified
raw
history blame
3.85 kB
import json
import random
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import DiffusionPipeline, LCMScheduler
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="lora")
pipe.load_lora_weights("JacobLinCool/sdxl-lora-gdsc-1", adapter_name="gdsc")
pipe.set_adapters(["lora", "gdsc"], adapter_weights=[1.0, 1.0])
pipe.to(device=DEVICE, dtype=torch.float16)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def infer(
pre_prompt,
prompt,
seed,
randomize_seed,
num_inference_steps,
negative_prompt,
guidance_scale,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if pre_prompt != "":
prompt = f"{pre_prompt} {prompt}"
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
return image
css = """
h1 {
text-align: center;
display:block;
}
p {
text-align: justify;
display:block;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
scale=5,
)
run_button = gr.Button("Run", scale=1)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
pre_prompt = gr.Text(
label="Pre-Prompt",
show_label=True,
max_lines=1,
placeholder="Pre Prompt from the LoRA config",
container=True,
scale=5,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=4,
maximum=8,
step=1,
value=4,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=6,
step=0.5,
value=1,
)
negative_prompt = gr.Text(
label="Negative Prompt",
show_label=False,
max_lines=1,
placeholder="Enter a negative Prompt",
container=False,
)
run_button.click(
fn=infer,
inputs=[
pre_prompt,
prompt,
seed,
randomize_seed,
num_inference_steps,
negative_prompt,
guidance_scale,
],
outputs=[result],
)
demo.queue().launch()