import torch import gradio as gr import spaces from pipeline import ChatsSDXLPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import CLIPFeatureExtractor from diffusers.utils import logging from PIL import Image logging.set_verbosity_error() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") # Load CHATS-SDXL pipeline pipe = ChatsSDXLPipeline.from_pretrained( "AIDC-AI/CHATS", safety_checker=safety_checker, feature_extractor=feature_extractor, torch_dtype=torch.float16 ) pipe.to(DEVICE) @spaces.GPU def generate(prompt, steps=50, guidance_scale=7.5, height=768, width=512): output = pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=guidance_scale, height=height, width=width, seed=0 ) return output['images'] # image = output['images'][0] # image = Image.fromarray(image) # return image with gr.Blocks(title="🔥 CHATS-SDXL Demo") as demo: gr.Markdown( "## CHATS-SDXL Text-to-Image Demo\n\n" "Enter your prompt and click **Generate Image**. All NSFW content will be automatically filtered." ) with gr.Row(): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your description here...", lines=2, ) with gr.Row(): steps_slider = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Inference Steps" ) scale_slider = gr.Slider( minimum=1.0, maximum=14.0, value=5.0, step=0.1, label="Guidance Scale" ) with gr.Row(): height_slider = gr.Slider( minimum=64, maximum=2048, value=1024, step=64, label="Image Height" ) width_slider = gr.Slider( minimum=64, maximum=2048, value=1024, step=64, label="Image Width" ) generate_button = gr.Button("Generate Image") gallery = gr.Gallery( label="Generated Images", show_label=False, columns=2, elem_id="gallery" ) generate_button.click( fn=generate, inputs=[prompt_input, steps_slider, scale_slider, height_slider, width_slider], outputs=[gallery], ) if __name__ == "__main__": demo.launch()