import gradio as gr import numpy as np import random import torch from PIL import Image from diffusers import ( DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel ) from peft import PeftModel device = "cuda" if torch.cuda.is_available() else "cpu" LORA_MODEL = "akaUNik/hw5-homm3-lora-15" LORA_BASE_MODEL = "runwayml/stable-diffusion-v1-5" # Model list including LoRA model MODEL_LIST = [ "runwayml/stable-diffusion-v1-5", "stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-2-1", LORA_MODEL, # LoRA model option ] # ControlNet modes list with aliases CONTROLNET_MODES = { "Canny Edge Detection": "lllyasviel/control_v11p_sd15_canny", "Pixel to Pixel": "lllyasviel/control_v11e_sd15_ip2p", "Inpainting": "lllyasviel/control_v11p_sd15_inpaint", "Multi-Level Line Segments": "lllyasviel/control_v11p_sd15_mlsd", "Depth Estimation": "lllyasviel/control_v11f1p_sd15_depth", "Surface Normal Estimation": "lllyasviel/control_v11p_sd15_normalbae", "Image Segmentation": "lllyasviel/control_v11p_sd15_seg", "Line Art Generation": "lllyasviel/control_v11p_sd15_lineart", "Anime Line Art": "lllyasviel/control_v11p_sd15_lineart_anime", "Human Pose Estimation": "lllyasviel/control_v11p_sd15_openpose", "Scribble-Based Generation": "lllyasviel/control_v11p_sd15_scribble", "Soft Edge Generation": "lllyasviel/control_v11p_sd15_softedge", "Image Shuffling": "lllyasviel/control_v11e_sd15_shuffle", "Image Tiling": "lllyasviel/control_v11f1e_sd15_tile", } if torch.cuda.is_available(): torch_dtype = torch.float16 else: torch_dtype = torch.float32 # Cache to avoid re-initializing pipelines repeatedly model_cache = {} MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 512 def infer( model_id, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_scale, controlnet_enable, controlnet_mode, controlnet_strength, controlnet_image, ip_adapter_enable, ip_adapter_scale, ip_adapter_image, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) # Cache # if (model_id, controlnet_enable, controlnet_image, controlnet_mode) in model_cache: # pipe = model_cache[(model_id, controlnet_enable, controlnet_image, controlnet_mode)] # else: pipe = None if controlnet_enable and controlnet_image: controlnet_model = ControlNetModel.from_pretrained( CONTROLNET_MODES.get(controlnet_mode), torch_dtype=torch_dtype ) if model_id == LORA_MODEL: pipe = StableDiffusionControlNetPipeline.from_pretrained( LORA_BASE_MODEL, controlnet=controlnet_model, torch_dtype=torch_dtype ) else: pipe = StableDiffusionControlNetPipeline.from_pretrained( model_id, controlnet=controlnet_model, torch_dtype=torch_dtype ) else: if model_id == LORA_MODEL: # Use the specified base model for your LoRA adapter. pipe = DiffusionPipeline.from_pretrained( LORA_BASE_MODEL, torch_dtype=torch_dtype ) # Load the LoRA weights pipe.unet = PeftModel.from_pretrained( pipe.unet, model_id, subfolder="unet", torch_dtype=torch_dtype ) pipe.text_encoder = PeftModel.from_pretrained( pipe.text_encoder, model_id, subfolder="text_encoder", torch_dtype=torch_dtype ) else: pipe = DiffusionPipeline.from_pretrained( model_id, torch_dtype=torch_dtype ) if ip_adapter_enable: pipe.load_ip_adapter( "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin" ) pipe.set_ip_adapter_scale(ip_adapter_scale) pipe.safety_checker = None pipe.to(device) # model_cache[(model_id, controlnet_enable, controlnet_image, controlnet_mode)] = pipe image = pipe( prompt=prompt, image=controlnet_image if controlnet_enable else None, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, cross_attention_kwargs={"scale": lora_scale}, controlnet_conditioning_scale=controlnet_strength, ip_adapter_image=ip_adapter_image if ip_adapter_enable else None ).images[0] return image, seed # @title Gradio examples = [ "homm3_spell_icon midivial sticker of a cartoon character of a man in a lab coat and glasses, old lady screaming and laughing", "homm3_spell_icon midivial sticker of a cartoon man with a mustache and a hat on, portrait bender from futurama, telegram sticker", "homm3_spell_icon midivial sticker of a cartoon character with a gun in his hand", ] css = """ #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(" # Text-to-Image Gradio Template") with gr.Row(): # Dropdown to select the model from Hugging Face model_id = gr.Dropdown( label="Model", choices=MODEL_LIST, value=MODEL_LIST[0], # Default model ) with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0, variant="primary") result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, # Default seed ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=20.0, step=0.5, value=7.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=100, step=1, value=20, ) # New slider for LoRA scale. lora_scale = gr.Slider( label="LoRA Scale", minimum=0.0, maximum=2.0, step=0.1, value=1.0, info="Adjust the influence of the LoRA weights", ) # --- ControlNet Settings --- with gr.Accordion("ControlNet Settings", open=False): controlnet_enable = gr.Checkbox( label="Enable ControlNet", value=False ) with gr.Group(visible=False) as controlnet_group: controlnet_mode = gr.Dropdown( label="ControlNet Mode", choices=list(CONTROLNET_MODES.keys()), value=list(CONTROLNET_MODES.keys())[0], ) controlnet_strength = gr.Slider( label="ControlNet Conditioning Scale", minimum=0.0, maximum=1.0, step=0.1, value=0.7, ) controlnet_image = gr.Image( label="ControlNet Image", type="pil" ) def show_controlnet_options(enable): return {controlnet_group: gr.update(visible=enable)} controlnet_enable.change( fn=show_controlnet_options, inputs=controlnet_enable, outputs=controlnet_group, ) # --- IP-adapter Settings --- with gr.Accordion("IP-adapter Settings", open=False): ip_adapter_enable = gr.Checkbox( label="Enable IP-adapter", value=False ) with gr.Group(visible=False) as ip_adapter_group: ip_adapter_scale = gr.Slider( label="IP-adapter Scale", minimum=0.0, maximum=2.0, step=0.1, value=1.0 ) ip_adapter_image = gr.Image( label="IP-adapter Image", type="pil" ) # Show/hide IP-adapter parameters when checkbox is toggled def show_ip_adapter_options(enable): return {ip_adapter_group: gr.update(visible=enable)} ip_adapter_enable.change( fn=show_ip_adapter_options, inputs=ip_adapter_enable, outputs=ip_adapter_group, ) gr.Examples(examples=examples, inputs=[prompt]) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ model_id, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_scale, controlnet_enable, controlnet_mode, controlnet_strength, controlnet_image, ip_adapter_enable, ip_adapter_scale, ip_adapter_image, ], outputs=[result, seed], ) # @title Run if __name__ == "__main__": demo.launch(debug=True) # show errors in colab notebook