erwann commited on
Commit
3f6a58a
·
1 Parent(s): bea83f6

rename backend

Browse files
Files changed (1) hide show
  1. backend.py +230 -0
backend.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from functools import cache
2
+ import importlib
3
+
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torchvision
8
+ import wandb
9
+ from icecream import ic
10
+ from torch import nn
11
+ from torchvision.transforms.functional import resize
12
+ from tqdm import tqdm
13
+ from transformers import CLIPModel, CLIPProcessor
14
+ import lpips
15
+ from edit import blend_paths
16
+ from img_processing import *
17
+ from img_processing import custom_to_pil
18
+ from loaders import load_default
19
+ import glob
20
+
21
+ global log
22
+ log=False
23
+
24
+ # ic.disable()
25
+ # ic.enable()
26
+ def get_resized_tensor(x):
27
+ if len(x.shape) == 2:
28
+ re = x.unsqueeze(0)
29
+ else: re = x
30
+ re = resize(re, (10, 10))
31
+ return re
32
+ class ProcessorGradientFlow():
33
+ """
34
+ This wraps the huggingface CLIP processor to allow backprop through the image processing step.
35
+ The original processor forces conversion to PIL images, which breaks gradient flow.
36
+ """
37
+ def __init__(self, device="cuda") -> None:
38
+ self.device = device
39
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
40
+ self.image_mean = [0.48145466, 0.4578275, 0.40821073]
41
+ self.image_std = [0.26862954, 0.26130258, 0.27577711]
42
+ self.normalize = torchvision.transforms.Normalize(
43
+ self.image_mean,
44
+ self.image_std
45
+ )
46
+ self.resize = torchvision.transforms.Resize(224)
47
+ self.center_crop = torchvision.transforms.CenterCrop(224)
48
+ def preprocess_img(self, images):
49
+ images = self.center_crop(images)
50
+ images = self.resize(images)
51
+ images = self.center_crop(images)
52
+ images = self.normalize(images)
53
+ return images
54
+ def __call__(self, images=[], **kwargs):
55
+ processed_inputs = self.processor(**kwargs)
56
+ processed_inputs["pixel_values"] = self.preprocess_img(images)
57
+ processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()}
58
+ return processed_inputs
59
+
60
+ class ImagePromptOptimizer(nn.Module):
61
+ def __init__(self,
62
+ vqgan,
63
+ clip,
64
+ clip_preprocessor,
65
+ lpips_fn,
66
+ iterations=100,
67
+ lr = 0.01,
68
+ save_vector=True,
69
+ return_val="vector",
70
+ quantize=True,
71
+ make_grid=False,
72
+ lpips_weight = 6.2) -> None:
73
+
74
+ super().__init__()
75
+ self.latent = None
76
+ self.device = vqgan.device
77
+ vqgan.eval()
78
+ self.vqgan = vqgan
79
+ self.clip = clip
80
+ self.iterations = iterations
81
+ self.lr = lr
82
+ self.clip_preprocessor = clip_preprocessor
83
+ self.make_grid = make_grid
84
+ self.return_val = return_val
85
+ self.quantize = quantize
86
+ self.lpips_weight = lpips_weight
87
+ self.perceptual_loss = lpips_fn
88
+ def set_latent(self, latent):
89
+ self.latent = latent.detach().to(self.device)
90
+ def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
91
+ self.attn_mask = attn_mask
92
+ self.iterations = iterations
93
+ self.lr = lr
94
+ self.lpips_weight = lpips_weight
95
+ self.reconstruction_steps = reconstruction_steps
96
+ def forward(self, vector):
97
+ base_latent = self.latent.detach().requires_grad_()
98
+ trans_latent = base_latent + vector
99
+ if self.quantize:
100
+ z_q, *_ = self.vqgan.quantize(trans_latent)
101
+ else:
102
+ z_q = trans_latent
103
+ dec = self.vqgan.decode(z_q)
104
+ return dec
105
+ def _get_clip_similarity(self, prompts, image, weights=None):
106
+ if isinstance(prompts, str):
107
+ prompts = [prompts]
108
+ elif not isinstance(prompts, list):
109
+ raise TypeError("Provide prompts as string or list of strings")
110
+ clip_inputs = self.clip_preprocessor(text=prompts,
111
+ images=image, return_tensors="pt", padding=True)
112
+ clip_outputs = self.clip(**clip_inputs)
113
+ similarity_logits = clip_outputs.logits_per_image
114
+ if weights:
115
+ similarity_logits *= weights
116
+ return similarity_logits.sum()
117
+ def get_similarity_loss(self, pos_prompts, neg_prompts, image):
118
+ pos_logits = self._get_clip_similarity(pos_prompts, image)
119
+ if neg_prompts:
120
+ neg_logits = self._get_clip_similarity(neg_prompts, image)
121
+ else:
122
+ neg_logits = torch.tensor([1], device=self.device)
123
+ loss = -torch.log(pos_logits) + torch.log(neg_logits)
124
+ return loss
125
+ def visualize(self, processed_img):
126
+ if self.make_grid:
127
+ self.index += 1
128
+ plt.subplot(1, 13, self.index)
129
+ plt.imshow(get_pil(processed_img[0]).detach().cpu())
130
+ else:
131
+ plt.imshow(get_pil(processed_img[0]).detach().cpu())
132
+ plt.show()
133
+ def attn_masking(self, grad):
134
+ # print("attnmask 1")
135
+ # print(f"input grad.shape = {grad.shape}")
136
+ # print(f"input grad = {get_resized_tensor(grad)}")
137
+ newgrad = grad
138
+ if self.attn_mask is not None:
139
+ # print("masking mult")
140
+ newgrad = grad * (self.attn_mask)
141
+ # print("output grad, ", get_resized_tensor(newgrad))
142
+ # print("end atn 1")
143
+ return newgrad
144
+ def attn_masking2(self, grad):
145
+ # print("attnmask 2")
146
+ # print(f"input grad.shape = {grad.shape}")
147
+ # print(f"input grad = {get_resized_tensor(grad)}")
148
+ newgrad = grad
149
+ if self.attn_mask is not None:
150
+ # print("masking mult")
151
+ newgrad = grad * ((self.attn_mask - 1) * -1)
152
+ # print("output grad, ", get_resized_tensor(newgrad))
153
+ # print("end atn 2")
154
+ return newgrad
155
+
156
+ def optimize(self, latent, pos_prompts, neg_prompts):
157
+ self.set_latent(latent)
158
+ # self.make_grid=True
159
+ transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
160
+ original_img = loop_post_process(transformed_img)
161
+ vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
162
+ optim = torch.optim.Adam([vector], lr=self.lr)
163
+ if self.make_grid:
164
+ plt.figure(figsize=(35, 25))
165
+ self.index = 1
166
+ for i in tqdm(range(self.iterations)):
167
+ optim.zero_grad()
168
+ transformed_img = self(vector)
169
+ processed_img = loop_post_process(transformed_img) #* self.attn_mask
170
+ processed_img.retain_grad()
171
+ lpips_input = processed_img.clone()
172
+ lpips_input.register_hook(self.attn_masking2)
173
+ lpips_input.retain_grad()
174
+ clip_clone = processed_img.clone()
175
+ clip_clone.register_hook(self.attn_masking)
176
+ clip_clone.retain_grad()
177
+ with torch.autocast("cuda"):
178
+ clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
179
+ print("CLIP loss", clip_loss)
180
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
181
+ print("LPIPS loss: ", perceptual_loss)
182
+ if log:
183
+ wandb.log({"Perceptual Loss": perceptual_loss})
184
+ wandb.log({"CLIP Loss": clip_loss})
185
+ clip_loss.backward(retain_graph=True)
186
+ perceptual_loss.backward(retain_graph=True)
187
+ p2 = processed_img.grad
188
+ print("Sum Loss", perceptual_loss + clip_loss)
189
+ optim.step()
190
+ # if i % self.iterations // 10 == 0:
191
+ # self.visualize(transformed_img)
192
+ yield vector
193
+ if self.make_grid:
194
+ plt.savefig(f"plot {pos_prompts[0]}.png")
195
+ plt.show()
196
+ print("lpips solo op")
197
+ for i in range(self.reconstruction_steps):
198
+ optim.zero_grad()
199
+ transformed_img = self(vector)
200
+ processed_img = loop_post_process(transformed_img) #* self.attn_mask
201
+ processed_img.retain_grad()
202
+ lpips_input = processed_img.clone()
203
+ lpips_input.register_hook(self.attn_masking2)
204
+ lpips_input.retain_grad()
205
+ with torch.autocast("cuda"):
206
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
207
+ if log:
208
+ wandb.log({"Perceptual Loss": perceptual_loss})
209
+ print("LPIPS loss: ", perceptual_loss)
210
+ perceptual_loss.backward(retain_graph=True)
211
+ optim.step()
212
+ yield vector
213
+ # torch.save(vector, "nose_vector.pt")
214
+ # print("")
215
+ # print("DISC STEPS")
216
+ # print("*************")
217
+ # for i in range(self.reconstruction_steps):
218
+ # optim.zero_grad()
219
+ # transformed_img = self(vector)
220
+ # processed_img = loop_post_process(transformed_img) #* self.attn_mask
221
+ # disc_logits = self.disc(transformed_img)
222
+ # disc_loss = self.disc_loss_fn(disc_logits)
223
+ # print(f"disc_loss = {disc_loss}")
224
+ # if log:
225
+ # wandb.log({"Disc Loss": disc_loss})
226
+ # print("LPIPS loss: ", perceptual_loss)
227
+ # disc_loss.backward(retain_graph=True)
228
+ # optim.step()
229
+ # yield vector
230
+ yield vector if self.return_val == "vector" else self.latent + vector