import os import random from typing import Callable, Dict, Optional, Tuple import gradio as gr import numpy as np import PIL.Image import spaces import torch from transformers import CLIPTextModel from diffusers import AutoencoderKL, StableDiffusionXLPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler MODEL = "eienmojiki/Starry-XL-v5.2" HF_TOKEN = os.getenv("HF_TOKEN") MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512")) MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048")) MAX_SEED = np.iinfo(np.int32).max sampler_list = [ "DPM++ 2M Karras", "DPM++ SDE Karras", "DPM++ 2M SDE Karras", "Euler", "Euler a", "DDIM", ] torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def seed_everything(seed: int) -> torch.Generator: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) generator = torch.Generator() generator.manual_seed(seed) return generator def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]: scheduler_factory_map = { "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config( scheduler_config, use_karras_sigmas=True ), "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config( scheduler_config, use_karras_sigmas=True ), "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config( scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++" ), "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config), "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(scheduler_config), "DDIM": lambda: DDIMScheduler.from_config(scheduler_config), } return scheduler_factory_map.get(name, lambda: None)() def load_pipeline(model_name): pipe = StableDiffusionXLPipeline.from_pretrained( model_name, torch_dtype=torch.float16, custom_pipeline="lpw_stable_diffusion_xl", safety_checker = None, use_safetensors=True, add_watermarker=False, use_auth_token=HF_TOKEN, ) pipe.to(device) return pipe @spaces.GPU def generate( prompt: str, negative_prompt: str = None, seed: int = 0, width: int = 1024, height: int = 1024, guidance_scale: float = 5.0, num_inference_steps: int = 24, sampler: str = "Euler a", clip_skip: int = 1, progress=gr.Progress(track_tqdm=True), ): generator = seed_everything(seed) pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler) pipe.text_encoder = CLIPTextModel.from_pretrained( MODEL, subfolder = "text_encoder", num_hidden_layers = 12 - (clip_skip - 1), torch_dtype = torch.float16 ) try: img = pipe( prompt = prompt, negative_prompt = negative_prompt, width = width, height = height, guidance_scale = guidance_scale, num_inference_steps = num_inference_steps, generator = generator, output_type="pil", ).images[0] return img, seed except Exception as e: print(f"An error occurred: {e}") if torch.cuda.is_available(): pipe = load_pipeline(MODEL) print("Loaded on Device!") else: pipe = None with gr.Blocks( theme=gr.themes.Soft() ) as demo: gr.Markdown("# Starry XL 5.2 Demo") with gr.Group(): prompt = gr.Text( label="Prompt", placeholder="Enter your prompt here..." ) negative_prompt = gr.Text( label="Negative Prompt", placeholder="(Optional) Enter your negative prompt here..." ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) sampler = gr.Dropdown( label="Sampler", choices=sampler_list, interactive=True, value="Euler a", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1, maximum=20, step=0.1, value=5.0, ) num_inference_steps = gr.Slider( label="Steps", minimum=10, maximum=100, step=1, value=25, ) clip_skip = gr.Slider( label="Clip Skip", minimum=1, maximum=2, step=1, value=1 ) run_button = gr.Button("Run") result = gr.Image( label="Result", show_label=False ) with gr.Group(): used_seed = gr.Number(label="Used Seed", interactive=False) gr.on( triggers=[ prompt.submit, negative_prompt.submit, run_button.click, ], fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=generate, inputs=[ prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps, sampler, clip_skip ], outputs=[result, used_seed], api_name="run" ) if __name__ == "__main__": demo.queue(max_size=20).launch(show_error=True)