face-to-all-api / app.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Update app.py
60fdd92 verified
raw
history blame
15.6 kB
import gradio as gr
import torch
torch.jit.script = lambda f: f
import timm
import time
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
import lora
import copy
import json
import gc
import random
from urllib.parse import quote
import gdown
import os
import diffusers
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler
import cv2
import torch
import numpy as np
from PIL import Image
from io import BytesIO
import base64
import re
from insightface.app import FaceAnalysis
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
from controlnet_aux import ZoeDetector
from compel import Compel, ReturnedEmbeddingsType
#import spaces
#from gradio_imageslider import ImageSlider
# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
def readb64(b64):
# Remove any data URI scheme prefix with regex
b64 = data_uri_pattern.sub("", b64)
# Decode and open the image with PIL
img = Image.open(BytesIO(base64.b64decode(b64)))
return img
# convert from PIL to base64
def writeb64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
b64image = base64.b64encode(buffered.getvalue())
b64image_str = b64image.decode("utf-8")
return b64image_str
with open("sdxl_loras.json", "r") as file:
data = json.load(file)
sdxl_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
"likes": item.get("likes", 0),
"downloads": item.get("downloads", 0),
"is_nc": item.get("is_nc", False),
"new": item.get("new", False),
}
for item in data
]
with open("defaults_data.json", "r") as file:
lora_defaults = json.load(file)
def getLoraByRepoName(repo_name):
# Loop through each lora in sdxl_loras_raw
for lora in sdxl_loras_raw:
if lora["repo"] == repo_name:
# Return the lora if the repo name matches
return lora
# If no match is found, return the first lora in the array
return sdxl_loras_raw[0] if sdxl_loras_raw else None
# Return the default values specific to this particular
def getLoraDefaultsByRepoName(repo_name):
# Loop through each lora in sdxl_loras_raw
for lora in lora_defaults:
if lora["repo"] == repo_name:
# Return the lora if the repo name matches
return lora
# If no match is found, return the first lora in the array
return lora_defaults[0] if lora_defaults else None
device = "cuda"
state_dicts = {}
for item in sdxl_loras_raw:
saved_name = hf_hub_download(item["repo"], item["weights"])
if not saved_name.endswith('.safetensors'):
state_dict = torch.load(saved_name)
else:
state_dict = load_file(saved_name)
state_dicts[item["repo"]] = {
"saved_name": saved_name,
"state_dict": state_dict
}
sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
# download models
hf_hub_download(
repo_id="InstantX/InstantID",
filename="ControlNetModel/config.json",
local_dir="/data/checkpoints",
)
hf_hub_download(
repo_id="InstantX/InstantID",
filename="ControlNetModel/diffusion_pytorch_model.safetensors",
local_dir="/data/checkpoints",
)
hf_hub_download(
repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
)
hf_hub_download(
repo_id="latent-consistency/lcm-lora-sdxl",
filename="pytorch_lora_weights.safetensors",
local_dir="/data/checkpoints",
)
# download antelopev2
if not os.path.exists("/data/antelopev2.zip"):
gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="/data/", quiet=False, fuzzy=True)
os.system("unzip /data/antelopev2.zip -d /data/models/")
app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
# prepare models under ./checkpoints
face_adapter = f'/data/checkpoints/ip-adapter.bin'
controlnet_path = f'/data/checkpoints/ControlNetModel'
# load IdentityNet
st = time.time()
identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16)
et = time.time()
elapsed_time = et - st
print('Loading ControlNet took: ', elapsed_time, 'seconds')
st = time.time()
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
et = time.time()
elapsed_time = et - st
print('Loading VAE took: ', elapsed_time, 'seconds')
st = time.time()
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("rubbrband/albedobaseXL_v21",
vae=vae,
controlnet=[identitynet, zoedepthnet],
torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
pipe.load_ip_adapter_instantid(face_adapter)
pipe.set_ip_adapter_scale(0.8)
et = time.time()
elapsed_time = et - st
print('Loading pipeline took: ', elapsed_time, 'seconds')
st = time.time()
compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
et = time.time()
elapsed_time = et - st
print('Loading Compel took: ', elapsed_time, 'seconds')
st = time.time()
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
et = time.time()
elapsed_time = et - st
print('Loading Zoe took: ', elapsed_time, 'seconds')
zoe.to(device)
pipe.to(device)
last_lora = ""
last_fused = False
def center_crop_image_as_square(img):
square_size = min(img.size)
left = (img.width - square_size) / 2
top = (img.height - square_size) / 2
right = (img.width + square_size) / 2
bottom = (img.height + square_size) / 2
img_cropped = img.crop((left, top, right, bottom))
return img_cropped
def check_selected(selected_state):
if not selected_state:
raise gr.Error("You must select a style")
def merge_incompatible_lora(full_path_lora, lora_scale):
for weights_file in [full_path_lora]:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = lora_scale
lora_model, weights_sd = lora.create_network_from_weights(
multiplier,
full_path_lora,
pipe.vae,
pipe.text_encoder,
pipe.unet,
for_inference=True,
)
lora_model.merge_to(
pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
)
del weights_sd
del lora_model
#@spaces.GPU
def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, lora, full_path_lora, lora_scale, st):
et = time.time()
elapsed_time = et - st
print('Getting into the decorated function took: ', elapsed_time, 'seconds')
global last_fused, last_lora
print("Last LoRA: ", last_lora)
print("Current LoRA: ", lora["repo"])
print("Last fused: ", last_fused)
#prepare face zoe
st = time.time()
with torch.no_grad():
image_zoe = zoe(face_image)
width, height = face_kps.size
images = [face_kps, image_zoe.resize((height, width))]
et = time.time()
elapsed_time = et - st
print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
if last_lora != lora["repo"]:
if(last_fused):
st = time.time()
pipe.unfuse_lora()
pipe.unload_lora_weights()
et = time.time()
elapsed_time = et - st
print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
st = time.time()
pipe.load_lora_weights(full_path_lora)
pipe.fuse_lora(lora_scale)
et = time.time()
elapsed_time = et - st
print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
last_fused = True
if(lora["is_pivotal"]):
#Add the textual inversion embeddings from pivotal tuning models
text_embedding_name = lora["text_embedding_weights"]
embedding_path = hf_hub_download(repo_id=lora["repo"], filename=text_embedding_name, repo_type="model")
state_dict_embedding = load_file(embedding_path)
try:
pipe.unload_textual_inversion()
pipe.load_textual_inversion(state_dict_embedding["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(state_dict_embedding["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
except:
pipe.unload_textual_inversion()
pipe.load_textual_inversion(state_dict_embedding["text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(state_dict_embedding["text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
print("Processing prompt...")
st = time.time()
conditioning, pooled = compel(prompt)
if(negative):
negative_conditioning, negative_pooled = compel(negative)
else:
negative_conditioning, negative_pooled = None, None
et = time.time()
elapsed_time = et - st
print('Prompt processing took: ', elapsed_time, 'seconds')
print("Processing image...")
st = time.time()
image = pipe(
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
negative_prompt_embeds=negative_conditioning,
negative_pooled_prompt_embeds=negative_pooled,
width=1024,
height=1024,
image_embeds=face_emb,
image=face_image,
strength=1-image_strength,
control_image=images,
num_inference_steps=20,
guidance_scale = guidance_scale,
controlnet_conditioning_scale=[face_strength, depth_control_scale],
).images[0]
et = time.time()
elapsed_time = et - st
print('Image processing took: ', elapsed_time, 'seconds')
last_lora = lora["repo"]
return image
def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, lora_repo_name):
# get the lora and its defaulrt values
lora = getLoraByRepoName(lora_repo_name)
default_values = getLoraDefaultsByRepoName(lora_repo_name)
st = time.time()
face_image = readb64(face_image)
face_image = center_crop_image_as_square(face_image)
try:
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
face_emb = face_info['embedding']
face_kps = draw_kps(face_image, face_info['kps'])
except:
raise gr.Error("No face found in your image. Only face images work here. Try again")
et = time.time()
elapsed_time = et - st
print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
st = time.time()
if default_values:
prompt_full = default_values.get("prompt", None)
if(prompt_full):
prompt = prompt_full.replace("<subject>", prompt)
print("Prompt:", prompt)
if(prompt == ""):
prompt = "a person"
if negative == "":
negative = None
if not selected_state:
raise gr.Error("You must select a LoRA")
weight_name = lora["weights"]
full_path_lora = state_dicts[lora["repo"]]["saved_name"]
#loaded_state_dict = copy.deepcopy(state_dicts[lora_repo_name]["state_dict"])
cross_attention_kwargs = None
et = time.time()
elapsed_time = et - st
print('Small content processing took: ', elapsed_time, 'seconds')
st = time.time()
image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, lora, full_path_lora, lora_scale, st)
image_base64 = writeb64(image)
return image_base64
with gr.Blocks() as demo:
gr.HTML("""
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
<div style="text-align: center; color: black;">
<p style="color: black;">This space is a REST API to programmatically generate an image from a face.</p>
<p style="color: black;">Interested in using it through an UI? Please use the <a href="https://huggingface.co/spaces/multimodalart/face-to-all" target="_blank">original space</a>, thank you!</p>
</div>
</div>""")
input_image_base64 = gr.Text()
lora_repo_name = gr.Text(label="name of the LoRA repo nape on HF")
prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
negative = gr.Textbox(label="Negative Prompt")
# initial value was 0.9
weight = gr.Slider(0, 10, value=6, step=0.1, label="LoRA weight")
# initial value was 0.85
face_strength = gr.Slider(0, 1, value=0.75, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
# initial value was 0.15
image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
# initial value was 7
guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
# initial value was 1
depth_control_scale = gr.Slider(0, 4, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
button = gr.Button(value="Generate")
output_image_base64 = gr.Text()
button.click(
fn=run_lora,
inputs=[
input_image_base64,
prompt,
negative,
weight,
face_strength,
image_strength,
guidance_scale,
depth_control_scale,
lora_repo_name
],
outputs=output_image_base64,
api_name='run',
)
demo.queue(max_size=20)
demo.launch()