File size: 6,964 Bytes
c4e6a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import json
import os
from dataclasses import dataclass, field
from typing import List

import pyrallis
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision.transforms import ToTensor
from tqdm import tqdm

from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace
from src.null_text_inversion import invert_image
from src.prompt_utils import get_proxy_prompts
from src.prompt_mixing import PromptMixing
from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
    generate_original_image


def save_args_dict(args, similar_words):
    exp_path = os.path.join(args.exp_dir, args.prompt.replace(' ', '-'), f"seed={args.seed}_{args.exp_name}")
    os.makedirs(exp_path, exist_ok=True)

    args_dict = vars(args)
    args_dict['similar_words'] = similar_words
    with open(os.path.join(exp_path, "opt.json"), 'w') as fp:
        json.dump(args_dict, fp, sort_keys=True, indent=4)

    return exp_path


def main(args):
    ldm_stable = get_stable_diffusion_model(args)
    ldm_stable_config = get_stable_diffusion_config(args)

    similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
    exp_path = save_args_dict(args, similar_words)

    images = []
    x_t = None
    uncond_embeddings = None

    if args.real_image_path != "":
        x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path)

    image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
    save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
    save_image(torch.from_numpy(orig_mask).float(), f"{exp_path}/{similar_words[0]}_mask.jpg")
    images.append(image[0])

    object_of_interest_index = args.prompt.split().index('{word}') + 1
    pm = PromptMixing(args, object_of_interest_index, average_attention)

    do_other_obj_self_attn_masking = len(args.objects_to_preserve) > 0 and args.end_preserved_obj_self_attn_masking > 0
    do_self_or_cross_attn_inject = args.cross_attn_inject_steps != 0.0 or args.self_attn_inject_steps != 0.0
    if do_other_obj_self_attn_masking:
        print("Do self attn other obj masking")
    if do_self_or_cross_attn_inject:
        print(f'Do self attn inject for {args.self_attn_inject_steps} steps')
        print(f'Do cross attn inject for {args.cross_attn_inject_steps} steps')

    another_prompts_dataloader = DataLoader(another_prompts[1:], batch_size=args.batch_size, shuffle=False)

    for another_prompt_batch in tqdm(another_prompts_dataloader):
        batch_size = len(another_prompt_batch["word"])
        batch_prompts = prompts * batch_size
        batch_another_prompt = another_prompt_batch["prompt"]
        if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking:
            batch_prompts.append(prompts[0])
            batch_another_prompt.insert(0, prompts[0])

        if do_self_or_cross_attn_inject:
            controller = AttentionReplace(batch_another_prompt, ldm_stable.tokenizer, ldm_stable.device,
                                          ldm_stable_config["low_resource"], ldm_stable_config["num_diffusion_steps"],
                                          cross_replace_steps=args.cross_attn_inject_steps,
                                          self_replace_steps=args.self_attn_inject_steps)
        else:
            controller = AttentionStore(ldm_stable_config["low_resource"])

        diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, prompt_mixing=pm)
        with torch.no_grad():
            image, x_t, _, mask = diffusion_model_wrapper.forward(batch_prompts, latent=x_t, other_prompt=batch_another_prompt,
                                                                  post_background=args.background_post_process, orig_all_latents=orig_all_latents,
                                                                  orig_mask=orig_mask, uncond_embeddings=uncond_embeddings)

        for i in range(batch_size):
            image_index = i + 1 if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking else i
            save_image(ToTensor()(image[image_index]), f"{exp_path}/{another_prompt_batch['word'][i]}.jpg")
            if mask is not None:
                save_image(torch.from_numpy(mask).float(), f"{exp_path}/{another_prompt_batch['word'][i]}_mask.jpg")
            images.append(image[image_index])

    images = [ToTensor()(image) for image in images]
    save_image(images, f"{exp_path}/grid.jpg", nrow=min(max([i for i in range(2, 8) if len(images) % i == 0]), 8))
    return images, similar_words


@dataclass
class LPMConfig:

    # general config
    seed: int = 10
    batch_size: int = 1
    exp_dir: str = "results"
    exp_name: str = ""
    display_images: bool = False
    gpu_id: int = 0

    # Stable Diffusion config
    auth_token: str = ""
    low_resource: bool = True
    num_diffusion_steps: int = 50
    guidance_scale: float = 7.5
    max_num_words: int = 77

    # prompt-mixing
    prompt: str = "a {word} in the field eats an apple"
    object_of_interest: str = "snake"                                   # The object for which we generate variations
    proxy_words: List[str] = field(default_factory=lambda :[])          # Leave empty for automatic proxy words
    number_of_variations: int = 20
    start_prompt_range: int = 7                                         # Number of steps to begin prompt-mixing
    end_prompt_range: int = 17                                          # Number of steps to finish prompt-mixing

    # attention based shape localization
    objects_to_preserve: List[str] = field(default_factory=lambda :[])  # Objects for which apply attention based shape localization
    remove_obj_from_self_mask: bool = True                              # If set to True, removes the object of interest from the self-attention mask
    obj_pixels_injection_threshold: float = 0.05
    end_preserved_obj_self_attn_masking: int = 40

    # real image
    real_image_path: str = ""

    # controllable background preservation
    background_post_process: bool = True
    background_nouns: List[str] = field(default_factory=lambda :[])     # Objects to take from the original image in addition to the background
    num_segments: int = 5                                               # Number of clusters for the segmentation
    background_segment_threshold: float = 0.3                           # Threshold for the segments labeling
    background_blend_timestep: int = 35                                 # Number of steps before background blending

    # other
    cross_attn_inject_steps: float = 0.0
    self_attn_inject_steps: float = 0.0


if __name__ == '__main__':
    args = pyrallis.parse(config_class=LPMConfig)

    print(args)
    main(args)