squaadai / app.py
forplaytvplus's picture
Update app.py
a896fab verified
raw
history blame
11.2 kB
#!/usr/bin/env pythona
from __future__ import annotations
import requests
import os
import random
import random
import string
import gradio as gr
import numpy as np
import spaces
import torch
import gc
import cv2
from PIL import Image
from accelerate import init_empty_weights
from io import BytesIO
from diffusers.utils import load_image
from diffusers import StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, AutoencoderKL, DiffusionPipeline, AutoPipelineForImage2Image, AutoPipelineForInpainting, UNet2DConditionModel
from controlnet_aux import HEDdetector
from compel import Compel, ReturnedEmbeddingsType
import threading
DESCRIPTION = "# Run any LoRA or SD Model"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>⚠️ This space is running on the CPU. This demo doesn't work on CPU 😞! Run on a GPU by duplicating this space or test our website for free and unlimited by <a href='https://squaadai.com'>clicking here</a>, which provides these and more options.</p>"
MAX_SEED = np.iinfo(np.int32).max
CUDA_LAUNCH_BLOCKING=1
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1824"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
ENABLE_USE_LORA = os.getenv("ENABLE_USE_LORA", "1") == "1"
ENABLE_USE_LORA2 = os.getenv("ENABLE_USE_LORA2", "1") == "1"
ENABLE_USE_IMG2IMG = os.getenv("ENABLE_USE_IMG2IMG", "1") == "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
cached_pipelines = {} # Dicionário para armazenar os pipelines
cached_loras = {}
# Crie um objeto Lock
pipeline_lock = threading.Lock()
@spaces.GPU
def generate(
prompt: str = "",
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale_base: float = 5.0,
num_inference_steps_base: int = 25,
strength_img2img: float = 0.7,
use_lora: bool = False,
use_lora2: bool = False,
model = 'stabilityai/stable-diffusion-xl-base-1.0',
lora = '',
lora2 = '',
lora_scale: float = 0.7,
lora_scale2: float = 0.7,
use_img2img: bool = False,
url = '',
):
global cached_pipelines, cached_loras
if torch.cuda.is_available():
# Construa a chave do dicionário baseada no modelo e no tipo de pipeline
pipeline_key = (model, use_img2img)
if pipeline_key not in cached_pipelines:
if not use_img2img:
cached_pipelines[pipeline_key] = DiffusionPipeline.from_pretrained(model, safety_checker=None, requires_safety_checker=False, torch_dtype=torch.float16, low_cpu_mem_usage=True)
elif use_img2img:
cached_pipelines[pipeline_key] = AutoPipelineForImage2Image.from_pretrained(model, safety_checker=None, requires_safety_checker=False, torch_dtype=torch.float16, low_cpu_mem_usage=True)
pipe = cached_pipelines[pipeline_key] # Usa o pipeline carregado da memória
if use_img2img:
init_image = load_image(url)
if use_lora:
lora_key = (lora, lora_scale)
if lora_key not in cached_loras:
adapter_name = ''.join(random.choice(string.ascii_letters) for _ in range(5))
pipe.load_lora_weights(lora, adapter_name=adapter_name)
cached_loras[lora_key] = adapter_name
else:
adapter_name = cached_loras[lora_key]
pipe.set_adapters(adapter_name, adapter_weights=[lora_scale])
if use_lora2:
lora_key1 = (lora, lora_scale)
lora_key2 = (lora2, lora_scale2)
if lora_key1 not in cached_loras:
adapter_name1 = ''.join(random.choice(string.ascii_letters) for _ in range(5))
pipe.load_lora_weights(lora, adapter_name=adapter_name1)
cached_loras[lora_key1] = adapter_name1
else:
adapter_name1 = cached_loras[lora_key1]
if lora_key2 not in cached_loras:
adapter_name2 = ''.join(random.choice(string.ascii_letters) for _ in range(5))
pipe.load_lora_weights(lora2, adapter_name=adapter_name2)
cached_loras[lora_key2] = adapter_name2
else:
adapter_name2 = cached_loras[lora_key2]
pipe.set_adapters([adapter_name1, adapter_name2], adapter_weights=[lora_scale, lora_scale2])
pipe.enable_model_cpu_offload()
generator = torch.Generator().manual_seed(seed)
if not use_negative_prompt:
negative_prompt = None # type: ignore
with pipeline_lock:
if use_img2img:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=init_image,
strength=strength_img2img,
width=width,
height=height,
guidance_scale=guidance_scale_base,
num_inference_steps=num_inference_steps_base,
generator=generator,
).images[0]
else:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale_base,
num_inference_steps=num_inference_steps_base,
generator=generator,
).images[0]
# Limpeza de memória
del pipe
torch.cuda.empty_cache()
gc.collect()
return result
with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
gr.HTML(
"<p><center>📙 For any additional support, join our <a href='https://discord.gg/JprjXpjt9K'>Discord</a></center></p>"
)
gr.Markdown(DESCRIPTION, elem_id="description")
with gr.Group():
model = gr.Text(label='Model', placeholder='e.g. stabilityai/stable-diffusion-xl-base-1.0')
lora = gr.Text(label='LoRA 1', placeholder='e.g. nerijs/pixel-art-xl')
lora2 = gr.Text(label='LoRA 2', placeholder='e.g. nerijs/pixel-art-xl')
lora_scale = gr.Slider(
info="The closer to 1, the more it will resemble LoRA, but errors may be visible.",
label="Lora Scale 1",
minimum=0.01,
maximum=1,
step=0.01,
value=0.7,
)
lora_scale2 = gr.Slider(
info="The closer to 1, the more it will resemble LoRA, but errors may be visible.",
label="Lora Scale 2",
minimum=0.01,
maximum=1,
step=0.01,
value=0.7,
)
url = gr.Text(label='URL (Img2Img)')
with gr.Row():
prompt = gr.Text(
placeholder="Input prompt",
label="Prompt",
show_label=False,
max_lines=1,
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
use_img2img = gr.Checkbox(label='Use Img2Img', value=False, visible=ENABLE_USE_IMG2IMG)
use_lora = gr.Checkbox(label='Use Lora 1', value=False, visible=ENABLE_USE_LORA)
use_lora2 = gr.Checkbox(label='Use Lora 2', value=False, visible=ENABLE_USE_LORA2)
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
negative_prompt = gr.Text(
placeholder="Input Negative Prompt",
label="Negative prompt",
max_lines=1,
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale_base = gr.Slider(
info="Scale for classifier-free guidance",
label="Guidance scale",
minimum=1,
maximum=20,
step=0.1,
value=5.0,
)
with gr.Row():
num_inference_steps_base = gr.Slider(
info="Number of denoising steps",
label="Number of inference steps",
minimum=10,
maximum=100,
step=1,
value=25,
)
with gr.Row():
strength_img2img = gr.Slider(
info="Strength for Img2Img",
label="Strength",
minimum=0,
maximum=1,
step=0.01,
value=0.7,
)
use_negative_prompt.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt,
outputs=negative_prompt,
queue=False,
api_name=False,
)
use_lora.change(
fn=lambda x: gr.update(visible=x),
inputs=use_lora,
outputs=lora,
queue=False,
api_name=False,
)
use_lora2.change(
fn=lambda x: gr.update(visible=x),
inputs=use_lora2,
outputs=lora2,
queue=False,
api_name=False,
)
use_img2img.change(
fn=lambda x: gr.update(visible=x),
inputs=use_img2img,
outputs=url,
queue=False,
api_name=False,
)
gr.on(
triggers=[
prompt.submit,
negative_prompt.submit,
run_button.click,
],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=[
prompt,
negative_prompt,
use_negative_prompt,
seed,
width,
height,
guidance_scale_base,
num_inference_steps_base,
strength_img2img,
use_lora,
use_lora2,
model,
lora,
lora2,
lora_scale,
lora_scale2,
use_img2img,
url,
],
outputs=result,
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=4, default_concurrency_limit=4).launch()