Spaces:
Runtime error
Runtime error
yuvalalaluf
commited on
Commit
Β·
82ef366
1
Parent(s):
991d8d3
initial commit
Browse files- appearance_transfer_model.py +177 -0
- config.py +66 -0
- constants.py +3 -0
- demo.py +96 -0
- environment/environment.yaml +10 -0
- environment/requirements.txt +17 -0
- inputs/chocolate_cake.jpg +0 -0
- inputs/duomo.png +0 -0
- inputs/giraffe.png +0 -0
- inputs/red_velvet_cake.jpg +0 -0
- inputs/taj_mahal.jpg +0 -0
- inputs/zebra.png +0 -0
- models/__init__.py +0 -0
- models/stable_diffusion.py +240 -0
- models/unet_2d_condition.py +345 -0
- utils/__init__.py +0 -0
- utils/adain.py +45 -0
- utils/attention_utils.py +37 -0
- utils/ddpm_inversion.py +323 -0
- utils/image_utils.py +59 -0
- utils/latent_utils.py +81 -0
- utils/model_utils.py +16 -0
- utils/segmentation.py +111 -0
appearance_transfer_model.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from config import RunConfig
|
7 |
+
from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX
|
8 |
+
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
|
9 |
+
from utils import attention_utils
|
10 |
+
from utils.adain import masked_adain
|
11 |
+
from utils.model_utils import get_stable_diffusion_model
|
12 |
+
from utils.segmentation import Segmentor
|
13 |
+
|
14 |
+
|
15 |
+
class AppearanceTransferModel:
|
16 |
+
|
17 |
+
def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None):
|
18 |
+
self.config = config
|
19 |
+
self.pipe = get_stable_diffusion_model() if pipe is None else pipe
|
20 |
+
self.register_attention_control()
|
21 |
+
self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun])
|
22 |
+
self.latents_app, self.latents_struct = None, None
|
23 |
+
self.zs_app, self.zs_struct = None, None
|
24 |
+
self.image_app_mask_32, self.image_app_mask_64 = None, None
|
25 |
+
self.image_struct_mask_32, self.image_struct_mask_64 = None, None
|
26 |
+
self.enable_edit = False
|
27 |
+
self.step = 0
|
28 |
+
|
29 |
+
def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor):
|
30 |
+
self.latents_app = latents_app
|
31 |
+
self.latents_struct = latents_struct
|
32 |
+
|
33 |
+
def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor):
|
34 |
+
self.zs_app = zs_app
|
35 |
+
self.zs_struct = zs_struct
|
36 |
+
|
37 |
+
def set_masks(self, masks: List[torch.Tensor]):
|
38 |
+
self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks
|
39 |
+
|
40 |
+
def get_adain_callback(self):
|
41 |
+
|
42 |
+
def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable:
|
43 |
+
self.step = st
|
44 |
+
# Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation
|
45 |
+
if self.step == self.config.adain_range.start:
|
46 |
+
masks = self.segmentor.get_object_masks()
|
47 |
+
self.set_masks(masks)
|
48 |
+
# Apply AdaIN operation using the computed masks
|
49 |
+
if self.config.adain_range.start <= self.step < self.config.adain_range.end:
|
50 |
+
latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64)
|
51 |
+
|
52 |
+
return callback
|
53 |
+
|
54 |
+
def register_attention_control(self):
|
55 |
+
|
56 |
+
model_self = self
|
57 |
+
|
58 |
+
class AttentionProcessor:
|
59 |
+
|
60 |
+
def __init__(self, place_in_unet: str):
|
61 |
+
self.place_in_unet = place_in_unet
|
62 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
63 |
+
raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.")
|
64 |
+
|
65 |
+
def __call__(self,
|
66 |
+
attn,
|
67 |
+
hidden_states: torch.Tensor,
|
68 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
69 |
+
attention_mask=None,
|
70 |
+
temb=None,
|
71 |
+
perform_swap: bool = False):
|
72 |
+
|
73 |
+
residual = hidden_states
|
74 |
+
|
75 |
+
if attn.spatial_norm is not None:
|
76 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
77 |
+
|
78 |
+
input_ndim = hidden_states.ndim
|
79 |
+
|
80 |
+
if input_ndim == 4:
|
81 |
+
batch_size, channel, height, width = hidden_states.shape
|
82 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
83 |
+
|
84 |
+
batch_size, sequence_length, _ = (
|
85 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
86 |
+
)
|
87 |
+
|
88 |
+
if attention_mask is not None:
|
89 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
90 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
91 |
+
|
92 |
+
if attn.group_norm is not None:
|
93 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
94 |
+
|
95 |
+
query = attn.to_q(hidden_states)
|
96 |
+
|
97 |
+
is_cross = encoder_hidden_states is not None
|
98 |
+
if not is_cross:
|
99 |
+
encoder_hidden_states = hidden_states
|
100 |
+
elif attn.norm_cross:
|
101 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
102 |
+
|
103 |
+
key = attn.to_k(encoder_hidden_states)
|
104 |
+
value = attn.to_v(encoder_hidden_states)
|
105 |
+
|
106 |
+
inner_dim = key.shape[-1]
|
107 |
+
head_dim = inner_dim // attn.heads
|
108 |
+
should_mix = False
|
109 |
+
|
110 |
+
# Potentially apply our cross image attention operation
|
111 |
+
# To do so, we need to be in a self-attention alyer in the decoder part of the denoising network
|
112 |
+
if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit:
|
113 |
+
if attention_utils.should_mix_keys_and_values(model_self, hidden_states):
|
114 |
+
should_mix = True
|
115 |
+
if model_self.step % 5 == 0 and model_self.step < 40:
|
116 |
+
# Inject the structure's keys and values
|
117 |
+
key[OUT_INDEX] = key[STRUCT_INDEX]
|
118 |
+
value[OUT_INDEX] = value[STRUCT_INDEX]
|
119 |
+
else:
|
120 |
+
# Inject the appearance's keys and values
|
121 |
+
key[OUT_INDEX] = key[STYLE_INDEX]
|
122 |
+
value[OUT_INDEX] = value[STYLE_INDEX]
|
123 |
+
|
124 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
125 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
126 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
127 |
+
|
128 |
+
# Compute the cross attention and apply our contrasting operation
|
129 |
+
hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention(
|
130 |
+
query, key, value,
|
131 |
+
edit_map=perform_swap and model_self.enable_edit and should_mix,
|
132 |
+
is_cross=is_cross,
|
133 |
+
contrast_strength=model_self.config.contrast_strength,
|
134 |
+
)
|
135 |
+
|
136 |
+
# Update attention map for segmentation
|
137 |
+
if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1:
|
138 |
+
model_self.segmentor.update_attention(attn_weight, is_cross)
|
139 |
+
|
140 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
141 |
+
hidden_states = hidden_states.to(query[OUT_INDEX].dtype)
|
142 |
+
|
143 |
+
# linear proj
|
144 |
+
hidden_states = attn.to_out[0](hidden_states)
|
145 |
+
# dropout
|
146 |
+
hidden_states = attn.to_out[1](hidden_states)
|
147 |
+
|
148 |
+
if input_ndim == 4:
|
149 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
150 |
+
|
151 |
+
if attn.residual_connection:
|
152 |
+
hidden_states = hidden_states + residual
|
153 |
+
|
154 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
155 |
+
|
156 |
+
return hidden_states
|
157 |
+
|
158 |
+
def register_recr(net_, count, place_in_unet):
|
159 |
+
if net_.__class__.__name__ == 'ResnetBlock2D':
|
160 |
+
pass
|
161 |
+
if net_.__class__.__name__ == 'Attention':
|
162 |
+
net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}"))
|
163 |
+
return count + 1
|
164 |
+
elif hasattr(net_, 'children'):
|
165 |
+
for net__ in net_.children():
|
166 |
+
count = register_recr(net__, count, place_in_unet)
|
167 |
+
return count
|
168 |
+
|
169 |
+
cross_att_count = 0
|
170 |
+
sub_nets = self.pipe.unet.named_children()
|
171 |
+
for net in sub_nets:
|
172 |
+
if "down" in net[0]:
|
173 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
174 |
+
elif "up" in net[0]:
|
175 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
176 |
+
elif "mid" in net[0]:
|
177 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
config.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import NamedTuple, Optional
|
4 |
+
|
5 |
+
|
6 |
+
class Range(NamedTuple):
|
7 |
+
start: int
|
8 |
+
end: int
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class RunConfig:
|
13 |
+
# Appearance image path
|
14 |
+
app_image_path: Path
|
15 |
+
# Struct image path
|
16 |
+
struct_image_path: Path
|
17 |
+
# Domain name (e.g., buildings, animals)
|
18 |
+
domain_name: Optional[str] = None
|
19 |
+
# Output path
|
20 |
+
output_path: Path = Path('./output')
|
21 |
+
# Random seed
|
22 |
+
seed: int = 42
|
23 |
+
# Input prompt for inversion (will use domain name as default)
|
24 |
+
prompt: Optional[str] = None
|
25 |
+
# Number of timesteps
|
26 |
+
num_timesteps: int = 100
|
27 |
+
# Whether to use a binary mask for performing AdaIN
|
28 |
+
use_masked_adain: bool = True
|
29 |
+
# Timesteps to apply cross-attention on 64x64 layers
|
30 |
+
cross_attn_64_range: Range = Range(start=10, end=90)
|
31 |
+
# Timesteps to apply cross-attention on 32x32 layers
|
32 |
+
cross_attn_32_range: Range = Range(start=10, end=70)
|
33 |
+
# Timesteps to apply AdaIn
|
34 |
+
adain_range: Range = Range(start=20, end=100)
|
35 |
+
# Guidance scale
|
36 |
+
guidance_scale: float = 7.5
|
37 |
+
# Swap guidance scale
|
38 |
+
swap_guidance_scale: float = 3.5
|
39 |
+
# Attention contrasting strength
|
40 |
+
contrast_strength: float = 1.67
|
41 |
+
# Object nouns to use for self-segmentation (will use the domain name as default)
|
42 |
+
object_noun: Optional[str] = None
|
43 |
+
# Whether to load previously saved inverted latent codes
|
44 |
+
load_latents: bool = True
|
45 |
+
# Number of steps to skip in the denoising process (used value from original edit-friendly DDPM paper)
|
46 |
+
skip_steps: int = 32
|
47 |
+
|
48 |
+
def __post_init__(self):
|
49 |
+
self.output_path = self.output_path / self.domain_name
|
50 |
+
self.output_path.mkdir(parents=True, exist_ok=True)
|
51 |
+
|
52 |
+
# Handle the domain name, prompt, and object nouns used for masking, etc.
|
53 |
+
if self.use_masked_adain and self.domain_name is None:
|
54 |
+
raise ValueError("Must provide --domain_name and --prompt when using masked AdaIN")
|
55 |
+
if not self.use_masked_adain and self.domain_name is None:
|
56 |
+
self.domain_name = "object"
|
57 |
+
if self.prompt is None:
|
58 |
+
self.prompt = f"A photo of a {self.domain_name}"
|
59 |
+
if self.object_noun is None:
|
60 |
+
self.object_noun = self.domain_name
|
61 |
+
|
62 |
+
# Define the paths to store the inverted latents to
|
63 |
+
self.latents_path = Path(self.output_path) / "latents"
|
64 |
+
self.latents_path.mkdir(parents=True, exist_ok=True)
|
65 |
+
self.app_latent_save_path = self.latents_path / f"{self.app_image_path.stem}.pt"
|
66 |
+
self.struct_latent_save_path = self.latents_path / f"{self.struct_image_path.stem}.pt"
|
constants.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
OUT_INDEX = 0
|
2 |
+
STYLE_INDEX = 1
|
3 |
+
STRUCT_INDEX = 2
|
demo.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from appearance_transfer_model import AppearanceTransferModel
|
9 |
+
from run import run_appearance_transfer
|
10 |
+
from utils.latent_utils import load_latents_or_invert_images
|
11 |
+
from utils.model_utils import get_stable_diffusion_model
|
12 |
+
|
13 |
+
sys.path.append(".")
|
14 |
+
sys.path.append("..")
|
15 |
+
|
16 |
+
from config import RunConfig
|
17 |
+
|
18 |
+
DESCRIPTION = '''
|
19 |
+
<h1 style="text-align: center;"> Cross-Image Attention for Zero-Shot Appearance Transfer </h1>
|
20 |
+
<p style="text-align: center;">
|
21 |
+
This is a demo for our <a href="https://arxiv.org/abs/2311.03335">paper</a>:
|
22 |
+
''Cross-Image Attention for Zero-Shot Appearance Transfer''.
|
23 |
+
<br>
|
24 |
+
Given two images depicting a source structure and a target appearance, our method generates an image merging
|
25 |
+
the structure of one image with the appearance of the other.
|
26 |
+
<br>
|
27 |
+
We do so in a zero-shot manner, with no optimization or model training required while supporting appearance
|
28 |
+
transfer across images that may differ in size and shape.
|
29 |
+
</p>
|
30 |
+
'''
|
31 |
+
|
32 |
+
pipe = get_stable_diffusion_model()
|
33 |
+
|
34 |
+
|
35 |
+
def main_pipeline(app_image_path: str,
|
36 |
+
struct_image_path: str,
|
37 |
+
domain_name: str,
|
38 |
+
seed: int,
|
39 |
+
prompt: Optional[str] = None) -> Image.Image:
|
40 |
+
if prompt == "":
|
41 |
+
prompt = None
|
42 |
+
config = RunConfig(
|
43 |
+
app_image_path=Path(app_image_path),
|
44 |
+
struct_image_path=Path(struct_image_path),
|
45 |
+
domain_name=domain_name,
|
46 |
+
prompt=prompt,
|
47 |
+
seed=seed,
|
48 |
+
load_latents=False
|
49 |
+
)
|
50 |
+
model = AppearanceTransferModel(config=config, pipe=pipe)
|
51 |
+
latents_app, latents_struct, noise_app, noise_struct = load_latents_or_invert_images(model=model, cfg=config)
|
52 |
+
model.set_latents(latents_app, latents_struct)
|
53 |
+
model.set_noise(noise_app, noise_struct)
|
54 |
+
print("Running appearance transfer...")
|
55 |
+
images = run_appearance_transfer(model=model, cfg=config)
|
56 |
+
print("Done.")
|
57 |
+
return [images[0]]
|
58 |
+
|
59 |
+
|
60 |
+
with gr.Blocks(css='style.css') as demo:
|
61 |
+
gr.Markdown(DESCRIPTION)
|
62 |
+
|
63 |
+
gr.HTML('''<a href="https://huggingface.co/spaces/yuvalalaluf/cross-image-attention?duplicate=true"><img src="https://bit.ly/3gLdBN6"
|
64 |
+
alt="Duplicate Space"></a>''')
|
65 |
+
|
66 |
+
with gr.Row():
|
67 |
+
with gr.Column():
|
68 |
+
app_image_path = gr.Image(label="Upload appearance image", type="filepath")
|
69 |
+
struct_image_path = gr.Image(label="Upload structure image", type="filepath")
|
70 |
+
domain_name = gr.Text(label="Domain name", max_lines=1,
|
71 |
+
info="Specifies the domain the objects are coming from (e.g., 'animal', 'building', etc).")
|
72 |
+
prompt = gr.Text(label="Prompt to use for inversion.", value='',
|
73 |
+
info='If this kept empty, we will use the domain name to define '
|
74 |
+
'the prompt as "A photo of a <domain_name>".')
|
75 |
+
random_seed = gr.Number(value=42, label="Random seed", precision=0)
|
76 |
+
run_button = gr.Button('Generate')
|
77 |
+
|
78 |
+
with gr.Column():
|
79 |
+
result = gr.Gallery(label='Result')
|
80 |
+
inputs = [app_image_path, struct_image_path, domain_name, random_seed, prompt]
|
81 |
+
outputs = [result]
|
82 |
+
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
|
83 |
+
|
84 |
+
with gr.Row():
|
85 |
+
examples = [
|
86 |
+
['inputs/zebra.png', 'inputs/giraffe.png', 'animal', 20, None],
|
87 |
+
['inputs/taj_mahal.jpg', 'inputs/duomo.png', 'building', 42, None],
|
88 |
+
['inputs/red_velvet_cake.jpg', 'inputs/chocolate_cake.jpg', 'cake', 42, 'A photo of cake'],
|
89 |
+
]
|
90 |
+
gr.Examples(examples=examples,
|
91 |
+
inputs=[app_image_path, struct_image_path, domain_name, random_seed, prompt],
|
92 |
+
outputs=[result],
|
93 |
+
fn=main_pipeline,
|
94 |
+
cache_examples=True)
|
95 |
+
|
96 |
+
demo.launch(share=False, server_name="127.0.0.1", server_port=8888)
|
environment/environment.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: cross_image
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.5
|
7 |
+
- pip=20.3
|
8 |
+
- cudatoolkit=11.3
|
9 |
+
- pip:
|
10 |
+
- -r requirements.txt
|
environment/requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib==3.6.3
|
2 |
+
matplotlib-inline==0.1.6
|
3 |
+
jupyter==1.0.0
|
4 |
+
numpy==1.24.1
|
5 |
+
pyrallis==0.3.1
|
6 |
+
torch==2.0.1
|
7 |
+
torchvision==0.15.2
|
8 |
+
diffusers==0.19.3
|
9 |
+
transformers==4.30.2
|
10 |
+
accelerate==0.20.3
|
11 |
+
huggingface-hub==0.16.4
|
12 |
+
xformers==0.0.21
|
13 |
+
tokenizers==0.13.3
|
14 |
+
nltk==3.8.1
|
15 |
+
Pillow==10.1.0
|
16 |
+
scikit_learn==1.3.0
|
17 |
+
tqdm==4.64.1
|
inputs/chocolate_cake.jpg
ADDED
inputs/duomo.png
ADDED
inputs/giraffe.png
ADDED
inputs/red_velvet_cake.jpg
ADDED
inputs/taj_mahal.jpg
ADDED
inputs/zebra.png
ADDED
models/__init__.py
ADDED
File without changes
|
models/stable_diffusion.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
from diffusers.models import AutoencoderKL
|
7 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
8 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
9 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
|
12 |
+
|
13 |
+
from config import Range
|
14 |
+
from models.unet_2d_condition import FreeUUNet2DConditionModel
|
15 |
+
|
16 |
+
|
17 |
+
class CrossImageAttentionStableDiffusionPipeline(StableDiffusionPipeline):
|
18 |
+
""" A modification of the standard StableDiffusionPipeline to incorporate our cross-image attention."""
|
19 |
+
|
20 |
+
def __init__(self, vae: AutoencoderKL,
|
21 |
+
text_encoder: CLIPTextModel,
|
22 |
+
tokenizer: CLIPTokenizer,
|
23 |
+
unet: FreeUUNet2DConditionModel,
|
24 |
+
scheduler: KarrasDiffusionSchedulers,
|
25 |
+
safety_checker: StableDiffusionSafetyChecker,
|
26 |
+
feature_extractor: CLIPImageProcessor,
|
27 |
+
requires_safety_checker: bool = True):
|
28 |
+
super().__init__(
|
29 |
+
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
30 |
+
)
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def __call__(
|
34 |
+
self,
|
35 |
+
prompt: Union[str, List[str]] = None,
|
36 |
+
height: Optional[int] = None,
|
37 |
+
width: Optional[int] = None,
|
38 |
+
num_inference_steps: int = 50,
|
39 |
+
guidance_scale: float = 7.5,
|
40 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
41 |
+
num_images_per_prompt: Optional[int] = 1,
|
42 |
+
eta: float = 0.0,
|
43 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
44 |
+
latents: Optional[torch.FloatTensor] = None,
|
45 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
46 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
47 |
+
output_type: Optional[str] = "pil",
|
48 |
+
return_dict: bool = True,
|
49 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
50 |
+
callback_steps: int = 1,
|
51 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
52 |
+
guidance_rescale: float = 0.0,
|
53 |
+
swap_guidance_scale: float = 1.0,
|
54 |
+
cross_image_attention_range: Range = Range(10, 90),
|
55 |
+
# DDPM addition
|
56 |
+
zs: Optional[List[torch.Tensor]] = None
|
57 |
+
):
|
58 |
+
|
59 |
+
# 0. Default height and width to unet
|
60 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
61 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
62 |
+
|
63 |
+
# 1. Check inputs. Raise error if not correct
|
64 |
+
self.check_inputs(
|
65 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
66 |
+
)
|
67 |
+
|
68 |
+
# 2. Define call parameters
|
69 |
+
if prompt is not None and isinstance(prompt, str):
|
70 |
+
batch_size = 1
|
71 |
+
elif prompt is not None and isinstance(prompt, list):
|
72 |
+
batch_size = len(prompt)
|
73 |
+
else:
|
74 |
+
batch_size = prompt_embeds.shape[0]
|
75 |
+
|
76 |
+
device = self._execution_device
|
77 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
78 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
79 |
+
# corresponds to doing no classifier free guidance.
|
80 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
81 |
+
|
82 |
+
# 3. Encode input prompt
|
83 |
+
text_encoder_lora_scale = (
|
84 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
85 |
+
)
|
86 |
+
prompt_embeds = self._encode_prompt(
|
87 |
+
prompt,
|
88 |
+
device,
|
89 |
+
num_images_per_prompt,
|
90 |
+
do_classifier_free_guidance,
|
91 |
+
negative_prompt,
|
92 |
+
prompt_embeds=prompt_embeds,
|
93 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
94 |
+
lora_scale=text_encoder_lora_scale,
|
95 |
+
)
|
96 |
+
|
97 |
+
# 4. Prepare timesteps
|
98 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
99 |
+
timesteps = self.scheduler.timesteps
|
100 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs[0].shape[0]:])}
|
101 |
+
timesteps = timesteps[-zs[0].shape[0]:]
|
102 |
+
|
103 |
+
# 5. Prepare latent variables
|
104 |
+
num_channels_latents = self.unet.config.in_channels
|
105 |
+
latents = self.prepare_latents(
|
106 |
+
batch_size * num_images_per_prompt,
|
107 |
+
num_channels_latents,
|
108 |
+
height,
|
109 |
+
width,
|
110 |
+
prompt_embeds.dtype,
|
111 |
+
device,
|
112 |
+
generator,
|
113 |
+
latents,
|
114 |
+
)
|
115 |
+
|
116 |
+
# 7. Denoising loop
|
117 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
118 |
+
|
119 |
+
op = tqdm(timesteps[-zs[0].shape[0]:])
|
120 |
+
n_timesteps = len(timesteps[-zs[0].shape[0]:])
|
121 |
+
|
122 |
+
count = 0
|
123 |
+
for t in op:
|
124 |
+
i = t_to_idx[int(t)]
|
125 |
+
|
126 |
+
# expand the latents if we are doing classifier free guidance
|
127 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
128 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
129 |
+
|
130 |
+
noise_pred_swap = self.unet(
|
131 |
+
latent_model_input,
|
132 |
+
t,
|
133 |
+
encoder_hidden_states=prompt_embeds,
|
134 |
+
cross_attention_kwargs={'perform_swap': True},
|
135 |
+
return_dict=False,
|
136 |
+
)[0]
|
137 |
+
noise_pred_no_swap = self.unet(
|
138 |
+
latent_model_input,
|
139 |
+
t,
|
140 |
+
encoder_hidden_states=prompt_embeds,
|
141 |
+
cross_attention_kwargs={'perform_swap': False},
|
142 |
+
return_dict=False,
|
143 |
+
)[0]
|
144 |
+
|
145 |
+
# perform guidance
|
146 |
+
if do_classifier_free_guidance:
|
147 |
+
_, noise_swap_pred_text = noise_pred_swap.chunk(2)
|
148 |
+
noise_no_swap_pred_uncond, _ = noise_pred_no_swap.chunk(2)
|
149 |
+
noise_pred = noise_no_swap_pred_uncond + guidance_scale * (
|
150 |
+
noise_swap_pred_text - noise_no_swap_pred_uncond)
|
151 |
+
else:
|
152 |
+
is_cross_image_step = cross_image_attention_range.start <= i <= cross_image_attention_range.end
|
153 |
+
if swap_guidance_scale > 1.0 and is_cross_image_step:
|
154 |
+
swapping_strengths = np.linspace(swap_guidance_scale,
|
155 |
+
max(swap_guidance_scale / 2, 1.0),
|
156 |
+
n_timesteps)
|
157 |
+
swapping_strength = swapping_strengths[count]
|
158 |
+
noise_pred = noise_pred_no_swap + swapping_strength * (noise_pred_swap - noise_pred_no_swap)
|
159 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_swap, guidance_rescale=guidance_rescale)
|
160 |
+
else:
|
161 |
+
noise_pred = noise_pred_swap
|
162 |
+
|
163 |
+
latents = torch.stack([
|
164 |
+
self.perform_ddpm_step(t_to_idx, zs[latent_idx], latents[latent_idx], t, noise_pred[latent_idx], eta)
|
165 |
+
for latent_idx in range(latents.shape[0])
|
166 |
+
])
|
167 |
+
|
168 |
+
# call the callback, if provided
|
169 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
170 |
+
# progress_bar.update()
|
171 |
+
if callback is not None and i % callback_steps == 0:
|
172 |
+
callback(i, t, latents)
|
173 |
+
|
174 |
+
count += 1
|
175 |
+
|
176 |
+
if not output_type == "latent":
|
177 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
178 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
179 |
+
else:
|
180 |
+
image = latents
|
181 |
+
has_nsfw_concept = None
|
182 |
+
|
183 |
+
if has_nsfw_concept is None:
|
184 |
+
do_denormalize = [True] * image.shape[0]
|
185 |
+
else:
|
186 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
187 |
+
|
188 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
189 |
+
|
190 |
+
# Offload last model to CPU
|
191 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
192 |
+
self.final_offload_hook.offload()
|
193 |
+
|
194 |
+
if not return_dict:
|
195 |
+
return (image, has_nsfw_concept)
|
196 |
+
|
197 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
198 |
+
|
199 |
+
def perform_ddpm_step(self, t_to_idx, zs, latents, t, noise_pred, eta):
|
200 |
+
idx = t_to_idx[int(t)]
|
201 |
+
z = zs[idx] if not zs is None else None
|
202 |
+
# 1. get previous step value (=t-1)
|
203 |
+
prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
204 |
+
# 2. compute alphas, betas
|
205 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
206 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
|
207 |
+
prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
208 |
+
beta_prod_t = 1 - alpha_prod_t
|
209 |
+
# 3. compute predicted original sample from predicted noise also called
|
210 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
211 |
+
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
|
212 |
+
# 5. compute variance: "sigma_t(Ξ·)" -> see formula (16)
|
213 |
+
# Ο_t = sqrt((1 β Ξ±_tβ1)/(1 β Ξ±_t)) * sqrt(1 β Ξ±_t/Ξ±_tβ1)
|
214 |
+
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
215 |
+
variance = self.get_variance(t)
|
216 |
+
std_dev_t = eta * variance ** (0.5)
|
217 |
+
# Take care of asymetric reverse process (asyrp)
|
218 |
+
model_output_direction = noise_pred
|
219 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
220 |
+
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
221 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
222 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
223 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
224 |
+
# 8. Add noice if eta > 0
|
225 |
+
if eta > 0:
|
226 |
+
if z is None:
|
227 |
+
z = torch.randn(noise_pred.shape, device=self.device)
|
228 |
+
sigma_z = eta * variance ** (0.5) * z
|
229 |
+
prev_sample = prev_sample + sigma_z
|
230 |
+
return prev_sample
|
231 |
+
|
232 |
+
def get_variance(self, timestep):
|
233 |
+
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
234 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
235 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
|
236 |
+
prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
237 |
+
beta_prod_t = 1 - alpha_prod_t
|
238 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
239 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
240 |
+
return variance
|
models/unet_2d_condition.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
from diffusers import UNet2DConditionModel
|
6 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
7 |
+
from diffusers.utils import logging
|
8 |
+
from torch.fft import fftn, ifftn, fftshift, ifftshift
|
9 |
+
|
10 |
+
"""
|
11 |
+
This is a small extension of the standard UNet2DConditionModel with the small addition of the
|
12 |
+
Free-U trick (https://github.com/ChenyangSi/FreeU).
|
13 |
+
"""
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
16 |
+
|
17 |
+
|
18 |
+
def Fourier_filter(x, threshold, scale):
|
19 |
+
# FFT
|
20 |
+
x_freq = fftn(x, dim=(-2, -1))
|
21 |
+
x_freq = fftshift(x_freq, dim=(-2, -1))
|
22 |
+
|
23 |
+
B, C, H, W = x_freq.shape
|
24 |
+
mask = torch.ones((B, C, H, W)).cuda() # CUDA için
|
25 |
+
|
26 |
+
crow, ccol = H // 2, W // 2
|
27 |
+
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
28 |
+
x_freq = x_freq * mask
|
29 |
+
|
30 |
+
# IFFT
|
31 |
+
x_freq = ifftshift(x_freq, dim=(-2, -1))
|
32 |
+
x_filtered = ifftn(x_freq, dim=(-2, -1)).real
|
33 |
+
|
34 |
+
return x_filtered
|
35 |
+
|
36 |
+
|
37 |
+
class FreeUUNet2DConditionModel(UNet2DConditionModel):
|
38 |
+
|
39 |
+
def forward(
|
40 |
+
self,
|
41 |
+
sample: torch.FloatTensor,
|
42 |
+
timestep: Union[torch.Tensor, float, int],
|
43 |
+
encoder_hidden_states: torch.Tensor,
|
44 |
+
class_labels: Optional[torch.Tensor] = None,
|
45 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
46 |
+
attention_mask: Optional[torch.Tensor] = None,
|
47 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
48 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
49 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
50 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
51 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
52 |
+
return_dict: bool = True,
|
53 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
54 |
+
r"""
|
55 |
+
The [`UNet2DConditionModel`] forward method.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
sample (`torch.FloatTensor`):
|
59 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
60 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
61 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
62 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
63 |
+
encoder_attention_mask (`torch.Tensor`):
|
64 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
65 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
66 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
67 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
69 |
+
tuple.
|
70 |
+
cross_attention_kwargs (`dict`, *optional*):
|
71 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
72 |
+
added_cond_kwargs: (`dict`, *optional*):
|
73 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
74 |
+
are passed along to the UNet blocks.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
78 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
79 |
+
a `tuple` is returned where the first element is the sample tensor.
|
80 |
+
"""
|
81 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
82 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
83 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
84 |
+
# on the fly if necessary.
|
85 |
+
default_overall_up_factor = 2 ** self.num_upsamplers
|
86 |
+
|
87 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
88 |
+
forward_upsample_size = False
|
89 |
+
upsample_size = None
|
90 |
+
|
91 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
92 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
93 |
+
forward_upsample_size = True
|
94 |
+
|
95 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
96 |
+
# expects mask of shape:
|
97 |
+
# [batch, key_tokens]
|
98 |
+
# adds singleton query_tokens dimension:
|
99 |
+
# [batch, 1, key_tokens]
|
100 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
101 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
102 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
103 |
+
if attention_mask is not None:
|
104 |
+
# assume that mask is expressed as:
|
105 |
+
# (1 = keep, 0 = discard)
|
106 |
+
# convert mask into a bias that can be added to attention scores:
|
107 |
+
# (keep = +0, discard = -10000.0)
|
108 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
109 |
+
attention_mask = attention_mask.unsqueeze(1)
|
110 |
+
|
111 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
112 |
+
if encoder_attention_mask is not None:
|
113 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
114 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
115 |
+
|
116 |
+
# 0. center input if necessary
|
117 |
+
if self.config.center_input_sample:
|
118 |
+
sample = 2 * sample - 1.0
|
119 |
+
|
120 |
+
# 1. time
|
121 |
+
timesteps = timestep
|
122 |
+
if not torch.is_tensor(timesteps):
|
123 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
124 |
+
is_mps = sample.device.type == "mps"
|
125 |
+
if isinstance(timestep, float):
|
126 |
+
dtype = torch.float32 if is_mps else torch.float64
|
127 |
+
else:
|
128 |
+
dtype = torch.int32 if is_mps else torch.int64
|
129 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
130 |
+
elif len(timesteps.shape) == 0:
|
131 |
+
timesteps = timesteps[None].to(sample.device)
|
132 |
+
|
133 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
134 |
+
timesteps = timesteps.expand(sample.shape[0])
|
135 |
+
|
136 |
+
t_emb = self.time_proj(timesteps)
|
137 |
+
|
138 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
139 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
140 |
+
# there might be better ways to encapsulate this.
|
141 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
142 |
+
|
143 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
144 |
+
aug_emb = None
|
145 |
+
|
146 |
+
if self.class_embedding is not None:
|
147 |
+
if class_labels is None:
|
148 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
149 |
+
|
150 |
+
if self.config.class_embed_type == "timestep":
|
151 |
+
class_labels = self.time_proj(class_labels)
|
152 |
+
|
153 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
154 |
+
# there might be better ways to encapsulate this.
|
155 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
156 |
+
|
157 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
158 |
+
|
159 |
+
if self.config.class_embeddings_concat:
|
160 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
161 |
+
else:
|
162 |
+
emb = emb + class_emb
|
163 |
+
|
164 |
+
if self.config.addition_embed_type == "text":
|
165 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
166 |
+
elif self.config.addition_embed_type == "text_image":
|
167 |
+
# Kandinsky 2.1 - style
|
168 |
+
if "image_embeds" not in added_cond_kwargs:
|
169 |
+
raise ValueError(
|
170 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
171 |
+
)
|
172 |
+
|
173 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
174 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
175 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
176 |
+
elif self.config.addition_embed_type == "text_time":
|
177 |
+
# SDXL - style
|
178 |
+
if "text_embeds" not in added_cond_kwargs:
|
179 |
+
raise ValueError(
|
180 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
181 |
+
)
|
182 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
183 |
+
if "time_ids" not in added_cond_kwargs:
|
184 |
+
raise ValueError(
|
185 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
186 |
+
)
|
187 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
188 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
189 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
190 |
+
|
191 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
192 |
+
add_embeds = add_embeds.to(emb.dtype)
|
193 |
+
aug_emb = self.add_embedding(add_embeds)
|
194 |
+
elif self.config.addition_embed_type == "image":
|
195 |
+
# Kandinsky 2.2 - style
|
196 |
+
if "image_embeds" not in added_cond_kwargs:
|
197 |
+
raise ValueError(
|
198 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
199 |
+
)
|
200 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
201 |
+
aug_emb = self.add_embedding(image_embs)
|
202 |
+
elif self.config.addition_embed_type == "image_hint":
|
203 |
+
# Kandinsky 2.2 - style
|
204 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
205 |
+
raise ValueError(
|
206 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
207 |
+
)
|
208 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
209 |
+
hint = added_cond_kwargs.get("hint")
|
210 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
211 |
+
sample = torch.cat([sample, hint], dim=1)
|
212 |
+
|
213 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
214 |
+
|
215 |
+
if self.time_embed_act is not None:
|
216 |
+
emb = self.time_embed_act(emb)
|
217 |
+
|
218 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
219 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
220 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
221 |
+
# Kadinsky 2.1 - style
|
222 |
+
if "image_embeds" not in added_cond_kwargs:
|
223 |
+
raise ValueError(
|
224 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
225 |
+
)
|
226 |
+
|
227 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
228 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
229 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
230 |
+
# Kandinsky 2.2 - style
|
231 |
+
if "image_embeds" not in added_cond_kwargs:
|
232 |
+
raise ValueError(
|
233 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
234 |
+
)
|
235 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
236 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
237 |
+
# 2. pre-process
|
238 |
+
sample = self.conv_in(sample)
|
239 |
+
|
240 |
+
# 3. down
|
241 |
+
|
242 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
243 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
244 |
+
|
245 |
+
down_block_res_samples = (sample,)
|
246 |
+
for downsample_block in self.down_blocks:
|
247 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
248 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
249 |
+
additional_residuals = {}
|
250 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
251 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
252 |
+
|
253 |
+
sample, res_samples = downsample_block(
|
254 |
+
hidden_states=sample,
|
255 |
+
temb=emb,
|
256 |
+
encoder_hidden_states=encoder_hidden_states,
|
257 |
+
attention_mask=attention_mask,
|
258 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
259 |
+
encoder_attention_mask=encoder_attention_mask,
|
260 |
+
**additional_residuals,
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
264 |
+
|
265 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
266 |
+
sample += down_block_additional_residuals.pop(0)
|
267 |
+
|
268 |
+
down_block_res_samples += res_samples
|
269 |
+
|
270 |
+
if is_controlnet:
|
271 |
+
new_down_block_res_samples = ()
|
272 |
+
|
273 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
274 |
+
down_block_res_samples, down_block_additional_residuals
|
275 |
+
):
|
276 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
277 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
278 |
+
|
279 |
+
down_block_res_samples = new_down_block_res_samples
|
280 |
+
|
281 |
+
# 4. mid
|
282 |
+
if self.mid_block is not None:
|
283 |
+
sample = self.mid_block(
|
284 |
+
sample,
|
285 |
+
emb,
|
286 |
+
encoder_hidden_states=encoder_hidden_states,
|
287 |
+
attention_mask=attention_mask,
|
288 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
289 |
+
encoder_attention_mask=encoder_attention_mask,
|
290 |
+
)
|
291 |
+
|
292 |
+
if is_controlnet:
|
293 |
+
sample = sample + mid_block_additional_residual
|
294 |
+
|
295 |
+
# 5. up
|
296 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
297 |
+
is_final_block = i == len(self.up_blocks) - 1
|
298 |
+
|
299 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
300 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
301 |
+
|
302 |
+
# Add the Free-U trick here!
|
303 |
+
# Fourier Filter
|
304 |
+
if sample.shape[1] == 1280:
|
305 |
+
sample[:, :640] *= 1.2 # 1.1 # For SD2.1
|
306 |
+
sample = Fourier_filter(sample, threshold=1, scale=0.9)
|
307 |
+
|
308 |
+
if sample.shape[1] == 640:
|
309 |
+
sample[:, :320] *= 1.4 # 1.2 # For SD2.1
|
310 |
+
sample = Fourier_filter(sample, threshold=1, scale=0.2)
|
311 |
+
|
312 |
+
# if we have not reached the final block and need to forward the
|
313 |
+
# upsample size, we do it here
|
314 |
+
if not is_final_block and forward_upsample_size:
|
315 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
316 |
+
|
317 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
318 |
+
sample = upsample_block(
|
319 |
+
hidden_states=sample,
|
320 |
+
temb=emb,
|
321 |
+
res_hidden_states_tuple=res_samples,
|
322 |
+
encoder_hidden_states=encoder_hidden_states,
|
323 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
324 |
+
upsample_size=upsample_size,
|
325 |
+
attention_mask=attention_mask,
|
326 |
+
encoder_attention_mask=encoder_attention_mask,
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
sample = upsample_block(
|
330 |
+
hidden_states=sample,
|
331 |
+
temb=emb,
|
332 |
+
res_hidden_states_tuple=res_samples,
|
333 |
+
upsample_size=upsample_size
|
334 |
+
)
|
335 |
+
|
336 |
+
# 6. post-process
|
337 |
+
if self.conv_norm_out:
|
338 |
+
sample = self.conv_norm_out(sample)
|
339 |
+
sample = self.conv_act(sample)
|
340 |
+
sample = self.conv_out(sample)
|
341 |
+
|
342 |
+
if not return_dict:
|
343 |
+
return (sample,)
|
344 |
+
|
345 |
+
return UNet2DConditionOutput(sample=sample)
|
utils/__init__.py
ADDED
File without changes
|
utils/adain.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def masked_adain(content_feat, style_feat, content_mask, style_mask):
|
2 |
+
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
3 |
+
size = content_feat.size()
|
4 |
+
style_mean, style_std = calc_mean_std(style_feat, mask=style_mask)
|
5 |
+
content_mean, content_std = calc_mean_std(content_feat, mask=content_mask)
|
6 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
7 |
+
style_normalized_feat = normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
8 |
+
return content_feat * (1 - content_mask) + style_normalized_feat * content_mask
|
9 |
+
|
10 |
+
|
11 |
+
def calc_mean_std(feat, eps=1e-5, mask=None):
|
12 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
13 |
+
size = feat.size()
|
14 |
+
if len(size) == 2:
|
15 |
+
return calc_mean_std_2d(feat, eps, mask)
|
16 |
+
|
17 |
+
assert (len(size) == 3)
|
18 |
+
C = size[0]
|
19 |
+
if mask is not None:
|
20 |
+
feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
|
21 |
+
feat_std = feat_var.sqrt().view(C, 1, 1)
|
22 |
+
feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1, 1)
|
23 |
+
else:
|
24 |
+
feat_var = feat.view(C, -1).var(dim=1) + eps
|
25 |
+
feat_std = feat_var.sqrt().view(C, 1, 1)
|
26 |
+
feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1, 1)
|
27 |
+
|
28 |
+
return feat_mean, feat_std
|
29 |
+
|
30 |
+
|
31 |
+
def calc_mean_std_2d(feat, eps=1e-5, mask=None):
|
32 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
33 |
+
size = feat.size()
|
34 |
+
assert (len(size) == 2)
|
35 |
+
C = size[0]
|
36 |
+
if mask is not None:
|
37 |
+
feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
|
38 |
+
feat_std = feat_var.sqrt().view(C, 1)
|
39 |
+
feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1)
|
40 |
+
else:
|
41 |
+
feat_var = feat.view(C, -1).var(dim=1) + eps
|
42 |
+
feat_std = feat_var.sqrt().view(C, 1)
|
43 |
+
feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1)
|
44 |
+
|
45 |
+
return feat_mean, feat_std
|
utils/attention_utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from constants import OUT_INDEX
|
5 |
+
|
6 |
+
|
7 |
+
def should_mix_keys_and_values(model, hidden_states: torch.Tensor) -> bool:
|
8 |
+
""" Verify whether we should perform the mixing in the current timestep. """
|
9 |
+
is_in_32_timestep_range = (
|
10 |
+
model.config.cross_attn_32_range.start <= model.step < model.config.cross_attn_32_range.end
|
11 |
+
)
|
12 |
+
is_in_64_timestep_range = (
|
13 |
+
model.config.cross_attn_64_range.start <= model.step < model.config.cross_attn_64_range.end
|
14 |
+
)
|
15 |
+
is_hidden_states_32_square = (hidden_states.shape[1] == 32 ** 2)
|
16 |
+
is_hidden_states_64_square = (hidden_states.shape[1] == 64 ** 2)
|
17 |
+
should_mix = (is_in_32_timestep_range and is_hidden_states_32_square) or \
|
18 |
+
(is_in_64_timestep_range and is_hidden_states_64_square)
|
19 |
+
return should_mix
|
20 |
+
|
21 |
+
|
22 |
+
def compute_scaled_dot_product_attention(Q, K, V, edit_map=False, is_cross=False, contrast_strength=1.0):
|
23 |
+
""" Compute the scale dot product attention, potentially with our contrasting operation. """
|
24 |
+
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))), dim=-1)
|
25 |
+
if edit_map and not is_cross:
|
26 |
+
attn_weight[OUT_INDEX] = torch.stack([
|
27 |
+
torch.clip(enhance_tensor(attn_weight[OUT_INDEX][head_idx], contrast_factor=contrast_strength),
|
28 |
+
min=0.0, max=1.0)
|
29 |
+
for head_idx in range(attn_weight.shape[1])
|
30 |
+
])
|
31 |
+
return attn_weight @ V, attn_weight
|
32 |
+
|
33 |
+
|
34 |
+
def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor:
|
35 |
+
""" Compute the attention map contrasting. """
|
36 |
+
adjusted_tensor = (tensor - tensor.mean(dim=-1)) * contrast_factor + tensor.mean(dim=-1)
|
37 |
+
return adjusted_tensor
|
utils/ddpm_inversion.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import inference_mode
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
"""
|
8 |
+
Inversion code taken from:
|
9 |
+
1. The official implementation of Edit-Friendly DDPM Inversion: https://github.com/inbarhub/DDPM_inversion
|
10 |
+
2. The LEDITS demo: https://huggingface.co/spaces/editing-images/ledits/tree/main
|
11 |
+
"""
|
12 |
+
|
13 |
+
LOW_RESOURCE = True
|
14 |
+
|
15 |
+
|
16 |
+
def invert(x0, pipe, prompt_src="", num_diffusion_steps=100, cfg_scale_src=3.5, eta=1):
|
17 |
+
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
18 |
+
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
19 |
+
# returns wt, zs, wts:
|
20 |
+
# wt - inverted latent
|
21 |
+
# wts - intermediate inverted latents
|
22 |
+
# zs - noise maps
|
23 |
+
pipe.scheduler.set_timesteps(num_diffusion_steps)
|
24 |
+
with inference_mode():
|
25 |
+
w0 = (pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
|
26 |
+
wt, zs, wts = inversion_forward_process(pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src,
|
27 |
+
prog_bar=True, num_inference_steps=num_diffusion_steps)
|
28 |
+
return zs, wts
|
29 |
+
|
30 |
+
|
31 |
+
def inversion_forward_process(model, x0,
|
32 |
+
etas=None,
|
33 |
+
prog_bar=False,
|
34 |
+
prompt="",
|
35 |
+
cfg_scale=3.5,
|
36 |
+
num_inference_steps=50, eps=None
|
37 |
+
):
|
38 |
+
if not prompt == "":
|
39 |
+
text_embeddings = encode_text(model, prompt)
|
40 |
+
uncond_embedding = encode_text(model, "")
|
41 |
+
timesteps = model.scheduler.timesteps.to(model.device)
|
42 |
+
variance_noise_shape = (
|
43 |
+
num_inference_steps,
|
44 |
+
model.unet.in_channels,
|
45 |
+
model.unet.sample_size,
|
46 |
+
model.unet.sample_size)
|
47 |
+
if etas is None or (type(etas) in [int, float] and etas == 0):
|
48 |
+
eta_is_zero = True
|
49 |
+
zs = None
|
50 |
+
else:
|
51 |
+
eta_is_zero = False
|
52 |
+
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
53 |
+
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
|
54 |
+
alpha_bar = model.scheduler.alphas_cumprod
|
55 |
+
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
56 |
+
|
57 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
58 |
+
xt = x0
|
59 |
+
op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
|
60 |
+
|
61 |
+
for t in op:
|
62 |
+
idx = t_to_idx[int(t)]
|
63 |
+
# 1. predict noise residual
|
64 |
+
if not eta_is_zero:
|
65 |
+
xt = xts[idx][None]
|
66 |
+
|
67 |
+
with torch.no_grad():
|
68 |
+
out = model.unet.forward(xt, timestep=t, encoder_hidden_states=uncond_embedding)
|
69 |
+
if not prompt == "":
|
70 |
+
cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states=text_embeddings)
|
71 |
+
|
72 |
+
if not prompt == "":
|
73 |
+
## classifier free guidance
|
74 |
+
noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
|
75 |
+
else:
|
76 |
+
noise_pred = out.sample
|
77 |
+
|
78 |
+
if eta_is_zero:
|
79 |
+
# 2. compute more noisy image and set x_t -> x_t+1
|
80 |
+
xt = forward_step(model, noise_pred, t, xt)
|
81 |
+
|
82 |
+
else:
|
83 |
+
xtm1 = xts[idx + 1][None]
|
84 |
+
# pred of x0
|
85 |
+
pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
86 |
+
|
87 |
+
# direction to xt
|
88 |
+
prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
89 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
90 |
+
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
91 |
+
|
92 |
+
variance = get_variance(model, t)
|
93 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
94 |
+
|
95 |
+
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
96 |
+
|
97 |
+
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
98 |
+
zs[idx] = z
|
99 |
+
|
100 |
+
# correction to avoid error accumulation
|
101 |
+
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
102 |
+
xts[idx + 1] = xtm1
|
103 |
+
|
104 |
+
if not zs is None:
|
105 |
+
zs[-1] = torch.zeros_like(zs[-1])
|
106 |
+
|
107 |
+
return xt, zs, xts
|
108 |
+
|
109 |
+
|
110 |
+
def encode_text(model, prompts):
|
111 |
+
text_input = model.tokenizer(
|
112 |
+
prompts,
|
113 |
+
padding="max_length",
|
114 |
+
max_length=model.tokenizer.model_max_length,
|
115 |
+
truncation=True,
|
116 |
+
return_tensors="pt",
|
117 |
+
)
|
118 |
+
with torch.no_grad():
|
119 |
+
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
120 |
+
return text_encoding
|
121 |
+
|
122 |
+
|
123 |
+
def sample_xts_from_x0(model, x0, num_inference_steps=50):
|
124 |
+
"""
|
125 |
+
Samples from P(x_1:T|x_0)
|
126 |
+
"""
|
127 |
+
# torch.manual_seed(43256465436)
|
128 |
+
alpha_bar = model.scheduler.alphas_cumprod
|
129 |
+
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
|
130 |
+
alphas = model.scheduler.alphas
|
131 |
+
betas = 1 - alphas
|
132 |
+
variance_noise_shape = (
|
133 |
+
num_inference_steps,
|
134 |
+
model.unet.in_channels,
|
135 |
+
model.unet.sample_size,
|
136 |
+
model.unet.sample_size)
|
137 |
+
|
138 |
+
timesteps = model.scheduler.timesteps.to(model.device)
|
139 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
140 |
+
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
141 |
+
for t in reversed(timesteps):
|
142 |
+
idx = t_to_idx[int(t)]
|
143 |
+
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
144 |
+
xts = torch.cat([xts, x0], dim=0)
|
145 |
+
|
146 |
+
return xts
|
147 |
+
|
148 |
+
|
149 |
+
def forward_step(model, model_output, timestep, sample):
|
150 |
+
next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
|
151 |
+
timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
|
152 |
+
|
153 |
+
# 2. compute alphas, betas
|
154 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
155 |
+
|
156 |
+
beta_prod_t = 1 - alpha_prod_t
|
157 |
+
|
158 |
+
# 3. compute predicted original sample from predicted noise also called
|
159 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
160 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
161 |
+
next_sample = model.scheduler.add_noise(pred_original_sample,
|
162 |
+
model_output,
|
163 |
+
torch.LongTensor([next_timestep]))
|
164 |
+
return next_sample
|
165 |
+
|
166 |
+
|
167 |
+
def get_variance(model, timestep):
|
168 |
+
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
169 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
170 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
171 |
+
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
172 |
+
beta_prod_t = 1 - alpha_prod_t
|
173 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
174 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
175 |
+
return variance
|
176 |
+
|
177 |
+
|
178 |
+
class AttentionControl(abc.ABC):
|
179 |
+
|
180 |
+
def step_callback(self, x_t):
|
181 |
+
return x_t
|
182 |
+
|
183 |
+
def between_steps(self):
|
184 |
+
return
|
185 |
+
|
186 |
+
@property
|
187 |
+
def num_uncond_att_layers(self):
|
188 |
+
return self.num_att_layers if LOW_RESOURCE else 0
|
189 |
+
|
190 |
+
@abc.abstractmethod
|
191 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
192 |
+
raise NotImplementedError
|
193 |
+
|
194 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
195 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
196 |
+
if LOW_RESOURCE:
|
197 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
198 |
+
else:
|
199 |
+
h = attn.shape[0]
|
200 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
201 |
+
self.cur_att_layer += 1
|
202 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
203 |
+
self.cur_att_layer = 0
|
204 |
+
self.cur_step += 1
|
205 |
+
self.between_steps()
|
206 |
+
return attn
|
207 |
+
|
208 |
+
def reset(self):
|
209 |
+
self.cur_step = 0
|
210 |
+
self.cur_att_layer = 0
|
211 |
+
|
212 |
+
def __init__(self):
|
213 |
+
self.cur_step = 0
|
214 |
+
self.num_att_layers = -1
|
215 |
+
self.cur_att_layer = 0
|
216 |
+
|
217 |
+
|
218 |
+
class AttentionStore(AttentionControl):
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def get_empty_store():
|
222 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
223 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
224 |
+
|
225 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
226 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
227 |
+
if attn.shape[1] <= 32 ** 2: # avoid memory overhead
|
228 |
+
self.step_store[key].append(attn)
|
229 |
+
return attn
|
230 |
+
|
231 |
+
def between_steps(self):
|
232 |
+
if len(self.attention_store) == 0:
|
233 |
+
self.attention_store = self.step_store
|
234 |
+
else:
|
235 |
+
for key in self.attention_store:
|
236 |
+
for i in range(len(self.attention_store[key])):
|
237 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
238 |
+
self.step_store = self.get_empty_store()
|
239 |
+
|
240 |
+
def get_average_attention(self):
|
241 |
+
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
|
242 |
+
self.attention_store}
|
243 |
+
return average_attention
|
244 |
+
|
245 |
+
def reset(self):
|
246 |
+
super(AttentionStore, self).reset()
|
247 |
+
self.step_store = self.get_empty_store()
|
248 |
+
self.attention_store = {}
|
249 |
+
|
250 |
+
def __init__(self):
|
251 |
+
super(AttentionStore, self).__init__()
|
252 |
+
self.step_store = self.get_empty_store()
|
253 |
+
self.attention_store = {}
|
254 |
+
|
255 |
+
|
256 |
+
def register_attention_control(model, controller):
|
257 |
+
def ca_forward(self, place_in_unet):
|
258 |
+
to_out = self.to_out
|
259 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
260 |
+
to_out = self.to_out[0]
|
261 |
+
else:
|
262 |
+
to_out = self.to_out
|
263 |
+
|
264 |
+
def forward(x, context=None, mask=None):
|
265 |
+
batch_size, sequence_length, dim = x.shape
|
266 |
+
h = self.heads
|
267 |
+
q = self.to_q(x)
|
268 |
+
is_cross = context is not None
|
269 |
+
context = context if is_cross else x
|
270 |
+
k = self.to_k(context)
|
271 |
+
v = self.to_v(context)
|
272 |
+
q = self.reshape_heads_to_batch_dim(q)
|
273 |
+
k = self.reshape_heads_to_batch_dim(k)
|
274 |
+
v = self.reshape_heads_to_batch_dim(v)
|
275 |
+
|
276 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
277 |
+
|
278 |
+
if mask is not None:
|
279 |
+
mask = mask.reshape(batch_size, -1)
|
280 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
281 |
+
mask = mask[:, None, :].repeat(h, 1, 1)
|
282 |
+
sim.masked_fill_(~mask, max_neg_value)
|
283 |
+
|
284 |
+
# attention, what we cannot get enough of
|
285 |
+
attn = sim.softmax(dim=-1)
|
286 |
+
attn = controller(attn, is_cross, place_in_unet)
|
287 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
288 |
+
out = self.reshape_batch_dim_to_heads(out)
|
289 |
+
return to_out(out)
|
290 |
+
|
291 |
+
return forward
|
292 |
+
|
293 |
+
class DummyController:
|
294 |
+
|
295 |
+
def __call__(self, *args):
|
296 |
+
return args[0]
|
297 |
+
|
298 |
+
def __init__(self):
|
299 |
+
self.num_att_layers = 0
|
300 |
+
|
301 |
+
if controller is None:
|
302 |
+
controller = DummyController()
|
303 |
+
|
304 |
+
def register_recr(net_, count, place_in_unet):
|
305 |
+
if net_.__class__.__name__ == 'CrossAttention':
|
306 |
+
net_.forward = ca_forward(net_, place_in_unet)
|
307 |
+
return count + 1
|
308 |
+
elif hasattr(net_, 'children'):
|
309 |
+
for net__ in net_.children():
|
310 |
+
count = register_recr(net__, count, place_in_unet)
|
311 |
+
return count
|
312 |
+
|
313 |
+
cross_att_count = 0
|
314 |
+
sub_nets = model.unet.named_children()
|
315 |
+
for net in sub_nets:
|
316 |
+
if "down" in net[0]:
|
317 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
318 |
+
elif "up" in net[0]:
|
319 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
320 |
+
elif "mid" in net[0]:
|
321 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
322 |
+
|
323 |
+
controller.num_att_layers = cross_att_count
|
utils/image_utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from config import RunConfig
|
8 |
+
|
9 |
+
|
10 |
+
def load_images(cfg: RunConfig, save_path: Optional[pathlib.Path] = None) -> Tuple[Image.Image, Image.Image]:
|
11 |
+
image_style = load_size(cfg.app_image_path)
|
12 |
+
image_struct = load_size(cfg.struct_image_path)
|
13 |
+
if save_path is not None:
|
14 |
+
Image.fromarray(image_style).save(save_path / f"in_style.png")
|
15 |
+
Image.fromarray(image_struct).save(save_path / f"in_struct.png")
|
16 |
+
return image_style, image_struct
|
17 |
+
|
18 |
+
|
19 |
+
def load_size(image_path: pathlib.Path,
|
20 |
+
left: int = 0,
|
21 |
+
right: int = 0,
|
22 |
+
top: int = 0,
|
23 |
+
bottom: int = 0,
|
24 |
+
size: int = 512) -> Image.Image:
|
25 |
+
if type(image_path) is str or type(image_path) is pathlib.PosixPath:
|
26 |
+
image = np.array(Image.open(image_path).convert('RGB'))
|
27 |
+
else:
|
28 |
+
image = image_path
|
29 |
+
|
30 |
+
h, w, c = image.shape
|
31 |
+
|
32 |
+
left = min(left, w - 1)
|
33 |
+
right = min(right, w - left - 1)
|
34 |
+
top = min(top, h - left - 1)
|
35 |
+
bottom = min(bottom, h - top - 1)
|
36 |
+
image = image[top:h - bottom, left:w - right]
|
37 |
+
|
38 |
+
h, w, c = image.shape
|
39 |
+
|
40 |
+
if h < w:
|
41 |
+
offset = (w - h) // 2
|
42 |
+
image = image[:, offset:offset + h]
|
43 |
+
elif w < h:
|
44 |
+
offset = (h - w) // 2
|
45 |
+
image = image[offset:offset + w]
|
46 |
+
|
47 |
+
image = np.array(Image.fromarray(image).resize((size, size)))
|
48 |
+
return image
|
49 |
+
|
50 |
+
|
51 |
+
def save_generated_masks(model, cfg: RunConfig):
|
52 |
+
tensor2im(model.image_app_mask_32).save(cfg.output_path / f"mask_style_32.png")
|
53 |
+
tensor2im(model.image_struct_mask_32).save(cfg.output_path / f"mask_struct_32.png")
|
54 |
+
tensor2im(model.image_app_mask_64).save(cfg.output_path / f"mask_style_64.png")
|
55 |
+
tensor2im(model.image_struct_mask_64).save(cfg.output_path / f"mask_struct_64.png")
|
56 |
+
|
57 |
+
|
58 |
+
def tensor2im(x) -> Image.Image:
|
59 |
+
return Image.fromarray(x.cpu().numpy().astype(np.uint8) * 255)
|
utils/latent_utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from appearance_transfer_model import AppearanceTransferModel
|
9 |
+
from config import RunConfig
|
10 |
+
from utils import image_utils
|
11 |
+
from utils.ddpm_inversion import invert
|
12 |
+
|
13 |
+
|
14 |
+
def load_latents_or_invert_images(model: AppearanceTransferModel, cfg: RunConfig):
|
15 |
+
if cfg.load_latents and cfg.app_latent_save_path.exists() and cfg.struct_latent_save_path.exists():
|
16 |
+
print("Loading existing latents...")
|
17 |
+
latents_app, latents_struct = load_latents(cfg.app_latent_save_path, cfg.struct_latent_save_path)
|
18 |
+
noise_app, noise_struct = load_noise(cfg.app_latent_save_path, cfg.struct_latent_save_path)
|
19 |
+
print("Done.")
|
20 |
+
else:
|
21 |
+
print("Inverting images...")
|
22 |
+
app_image, struct_image = image_utils.load_images(cfg=cfg, save_path=cfg.output_path)
|
23 |
+
model.enable_edit = False # Deactivate the cross-image attention layers
|
24 |
+
latents_app, latents_struct, noise_app, noise_struct = invert_images(app_image=app_image,
|
25 |
+
struct_image=struct_image,
|
26 |
+
sd_model=model.pipe,
|
27 |
+
cfg=cfg)
|
28 |
+
model.enable_edit = True
|
29 |
+
print("Done.")
|
30 |
+
return latents_app, latents_struct, noise_app, noise_struct
|
31 |
+
|
32 |
+
|
33 |
+
def load_latents(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
34 |
+
latents_app = torch.load(app_latent_save_path)
|
35 |
+
latents_struct = torch.load(struct_latent_save_path)
|
36 |
+
if type(latents_struct) == list:
|
37 |
+
latents_app = [l.to("cuda") for l in latents_app]
|
38 |
+
latents_struct = [l.to("cuda") for l in latents_struct]
|
39 |
+
else:
|
40 |
+
latents_app = latents_app.to("cuda")
|
41 |
+
latents_struct = latents_struct.to("cuda")
|
42 |
+
return latents_app, latents_struct
|
43 |
+
|
44 |
+
|
45 |
+
def load_noise(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
46 |
+
latents_app = torch.load(app_latent_save_path.parent / (app_latent_save_path.stem + "_ddpm_noise.pt"))
|
47 |
+
latents_struct = torch.load(struct_latent_save_path.parent / (struct_latent_save_path.stem + "_ddpm_noise.pt"))
|
48 |
+
latents_app = latents_app.to("cuda")
|
49 |
+
latents_struct = latents_struct.to("cuda")
|
50 |
+
return latents_app, latents_struct
|
51 |
+
|
52 |
+
|
53 |
+
def invert_images(sd_model: AppearanceTransferModel, app_image: Image.Image, struct_image: Image.Image, cfg: RunConfig):
|
54 |
+
input_app = torch.from_numpy(np.array(app_image)).float() / 127.5 - 1.0
|
55 |
+
input_struct = torch.from_numpy(np.array(struct_image)).float() / 127.5 - 1.0
|
56 |
+
zs_app, latents_app = invert(x0=input_app.permute(2, 0, 1).unsqueeze(0).to('cuda'),
|
57 |
+
pipe=sd_model,
|
58 |
+
prompt_src=cfg.prompt,
|
59 |
+
num_diffusion_steps=cfg.num_timesteps,
|
60 |
+
cfg_scale_src=3.5)
|
61 |
+
zs_struct, latents_struct = invert(x0=input_struct.permute(2, 0, 1).unsqueeze(0).to('cuda'),
|
62 |
+
pipe=sd_model,
|
63 |
+
prompt_src=cfg.prompt,
|
64 |
+
num_diffusion_steps=cfg.num_timesteps,
|
65 |
+
cfg_scale_src=3.5)
|
66 |
+
# Save the inverted latents and noises
|
67 |
+
torch.save(latents_app, cfg.latents_path / f"{cfg.app_image_path.stem}.pt")
|
68 |
+
torch.save(latents_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}.pt")
|
69 |
+
torch.save(zs_app, cfg.latents_path / f"{cfg.app_image_path.stem}_ddpm_noise.pt")
|
70 |
+
torch.save(zs_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}_ddpm_noise.pt")
|
71 |
+
return latents_app, latents_struct, zs_app, zs_struct
|
72 |
+
|
73 |
+
|
74 |
+
def get_init_latents_and_noises(model: AppearanceTransferModel, cfg: RunConfig) -> Tuple[torch.Tensor, torch.Tensor]:
|
75 |
+
# If we stored all the latents along the diffusion process, select the desired one based on the skip_steps
|
76 |
+
if model.latents_struct.dim() == 4 and model.latents_app.dim() == 4 and model.latents_app.shape[0] > 1:
|
77 |
+
model.latents_struct = model.latents_struct[cfg.skip_steps]
|
78 |
+
model.latents_app = model.latents_app[cfg.skip_steps]
|
79 |
+
init_latents = torch.stack([model.latents_struct, model.latents_app, model.latents_struct])
|
80 |
+
init_zs = [model.zs_struct[cfg.skip_steps:], model.zs_app[cfg.skip_steps:], model.zs_struct[cfg.skip_steps:]]
|
81 |
+
return init_latents, init_zs
|
utils/model_utils.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import DDIMScheduler
|
3 |
+
|
4 |
+
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
|
5 |
+
from models.unet_2d_condition import FreeUUNet2DConditionModel
|
6 |
+
|
7 |
+
|
8 |
+
def get_stable_diffusion_model() -> CrossImageAttentionStableDiffusionPipeline:
|
9 |
+
print("Loading Stable Diffusion model...")
|
10 |
+
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
11 |
+
pipe = CrossImageAttentionStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
|
12 |
+
safety_checker=None).to(device)
|
13 |
+
pipe.unet = FreeUUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").to(device)
|
14 |
+
pipe.scheduler = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
15 |
+
print("Done.")
|
16 |
+
return pipe
|
utils/segmentation.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List
|
2 |
+
|
3 |
+
import nltk
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from sklearn.cluster import KMeans
|
7 |
+
|
8 |
+
from constants import STYLE_INDEX, STRUCT_INDEX
|
9 |
+
|
10 |
+
nltk.download('punkt')
|
11 |
+
nltk.download('averaged_perceptron_tagger')
|
12 |
+
|
13 |
+
"""
|
14 |
+
Self-segmentation technique taken from Prompt Mixing: https://github.com/orpatashnik/local-prompt-mixing
|
15 |
+
"""
|
16 |
+
|
17 |
+
class Segmentor:
|
18 |
+
|
19 |
+
def __init__(self, prompt: str, object_nouns: List[str], num_segments: int = 5, res: int = 32):
|
20 |
+
self.prompt = prompt
|
21 |
+
self.num_segments = num_segments
|
22 |
+
self.resolution = res
|
23 |
+
self.object_nouns = object_nouns
|
24 |
+
tokenized_prompt = nltk.word_tokenize(prompt)
|
25 |
+
forbidden_words = [word.upper() for word in ["photo", "image", "picture"]]
|
26 |
+
self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt))
|
27 |
+
if pos[:2] == 'NN' and word.upper() not in forbidden_words]
|
28 |
+
|
29 |
+
def update_attention(self, attn, is_cross):
|
30 |
+
res = int(attn.shape[2] ** 0.5)
|
31 |
+
if is_cross:
|
32 |
+
if res == 16:
|
33 |
+
self.cross_attention_32 = attn
|
34 |
+
elif res == 32:
|
35 |
+
self.cross_attention_64 = attn
|
36 |
+
else:
|
37 |
+
if res == 32:
|
38 |
+
self.self_attention_32 = attn
|
39 |
+
elif res == 64:
|
40 |
+
self.self_attention_64 = attn
|
41 |
+
|
42 |
+
def __call__(self, *args, **kwargs):
|
43 |
+
clusters = self.cluster()
|
44 |
+
cluster2noun = self.cluster2noun(clusters)
|
45 |
+
return cluster2noun
|
46 |
+
|
47 |
+
def cluster(self, res: int = 32):
|
48 |
+
np.random.seed(1)
|
49 |
+
self_attn = self.self_attention_32 if res == 32 else self.self_attention_64
|
50 |
+
|
51 |
+
style_attn = self_attn[STYLE_INDEX].mean(dim=0).cpu().numpy()
|
52 |
+
style_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(style_attn)
|
53 |
+
style_clusters = style_kmeans.labels_.reshape(res, res)
|
54 |
+
|
55 |
+
struct_attn = self_attn[STRUCT_INDEX].mean(dim=0).cpu().numpy()
|
56 |
+
struct_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(struct_attn)
|
57 |
+
struct_clusters = struct_kmeans.labels_.reshape(res, res)
|
58 |
+
|
59 |
+
return style_clusters, struct_clusters
|
60 |
+
|
61 |
+
def cluster2noun(self, clusters, cross_attn, attn_index):
|
62 |
+
result = {}
|
63 |
+
res = int(cross_attn.shape[2] ** 0.5)
|
64 |
+
nouns_indices = [index for (index, word) in self.nouns]
|
65 |
+
cross_attn = cross_attn[attn_index].mean(dim=0).reshape(res, res, -1)
|
66 |
+
nouns_maps = cross_attn.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]]
|
67 |
+
normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1)
|
68 |
+
for i in range(nouns_maps.shape[-1]):
|
69 |
+
curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1)
|
70 |
+
normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
|
71 |
+
|
72 |
+
max_score = 0
|
73 |
+
all_scores = []
|
74 |
+
for c in range(self.num_segments):
|
75 |
+
cluster_mask = np.zeros_like(clusters)
|
76 |
+
cluster_mask[clusters == c] = 1
|
77 |
+
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
|
78 |
+
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
|
79 |
+
all_scores.append(max(scores))
|
80 |
+
max_score = max(max(scores), max_score)
|
81 |
+
|
82 |
+
all_scores.remove(max_score)
|
83 |
+
mean_score = sum(all_scores) / len(all_scores)
|
84 |
+
|
85 |
+
for c in range(self.num_segments):
|
86 |
+
cluster_mask = np.zeros_like(clusters)
|
87 |
+
cluster_mask[clusters == c] = 1
|
88 |
+
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
|
89 |
+
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
|
90 |
+
result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > 1.4 * mean_score else "BG"
|
91 |
+
|
92 |
+
return result
|
93 |
+
|
94 |
+
def create_mask(self, clusters, cross_attention, attn_index):
|
95 |
+
cluster2noun = self.cluster2noun(clusters, cross_attention, attn_index)
|
96 |
+
mask = clusters.copy()
|
97 |
+
obj_segments = [c for c in cluster2noun if cluster2noun[c][1] in self.object_nouns]
|
98 |
+
for c in range(self.num_segments):
|
99 |
+
mask[clusters == c] = 1 if c in obj_segments else 0
|
100 |
+
return torch.from_numpy(mask).to("cuda")
|
101 |
+
|
102 |
+
def get_object_masks(self) -> Tuple[torch.Tensor]:
|
103 |
+
clusters_style_32, clusters_struct_32 = self.cluster(res=32)
|
104 |
+
clusters_style_64, clusters_struct_64 = self.cluster(res=64)
|
105 |
+
|
106 |
+
mask_style_32 = self.create_mask(clusters_style_32, self.cross_attention_32, STYLE_INDEX)
|
107 |
+
mask_struct_32 = self.create_mask(clusters_struct_32, self.cross_attention_32, STRUCT_INDEX)
|
108 |
+
mask_style_64 = self.create_mask(clusters_style_64, self.cross_attention_64, STYLE_INDEX)
|
109 |
+
mask_struct_64 = self.create_mask(clusters_struct_64, self.cross_attention_64, STRUCT_INDEX)
|
110 |
+
|
111 |
+
return mask_style_32, mask_struct_32, mask_style_64, mask_struct_64
|