Spaces:
Sleeping
Sleeping
import os | |
import random | |
from typing import Callable, Dict, Optional, Tuple | |
import gradio as gr | |
import numpy as np | |
import PIL.Image | |
import spaces | |
import torch | |
from transformers import CLIPTextModel | |
from diffusers import AutoencoderKL, StableDiffusionXLPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler | |
MODEL = "eienmojiki/Starry-XL-v5.2" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512")) | |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048")) | |
MAX_SEED = np.iinfo(np.int32).max | |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" | |
sampler_list = [ | |
"DPM++ 2M Karras", | |
"DPM++ SDE Karras", | |
"DPM++ 2M SDE Karras", | |
"Euler", | |
"Euler a", | |
"DDIM", | |
] | |
examples = [ | |
""" | |
1girl, | |
midori \(blue archive\), blue archive, | |
(ningen mame:0.9), ciloranko, sho \(sho lwlw\), (tianliang duohe fangdongye:0.8), ask \(askzy\), wlop, | |
indoors, plant, hair bow, cake, cat ears, food, smile, animal ear headphones, bare legs, short shorts, drawing \(object\), feet, legs, on back, bed, solo, green eyes, cat, table, window blinds, headphones, nintendo switch, toes, bow, toenails, looking at viewer, chips \(food\), potted plant, halo, calendar \(object\), tray, blonde hair, green halo, lying, barefoot, bare shoulders, blunt bangs, green shorts, picture frame, fake animal ears, closed mouth, shorts, handheld game console, green bow, animal ears, on bed, medium hair, knees up, upshorts, eating, potato chips, pillow, blush, dolphin shorts, ass, character doll, alternate costume, | |
masterpiece, newest, absurdres | |
""" | |
] | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def seed_everything(seed: int) -> torch.Generator: | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
generator = torch.Generator() | |
generator.manual_seed(seed) | |
return generator | |
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]: | |
scheduler_factory_map = { | |
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config( | |
scheduler_config, use_karras_sigmas=True | |
), | |
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config( | |
scheduler_config, use_karras_sigmas=True | |
), | |
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config( | |
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++" | |
), | |
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config), | |
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(scheduler_config), | |
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config), | |
} | |
return scheduler_factory_map.get(name, lambda: None)() | |
def load_pipeline(model_name): | |
if torch.cuda.is_available(): | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
safety_checker=None, | |
use_safetensors=True, | |
add_watermarker=False, | |
use_auth_token=HF_TOKEN | |
) | |
pipe.to(device) | |
return pipe | |
def generate( | |
prompt: str, | |
negative_prompt: str = None, | |
seed: int = 0, | |
width: int = 1024, | |
height: int = 1024, | |
guidance_scale: float = 5.0, | |
num_inference_steps: int = 26, | |
sampler: str = "Eul""er a", | |
clip_skip: int = 1, | |
): | |
""" | |
if torch.cuda.is_available(): | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
MODEL, | |
torch_dtype=torch.float16, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
safety_checker=None, | |
use_safetensors=True, | |
add_watermarker=False, | |
use_auth_token=HF_TOKEN | |
) | |
""" | |
generator = seed_everything(seed) | |
pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler) | |
pipe.text_encoder = CLIPTextModel.from_pretrained( | |
MODEL, | |
subfolder = "text_encoder", | |
num_hidden_layers = 12 - (clip_skip - 1), | |
torch_dtype = torch.float16 | |
) | |
pipe.to(device) | |
try: | |
img = pipe( | |
prompt = prompt, | |
negative_prompt = negative_prompt, | |
width = width, | |
height = height, | |
guidance_scale = guidance_scale, | |
num_inference_steps = num_inference_steps, | |
generator = generator, | |
output_type="pil", | |
).images | |
return img, seed | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
with gr.Blocks( | |
theme=gr.themes.Soft() | |
) as demo: | |
gr.Markdown("# Starry XL 5.2 Demo") | |
with gr.Group(): | |
prompt = gr.Text( | |
label="Prompt", | |
placeholder="Enter your prompt here..." | |
) | |
negative_prompt = gr.Text( | |
label="Negative Prompt", | |
placeholder="(Optional) Enter your negative prompt here..." | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
sampler = gr.Dropdown( | |
label="Sampler", | |
choices=sampler_list, | |
interactive=True, | |
value="Euler a", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=1, | |
maximum=20, | |
step=0.1, | |
value=5.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Steps", | |
minimum=10, | |
maximum=100, | |
step=1, | |
value=25, | |
) | |
clip_skip = gr.Slider( | |
label="Clip Skip", | |
minimum=1, | |
maximum=2, | |
step=1, | |
value=1 | |
) | |
run_button = gr.Button("Run") | |
result = gr.Gallery( | |
label="Result", | |
columns=1, | |
height="512px", | |
preview=True, | |
show_label=False | |
) | |
with gr.Group(): | |
used_seed = gr.Number(label="Used Seed", interactive=False) | |
gr.Examples( | |
examples=examples, | |
inputs=prompt, | |
outputs=[result, used_seed], | |
fn=lambda *args, **kwargs: generate(*args, **kwargs), | |
cache_examples=CACHE_EXAMPLES, | |
) | |
gr.on( | |
triggers=[ | |
prompt.submit, | |
negative_prompt.submit, | |
run_button.click, | |
], | |
fn=randomize_seed_fn, | |
inputs=[seed, randomize_seed], | |
outputs=seed, | |
queue=False, | |
api_name=False, | |
).then( | |
fn=generate, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
sampler, | |
clip_skip | |
], | |
outputs=[result, used_seed], | |
api_name="run" | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(show_error=True) |