import sys sys.path.append('./') from adaface.adaface_wrapper import AdaFaceWrapper import torch import numpy as np import random import os, re import time import gradio as gr import spaces def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") def is_running_on_spaces(): return os.getenv("SPACE_ID") is not None import argparse parser = argparse.ArgumentParser() parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"], choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt', help="Path to the checkpoint of the ID2Ada prompt encoders") # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face). parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0], help="Scales for the ID2Ada prompt encoders") parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None, choices=["arc2face", "consistentID"], help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)") parser.add_argument('--model_style_type', type=str, default='photorealistic', choices=["realistic", "anime", "photorealistic"], help="Type of the base model") parser.add_argument("--guidance_scale", type=float, default=5.0, help="The guidance scale for the diffusion model. Default: 5.0") parser.add_argument("--unet_uses_attn_lora", type=str2bool, nargs="?", const=True, default=False, help="Whether to use LoRA in the Diffusers UNet model") # --attn_lora_layer_names and --q_lora_updates_query are only effective # when --unet_uses_attn_lora is set to True. parser.add_argument("--attn_lora_layer_names", type=str, nargs="*", default=['q', 'k', 'v', 'out'], choices=['q', 'k', 'v', 'out'], help="Names of the cross-attn components to apply LoRA on") parser.add_argument("--q_lora_updates_query", type=str2bool, nargs="?", const=True, default=False, help="Whether the q LoRA updates the query in the Diffusers UNet model. " "If False, the q lora only updates query2.") parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False, help="Whether to show the checkbox for disabling AdaFace") parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images") parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False, help="Only test the UI layout, and skip loadding the adaface model") parser.add_argument('--gpu', type=int, default=None) parser.add_argument('--ip', type=str, default="0.0.0.0") args = parser.parse_args() from huggingface_hub import snapshot_download large_files = ["models/*", "models/**/*"] snapshot_download(repo_id="adaface-neurips/adaface-models", repo_type="model", allow_patterns=large_files, local_dir=".") os.makedirs("/tmp/gradio", exist_ok=True) model_style_type2base_model_path = { "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors", "anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors", "photorealistic": "models/sar/sar.safetensors", # LDM format. Needs to be converted. } base_model_path = model_style_type2base_model_path[args.model_style_type] # global variable MAX_SEED = np.iinfo(np.int32).max global adaface adaface = None if not args.test_ui_only: adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path, adaface_encoder_types=args.adaface_encoder_types, adaface_ckpt_paths=args.adaface_ckpt_path, adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales, enabled_encoders=args.enabled_encoders, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None, unet_uses_attn_lora=args.unet_uses_attn_lora, attn_lora_layer_names=args.attn_lora_layer_names, shrink_cross_attn=False, q_lora_updates_query=args.q_lora_updates_query, device='cpu') def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def swap_to_gallery(images): # Update uploaded_files_gallery, show files, hide clear_button_column # Or: # Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False) def remove_back_to_files(): # Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx # Or: # Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), \ gr.update(value=""), gr.update(value="(none)") @spaces.GPU def generate_image(image_paths, image_paths2, guidance_scale, perturb_std, num_images, prompt, negative_prompt, gender, highlight_face, ablate_prompt_embed_type, nonmix_prompt_emb_weight, composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)): global adaface, args if is_running_on_spaces(): device = 'cuda:0' else: if args.gpu is None: device = "cuda" else: device = f"cuda:{args.gpu}" print(f"Device: {device}") adaface.to(device) args.device = device if image_paths is None or len(image_paths) == 0: raise gr.Error(f"Cannot find any input face image! Please upload a face image.") if image_paths2 is not None and len(image_paths2) > 0: image_paths = image_paths + image_paths2 if prompt is None: prompt = "" adaface_subj_embs = \ adaface.prepare_adaface_embeddings(image_paths=image_paths, face_id_embs=None, avg_at_stage='id_emb', perturb_at_stage='img_prompt_emb', perturb_std=perturb_std, update_text_encoder=True) if adaface_subj_embs is None: raise gr.Error(f"Failed to detect any faces! Please try with other images") # Sometimes the pipeline is on CPU, although we've put it on CUDA (due to some offloading mechanism). # Therefore we set the generator to the correct device. generator = torch.Generator(device=args.device).manual_seed(seed) print(f"Manual seed: {seed}.") # Generate two images each time for the user to select from. noise = torch.randn(num_images, 3, 512, 512, device=args.device, generator=generator) #print(noise.abs().sum()) # samples: A list of PIL Image instances. if highlight_face: if "portrait" not in prompt: prompt = "face portrait, " + prompt else: prompt = prompt.replace("portrait", "face portrait") if composition_level >= 2: if "full body" not in prompt: prompt = prompt + ", full body view" if gender != "(none)": if "portrait" in prompt: prompt = prompt.replace("portrait, ", f"portrait, {gender} ") else: prompt = gender + ", " + prompt generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed) samples = adaface(noise, prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, out_image_count=num_images, generator=generator, repeat_prompt_for_each_encoder=(composition_level >= 1), ablate_prompt_no_placeholders=disable_adaface, ablate_prompt_embed_type=ablate_prompt_embed_type, nonmix_prompt_emb_weight=nonmix_prompt_emb_weight, verbose=True) session_signature = ",".join(image_paths + [prompt, str(seed)]) temp_folder = os.path.join("/tmp/gradio", f"{hash(session_signature)}") os.makedirs(temp_folder, exist_ok=True) saved_image_paths = [] if "models/adaface/" in args.adaface_ckpt_path: # The model is loaded from within the project. # models/adaface/VGGface2_HQ_masks2024-10-14T16-09-24_zero3-ada-3500.pt matches = re.search(r"models/adaface/\w+\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada-(\d+).pt", args.adaface_ckpt_path) else: # The model is loaded from the adaprompt folder. # adaface_ckpt_path = "VGGface2_HQ_masks2024-11-28T13-13-20_zero3-ada/checkpoints/embeddings_gs-2000.pt" matches = re.search(r"\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada/checkpoints/embeddings_gs-(\d+).pt", args.adaface_ckpt_path) # Extract the checkpoint signature as 112813-2000 ckpt_sig = f"{matches.group(1)}{matches.group(2)}{matches.group(3)}-{matches.group(4)}" prompt_keywords = ['armor', 'beach', 'chef', 'dancing', 'iron man', 'jedi', 'street', 'guitar', 'reading', 'running', 'superman', 'new year', 'mars'] keywords_reduction = { 'iron man': 'ironman', 'dancing': 'dance', 'running': 'run', 'reading': 'read', 'new year': 'newyear' } prompt_sig = None for keyword in prompt_keywords: if keyword in prompt.lower(): prompt_sig = keywords_reduction.get(keyword, keyword) break if prompt_sig is None: prompt_parts = prompt.lower().split(",") # Remove the view/shot parts (full body view, long shot, etc.) from the prompt. prompt_parts = [ part for part in prompt_parts if not re.search(r"\W(view|shot)(\W|$)", part) ] if len(prompt_parts) > 0: # Use the last word of the prompt as the signature. prompt_sig = prompt_parts[-1].split()[-1] else: prompt_sig = "person" if len(prompt_sig) > 0: prompt_sig = "-" + prompt_sig extra_save_dir = args.extra_save_dir if extra_save_dir is not None: os.makedirs(extra_save_dir, exist_ok=True) for i, sample in enumerate(samples): filename = f"adaface{ckpt_sig}{prompt_sig}-{i+1}.png" if len(subj_name_sig) > 0: filename = f"{subj_name_sig.lower()}-{filename}" filepath = os.path.join(temp_folder, filename) # Save the image sample.save(filepath) # Adjust to your image saving method saved_image_paths.append(filepath) if extra_save_dir is not None: extra_filepath = os.path.join(extra_save_dir, filename) sample.save(extra_filepath) print(extra_filepath) # Solution suggested by o1 to force the client browser to reload images # when we change guidance scales only. saved_image_paths = [f"{url}?t={int(time.time())}" for url in saved_image_paths] return saved_image_paths def check_prompt_and_model_type(prompt, model_style_type, adaface_encoder_cfg_scale1): global adaface model_style_type = model_style_type.lower() # If the base model type is changed, reload the model. if model_style_type != args.model_style_type or adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]: if model_style_type != args.model_style_type: # Update base model type. args.model_style_type = model_style_type print(f"Switching to the base model type: {model_style_type}.") adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type], adaface_encoder_types=args.adaface_encoder_types, adaface_ckpt_paths=args.adaface_ckpt_path, adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales, enabled_encoders=args.enabled_encoders, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None, unet_uses_attn_lora=args.unet_uses_attn_lora, attn_lora_layer_names=args.attn_lora_layer_names, shrink_cross_attn=False, q_lora_updates_query=args.q_lora_updates_query, device='cpu') if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]: args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1 adaface.set_adaface_encoder_cfg_scales(args.adaface_encoder_cfg_scales) print(f"Updating the scale for consistentID encoder to {adaface_encoder_cfg_scale1}.") if not prompt: raise gr.Error("Prompt cannot be blank") ### Description title = r"""