Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
sys.path.append('./') | |
from adaface.adaface_wrapper import AdaFaceWrapper | |
import torch | |
import numpy as np | |
import random | |
import os, re | |
import time | |
import gradio as gr | |
import spaces | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected.") | |
def is_running_on_spaces(): | |
return os.getenv("SPACE_ID") is not None | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"], | |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") | |
parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt', | |
help="Path to the checkpoint of the ID2Ada prompt encoders") | |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face). | |
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0], | |
help="Scales for the ID2Ada prompt encoders") | |
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None, | |
choices=["arc2face", "consistentID"], | |
help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)") | |
parser.add_argument('--model_style_type', type=str, default='photorealistic', | |
choices=["realistic", "anime", "photorealistic"], help="Type of the base model") | |
parser.add_argument("--guidance_scale", type=float, default=5.0, | |
help="The guidance scale for the diffusion model. Default: 5.0") | |
parser.add_argument("--unet_uses_attn_lora", type=str2bool, nargs="?", const=True, default=False, | |
help="Whether to use LoRA in the Diffusers UNet model") | |
# --attn_lora_layer_names and --q_lora_updates_query are only effective | |
# when --unet_uses_attn_lora is set to True. | |
parser.add_argument("--attn_lora_layer_names", type=str, nargs="*", default=['q', 'k', 'v', 'out'], | |
choices=['q', 'k', 'v', 'out'], help="Names of the cross-attn components to apply LoRA on") | |
parser.add_argument("--q_lora_updates_query", type=str2bool, nargs="?", const=True, default=False, | |
help="Whether the q LoRA updates the query in the Diffusers UNet model. " | |
"If False, the q lora only updates query2.") | |
parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False, | |
help="Whether to show the checkbox for disabling AdaFace") | |
parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images") | |
parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False, | |
help="Only test the UI layout, and skip loadding the adaface model") | |
parser.add_argument('--gpu', type=int, default=None) | |
parser.add_argument('--ip', type=str, default="0.0.0.0") | |
args = parser.parse_args() | |
from huggingface_hub import snapshot_download | |
large_files = ["models/*", "models/**/*"] | |
snapshot_download(repo_id="adaface-neurips/adaface-models", repo_type="model", allow_patterns=large_files, local_dir=".") | |
os.makedirs("/tmp/gradio", exist_ok=True) | |
model_style_type2base_model_path = { | |
"realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors", | |
"anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors", | |
"photorealistic": "models/sar/sar.safetensors", # LDM format. Needs to be converted. | |
} | |
base_model_path = model_style_type2base_model_path[args.model_style_type] | |
# global variable | |
MAX_SEED = np.iinfo(np.int32).max | |
global adaface | |
adaface = None | |
if not args.test_ui_only: | |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path, | |
adaface_encoder_types=args.adaface_encoder_types, | |
adaface_ckpt_paths=args.adaface_ckpt_path, | |
adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales, | |
enabled_encoders=args.enabled_encoders, | |
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None, | |
unet_uses_attn_lora=args.unet_uses_attn_lora, | |
attn_lora_layer_names=args.attn_lora_layer_names, | |
shrink_cross_attn=False, | |
q_lora_updates_query=args.q_lora_updates_query, | |
device='cpu') | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def swap_to_gallery(images): | |
# Update uploaded_files_gallery, show files, hide clear_button_column | |
# Or: | |
# Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column | |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False) | |
def remove_back_to_files(): | |
# Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx | |
# Or: | |
# Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx | |
return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), \ | |
gr.update(value=""), gr.update(value="(none)") | |
def generate_image(image_paths, image_paths2, guidance_scale, perturb_std, | |
num_images, prompt, negative_prompt, gender, highlight_face, | |
ablate_prompt_embed_type, nonmix_prompt_emb_weight, | |
composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)): | |
global adaface | |
if is_running_on_spaces(): | |
device = 'cuda:0' | |
else: | |
if args.gpu is None: | |
device = "cuda" | |
else: | |
device = f"cuda:{args.gpu}" | |
print(f"Device: {device}") | |
adaface.to(device) | |
if image_paths is None or len(image_paths) == 0: | |
raise gr.Error(f"Cannot find any input face image! Please upload a face image.") | |
if image_paths2 is not None and len(image_paths2) > 0: | |
image_paths = image_paths + image_paths2 | |
if prompt is None: | |
prompt = "" | |
adaface_subj_embs = \ | |
adaface.prepare_adaface_embeddings(image_paths=image_paths, face_id_embs=None, | |
avg_at_stage='id_emb', | |
perturb_at_stage='img_prompt_emb', | |
perturb_std=perturb_std, update_text_encoder=True) | |
if adaface_subj_embs is None: | |
raise gr.Error(f"Failed to detect any faces! Please try with other images") | |
# Sometimes the pipeline is on CPU, although we've put it on CUDA (due to some offloading mechanism). | |
# Therefore we set the generator to the correct device. | |
generator = torch.Generator(device=device).manual_seed(seed) | |
print(f"Manual seed: {seed}.") | |
# Generate two images each time for the user to select from. | |
noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator) | |
#print(noise.abs().sum()) | |
# samples: A list of PIL Image instances. | |
if highlight_face: | |
if "portrait" not in prompt: | |
prompt = "face portrait, " + prompt | |
else: | |
prompt = prompt.replace("portrait", "face portrait") | |
if composition_level >= 2: | |
if "full body" not in prompt: | |
prompt = prompt + ", full body view" | |
if gender != "(none)": | |
if "portrait" in prompt: | |
prompt = prompt.replace("portrait, ", f"portrait, {gender} ") | |
else: | |
prompt = gender + ", " + prompt | |
generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed) | |
samples = adaface(noise, prompt, negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
out_image_count=num_images, generator=generator, | |
repeat_prompt_for_each_encoder=(composition_level >= 1), | |
ablate_prompt_no_placeholders=disable_adaface, | |
ablate_prompt_embed_type=ablate_prompt_embed_type, | |
nonmix_prompt_emb_weight=nonmix_prompt_emb_weight, | |
verbose=True) | |
session_signature = ",".join(image_paths + [prompt, str(seed)]) | |
temp_folder = os.path.join("/tmp/gradio", f"{hash(session_signature)}") | |
os.makedirs(temp_folder, exist_ok=True) | |
saved_image_paths = [] | |
if "models/adaface/" in args.adaface_ckpt_path: | |
# The model is loaded from within the project. | |
# models/adaface/VGGface2_HQ_masks2024-10-14T16-09-24_zero3-ada-3500.pt | |
matches = re.search(r"models/adaface/\w+\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada-(\d+).pt", args.adaface_ckpt_path) | |
else: | |
# The model is loaded from the adaprompt folder. | |
# adaface_ckpt_path = "VGGface2_HQ_masks2024-11-28T13-13-20_zero3-ada/checkpoints/embeddings_gs-2000.pt" | |
matches = re.search(r"\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada/checkpoints/embeddings_gs-(\d+).pt", args.adaface_ckpt_path) | |
# Extract the checkpoint signature as 112813-2000 | |
ckpt_sig = f"{matches.group(1)}{matches.group(2)}{matches.group(3)}-{matches.group(4)}" | |
prompt_keywords = ['armor', 'beach', 'chef', 'dancing', 'iron man', 'jedi', | |
'street', 'guitar', 'reading', 'running', 'superman', 'new year', 'mars'] | |
keywords_reduction = { 'iron man': 'ironman', 'dancing': 'dance', | |
'running': 'run', 'reading': 'read', 'new year': 'newyear' } | |
prompt_sig = None | |
for keyword in prompt_keywords: | |
if keyword in prompt.lower(): | |
prompt_sig = keywords_reduction.get(keyword, keyword) | |
break | |
if prompt_sig is None: | |
prompt_parts = prompt.lower().split(",") | |
# Remove the view/shot parts (full body view, long shot, etc.) from the prompt. | |
prompt_parts = [ part for part in prompt_parts if not re.search(r"\W(view|shot)(\W|$)", part) ] | |
if len(prompt_parts) > 0: | |
# Use the last word of the prompt as the signature. | |
prompt_sig = prompt_parts[-1].split()[-1] | |
else: | |
prompt_sig = "person" | |
if len(prompt_sig) > 0: | |
prompt_sig = "-" + prompt_sig | |
extra_save_dir = args.extra_save_dir | |
if extra_save_dir is not None: | |
os.makedirs(extra_save_dir, exist_ok=True) | |
for i, sample in enumerate(samples): | |
filename = f"adaface{ckpt_sig}{prompt_sig}-{i+1}.png" | |
if len(subj_name_sig) > 0: | |
filename = f"{subj_name_sig.lower()}-{filename}" | |
filepath = os.path.join(temp_folder, filename) | |
# Save the image | |
sample.save(filepath) # Adjust to your image saving method | |
saved_image_paths.append(filepath) | |
if extra_save_dir is not None: | |
extra_filepath = os.path.join(extra_save_dir, filename) | |
sample.save(extra_filepath) | |
print(extra_filepath) | |
# Solution suggested by o1 to force the client browser to reload images | |
# when we change guidance scales only. | |
saved_image_paths = [f"{url}?t={int(time.time())}" for url in saved_image_paths] | |
return saved_image_paths | |
def check_prompt_and_model_type(prompt, model_style_type, adaface_encoder_cfg_scale1): | |
global adaface | |
model_style_type = model_style_type.lower() | |
# If the base model type is changed, reload the model. | |
if model_style_type != args.model_style_type or adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]: | |
if model_style_type != args.model_style_type: | |
# Update base model type. | |
args.model_style_type = model_style_type | |
print(f"Switching to the base model type: {model_style_type}.") | |
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type], | |
adaface_encoder_types=args.adaface_encoder_types, | |
adaface_ckpt_paths=args.adaface_ckpt_path, | |
adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales, | |
enabled_encoders=args.enabled_encoders, | |
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None, | |
unet_uses_attn_lora=args.unet_uses_attn_lora, | |
attn_lora_layer_names=args.attn_lora_layer_names, | |
shrink_cross_attn=False, | |
q_lora_updates_query=args.q_lora_updates_query, | |
device='cpu') | |
if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]: | |
args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1 | |
adaface.set_adaface_encoder_cfg_scales(args.adaface_encoder_cfg_scales) | |
print(f"Updating the scale for consistentID encoder to {adaface_encoder_cfg_scale1}.") | |
if not prompt: | |
raise gr.Error("Prompt cannot be blank") | |
### Description | |
title = r""" | |
<h1>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</h1> | |
""" | |
description = r""" | |
<b>Official demo</b> for our working paper <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br> | |
❗️**What's New**❗️ | |
- Support switching between three model styles: **Photorealistic**, **Realistic** and **Anime**. | |
- If you just changed the model style, the first image/video generation will take extra 20~30 seconds for loading new model weight. | |
❗️**Tips**❗️ | |
1. Upload one or more images of a person. If multiple faces are detected, we use the largest one. | |
2. Check "Highlight face" to highlight fine facial features. | |
4. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;"> | |
AdaFace-Animate | |
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;"> | |
</a> | |
**TODO:** | |
- ControlNet integration. | |
""" | |
css = ''' | |
.gradio-container {width: 95% !important} | |
.custom-gallery { | |
height: 800px !important; | |
width: 100%; | |
margin: 10px auto; | |
padding: 0px; | |
overflow-y: auto !important; | |
} | |
.tight-row { | |
gap: 0 !important; /* removes the horizontal gap between columns */ | |
margin: 0 !important; /* remove any extra margin if needed */ | |
padding: 0 !important; /* remove any extra padding if needed */ | |
} | |
''' | |
with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo: | |
# description | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
# upload face image | |
# img_file = gr.Image(label="Upload a photo with a face", type="filepath") | |
img_files = gr.File( | |
label="Drag / Select 1 or more photos of a person's face", | |
file_types=["image"], | |
file_count="multiple" | |
) | |
img_files.GRADIO_CACHE = "/tmp/gradio" | |
# When files are uploaded, show the images in the gallery and hide the file uploader. | |
uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300) | |
with gr.Column(visible=False) as clear_button_column: | |
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", | |
components=img_files, size="sm") | |
with gr.Accordion("Second Subject (Optional)", open=False): | |
img_files2 = gr.File( | |
label="Drag / Select 1 or more photos of second subject's face (optional)", | |
file_types=["image"], | |
file_count="multiple" | |
) | |
img_files2.GRADIO_CACHE = "/tmp/gradio" | |
uploaded_files_gallery2 = gr.Gallery(label="2nd Subject images (optional)", visible=False, columns=3, rows=1, height=300) | |
with gr.Column(visible=False) as clear_button_column2: | |
remove_and_reupload2 = gr.ClearButton(value="Remove and upload 2nd Subject images", | |
components=img_files2, size="sm") | |
with gr.Row(elem_classes="tight-row"): | |
with gr.Column(scale=1, min_width=100): | |
gender = gr.Dropdown(label="Gender", value="(none)", | |
info="Gender prefix. Select only when the model errs.", | |
container=False, | |
choices=[ "(none)", "person", "man", "woman", "girl", "boy" ]) | |
with gr.Column(scale=100): | |
prompt = gr.Dropdown(label="Prompt", | |
info="Try something like 'walking on the beach'. If the face is not in focus, try checking 'Highlight face'.", | |
value="portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin", | |
allow_custom_value=True, | |
choices=[ | |
"portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin", | |
"portrait, walking on the beach, sunset, orange sky, front view", | |
"portrait, in a white apron and chef hat, garnishing a gourmet dish", | |
"portrait, waving hands, dancing pose among folks in a park", | |
"portrait, in iron man costume, the sky ablaze with hues of orange and purple", | |
"portrait, jedi wielding a lightsaber, star wars", | |
"portrait, night view of tokyo street, neon light", | |
"portrait, playing guitar on a boat, ocean waves", | |
"portrait, with a passion for reading, curled up with a book in a cozy nook near a window, front view", | |
"portrait, celebrating new year, fireworks", | |
"portrait, running pose in a park", | |
"portrait, in space suit, space helmet, walking on mars", | |
"portrait, in superman costume, the sky ablaze with hues of orange and purple", | |
"in a wheelchair", | |
"on a horse" | |
]) | |
highlight_face = gr.Checkbox(label="Highlight face", value=False, | |
info="Enhance the facial features by prepending 'face portrait' to the prompt") | |
composition_level = \ | |
gr.Slider(label="Composition Level", visible=True, | |
info="The degree of overall composition, 0~2. Challenging prompts like 'In a wheelchair' and 'on a horse' need level 2", | |
minimum=0, maximum=2, step=1, value=0) | |
ablate_prompt_embed_type = gr.Dropdown(label="Ablate prompt embeddings type", | |
choices=["ada", "ada-nonmix", "img"], value="ada", visible=False, | |
info="Use this type of prompt embeddings for ablation study") | |
nonmix_prompt_emb_weight = gr.Slider(label="Weight of ada-nonmix ID embeddings", | |
minimum=0.0, maximum=0.5, step=0.1, value=0, | |
info="Weight of ada-nonmix ID embeddings in the prompt embeddings", | |
visible=False) | |
subj_name_sig = gr.Textbox( | |
label="Nickname of Subject (optional; used to name saved images)", | |
value="", | |
) | |
subj_name_sig2 = gr.Textbox( | |
label="Nickname of 2nd Subject (optional; used to name saved images)", | |
value="", | |
visible=False, | |
) | |
submit = gr.Button("Submit", variant="primary") | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="sagging face, sagging cheeks, wrinkles, flaws in the eyes, flaws in the face, lowres, " | |
"non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " | |
"mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, " | |
"deformed eyeballs, cross-eyed, extra legs, extra arms, blurry, mutation, duplicate, " | |
"out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " | |
"nude, naked, nsfw, topless, bare breasts", | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=1.0, | |
maximum=8.0, | |
step=0.5, | |
value=args.guidance_scale, | |
) | |
adaface_encoder_cfg_scale1 = gr.Slider( | |
label="Scale for consistentID encoder", | |
minimum=1.0, | |
maximum=12.0, | |
step=1.0, | |
value=args.adaface_encoder_cfg_scales[0], | |
visible=False, | |
) | |
model_style_type = gr.Dropdown( | |
label="Base Model Style Type", | |
info="Switching the base model type will take 10~20 seconds to reload the model", | |
value=args.model_style_type.capitalize(), | |
choices=["Realistic", "Anime", "Photorealistic"], | |
allow_custom_value=False, | |
filterable=False, | |
) | |
perturb_std = gr.Slider( | |
label="Std of noise added to input (may help stablize face embeddings)", | |
minimum=0.0, | |
maximum=0.05, | |
step=0.025, | |
value=0.0, | |
visible=False, | |
) | |
num_images = gr.Slider( | |
label="Number of output images", | |
minimum=1, | |
maximum=8, | |
step=1, | |
value=4, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True, | |
info="Uncheck for reproducible results") | |
disable_adaface = gr.Checkbox(label="Disable AdaFace", value=False, | |
info="Disable AdaFace for ablation. If checked, the results are no longer personalized.", | |
visible=args.show_disable_adaface_checkbox) | |
with gr.Column(): | |
out_gallery = gr.Gallery(label="Generated Images", interactive=False, columns=2, rows=4, height=800, | |
elem_classes="custom-gallery") | |
img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files]) | |
img_files2.upload(fn=swap_to_gallery, inputs=img_files2, outputs=[uploaded_files_gallery2, clear_button_column2, img_files2]) | |
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, | |
img_files, subj_name_sig, gender]) | |
remove_and_reupload2.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery2, clear_button_column2, | |
img_files2, subj_name_sig2, gender]) | |
check_prompt_and_model_type_call_dict = { | |
'fn': check_prompt_and_model_type, | |
'inputs': [prompt, model_style_type, adaface_encoder_cfg_scale1], | |
'outputs': None | |
} | |
randomize_seed_fn_call_dict = { | |
'fn': randomize_seed_fn, | |
'inputs': [seed, randomize_seed], | |
'outputs': seed | |
} | |
generate_image_call_dict = { | |
'fn': generate_image, | |
'inputs': [img_files, img_files2, guidance_scale, perturb_std, num_images, prompt, | |
negative_prompt, gender, highlight_face, ablate_prompt_embed_type, | |
nonmix_prompt_emb_weight, composition_level, seed, disable_adaface, subj_name_sig], | |
'outputs': [out_gallery] | |
} | |
submit.click(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict) | |
subj_name_sig.submit(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict) | |
demo.launch(share=True, server_name=args.ip, ssl_verify=False) |