File size: 8,025 Bytes
76ccb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee7393
76ccb95
 
 
8ee7393
 
76ccb95
8ee7393
 
 
76ccb95
 
 
 
 
 
8ee7393
76ccb95
8ee7393
76ccb95
 
 
 
 
8ee7393
76ccb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee7393
76ccb95
 
8ee7393
 
76ccb95
 
 
8ee7393
76ccb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee7393
76ccb95
 
 
 
 
 
 
8ee7393
 
 
76ccb95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from adaface.adaface_wrapper import AdaFaceWrapper
import torch
#import torch.nn.functional as F
from PIL import Image
import numpy as np
import os, argparse, glob, re

def save_images(images, num_images_per_row, subject_name, prompt, perturb_std, save_dir = "samples-ada"):
    if num_images_per_row > len(images):
        num_images_per_row = len(images)
        
    os.makedirs(save_dir, exist_ok=True)
        
    num_columns = int(np.ceil(len(images) / num_images_per_row))
    # Save 4 images as a grid image in save_dir
    grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
    for i, image in enumerate(images):
        image = image.resize((512, 512))
        grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))

    prompt_sig = prompt.replace(" ", "_").replace(",", "_")
    grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}.png")
    if os.path.exists(grid_filepath):
        grid_count = 2
        grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png')
        while os.path.exists(grid_filepath):
            grid_count += 1
            grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png')

    grid_image.save(grid_filepath)
    print(f"Saved to {grid_filepath}")

def seed_everything(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PL_GLOBAL_SEED"] = str(seed)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pipeline", type=str, default="text2img", 
                        choices=["text2img", "text2imgxl", "img2img", "text2img3", "flux"],
                        help="Type of pipeline to use (default: txt2img)")
    parser.add_argument("--base_model_path", type=str, default=None, 
                        help="Type of checkpoints to use (default: None, using the official model)")
    parser.add_argument('--adaface_ckpt_path', type=str, required=True)
    parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
                        choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")   
    parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
                        choices=["arc2face", "consistentID"], 
                        help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")    
    # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
    parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,    
                        help="CFG scales of output embeddings of the ID2Ada prompt encoders")
    parser.add_argument("--main_unet_filepath", type=str, default=None,
                        help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path")
    parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*", 
                        default=[], 
                        help="Extra paths to the checkpoints of the UNet models")
    parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1], 
                        help="Weights for the UNet models")    
    parser.add_argument("--subject", type=str)
    parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
    parser.add_argument("--out_image_count",     type=int, default=4,  help="Number of images to generate")
    parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
    parser.add_argument("--perturb_std", type=float, default=0)
    parser.add_argument("--randface", action="store_true")
    parser.add_argument("--scale", dest='guidance_scale', type=float, default=4, 
                        help="Guidance scale for the diffusion model")
    parser.add_argument("--num_images_per_row", type=int, default=4,
                        help="Number of images to display in a row in the output grid image.")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
    parser.add_argument("--seed", type=int, default=42, 
                        help="the seed (for reproducible sampling). Set to -1 to disable.")
    args = parser.parse_args()
    
    return args

if __name__ == "__main__":
    args = parse_args()
    if args.seed != -1:
        seed_everything(args.seed)

    if re.match(r"^\d+$", args.device):
        args.device = f"cuda:{args.device}"
    print(f"Using device {args.device}")

    if args.pipeline not in ["text2img", "img2img"]:
        args.extra_unet_dirpaths = None
        args.unet_weights_in_ensemble = None
        
    adaface = AdaFaceWrapper(args.pipeline, args.base_model_path, 
                             args.adaface_encoder_types, args.adaface_ckpt_path, 
                             args.adaface_encoder_cfg_scales, args.enabled_encoders,
                             unet_types=None,
                             main_unet_filepath=args.main_unet_filepath,
                             extra_unet_dirpaths=args.extra_unet_dirpaths,
                             unet_weights_in_ensemble=args.unet_weights_in_ensemble, device=args.device)

    if not args.randface:
        image_folder = args.subject
        if image_folder.endswith("/"):
            image_folder = image_folder[:-1]

        if os.path.isfile(image_folder):
            # Get the second to the last part of the path
            subject_name = os.path.basename(os.path.dirname(image_folder))
            image_paths = [image_folder]

        else:
            subject_name = os.path.basename(image_folder)
            image_types = ["*.jpg", "*.png", "*.jpeg"]
            alltype_image_paths = []
            for image_type in image_types:
                # glob returns the full path.
                image_paths = glob.glob(os.path.join(image_folder, image_type))
                if len(image_paths) > 0:
                    alltype_image_paths.extend(image_paths)

            # Filter out images of "*_mask.png"
            alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]

            # image_paths contain at most args.example_image_count full image paths.
            if args.example_image_count > 0:
                image_paths = alltype_image_paths[:args.example_image_count]
            else:
                image_paths = alltype_image_paths
    else:
        subject_name = None
        image_paths = None
        image_folder = None

    subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
    rand_init_id_embs = torch.randn(1, 512)

    init_id_embs = rand_init_id_embs if args.randface else None
    init_noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
    # args.perturb_std: the *relative* std of the noise added to the face embeddings.
    # A noise level of 0.08 could change gender, but 0.06 is usually safe.
    # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
    adaface_subj_embs = \
        adaface.prepare_adaface_embeddings(image_paths, init_id_embs, 
                                           perturb_at_stage='img_prompt_emb',
                                           perturb_std=args.perturb_std, update_text_encoder=True)    
    images = adaface(init_noise, args.prompt, None, None, 
                     'append', args.guidance_scale, 
                     args.out_image_count, verbose=True)
    save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)