Spaces:
Configuration error
Configuration error
rename backend
Browse files- backend.py +230 -0
backend.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from functools import cache
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import wandb
|
9 |
+
from icecream import ic
|
10 |
+
from torch import nn
|
11 |
+
from torchvision.transforms.functional import resize
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import CLIPModel, CLIPProcessor
|
14 |
+
import lpips
|
15 |
+
from edit import blend_paths
|
16 |
+
from img_processing import *
|
17 |
+
from img_processing import custom_to_pil
|
18 |
+
from loaders import load_default
|
19 |
+
import glob
|
20 |
+
|
21 |
+
global log
|
22 |
+
log=False
|
23 |
+
|
24 |
+
# ic.disable()
|
25 |
+
# ic.enable()
|
26 |
+
def get_resized_tensor(x):
|
27 |
+
if len(x.shape) == 2:
|
28 |
+
re = x.unsqueeze(0)
|
29 |
+
else: re = x
|
30 |
+
re = resize(re, (10, 10))
|
31 |
+
return re
|
32 |
+
class ProcessorGradientFlow():
|
33 |
+
"""
|
34 |
+
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
35 |
+
The original processor forces conversion to PIL images, which breaks gradient flow.
|
36 |
+
"""
|
37 |
+
def __init__(self, device="cuda") -> None:
|
38 |
+
self.device = device
|
39 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
40 |
+
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
41 |
+
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
42 |
+
self.normalize = torchvision.transforms.Normalize(
|
43 |
+
self.image_mean,
|
44 |
+
self.image_std
|
45 |
+
)
|
46 |
+
self.resize = torchvision.transforms.Resize(224)
|
47 |
+
self.center_crop = torchvision.transforms.CenterCrop(224)
|
48 |
+
def preprocess_img(self, images):
|
49 |
+
images = self.center_crop(images)
|
50 |
+
images = self.resize(images)
|
51 |
+
images = self.center_crop(images)
|
52 |
+
images = self.normalize(images)
|
53 |
+
return images
|
54 |
+
def __call__(self, images=[], **kwargs):
|
55 |
+
processed_inputs = self.processor(**kwargs)
|
56 |
+
processed_inputs["pixel_values"] = self.preprocess_img(images)
|
57 |
+
processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()}
|
58 |
+
return processed_inputs
|
59 |
+
|
60 |
+
class ImagePromptOptimizer(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
vqgan,
|
63 |
+
clip,
|
64 |
+
clip_preprocessor,
|
65 |
+
lpips_fn,
|
66 |
+
iterations=100,
|
67 |
+
lr = 0.01,
|
68 |
+
save_vector=True,
|
69 |
+
return_val="vector",
|
70 |
+
quantize=True,
|
71 |
+
make_grid=False,
|
72 |
+
lpips_weight = 6.2) -> None:
|
73 |
+
|
74 |
+
super().__init__()
|
75 |
+
self.latent = None
|
76 |
+
self.device = vqgan.device
|
77 |
+
vqgan.eval()
|
78 |
+
self.vqgan = vqgan
|
79 |
+
self.clip = clip
|
80 |
+
self.iterations = iterations
|
81 |
+
self.lr = lr
|
82 |
+
self.clip_preprocessor = clip_preprocessor
|
83 |
+
self.make_grid = make_grid
|
84 |
+
self.return_val = return_val
|
85 |
+
self.quantize = quantize
|
86 |
+
self.lpips_weight = lpips_weight
|
87 |
+
self.perceptual_loss = lpips_fn
|
88 |
+
def set_latent(self, latent):
|
89 |
+
self.latent = latent.detach().to(self.device)
|
90 |
+
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
91 |
+
self.attn_mask = attn_mask
|
92 |
+
self.iterations = iterations
|
93 |
+
self.lr = lr
|
94 |
+
self.lpips_weight = lpips_weight
|
95 |
+
self.reconstruction_steps = reconstruction_steps
|
96 |
+
def forward(self, vector):
|
97 |
+
base_latent = self.latent.detach().requires_grad_()
|
98 |
+
trans_latent = base_latent + vector
|
99 |
+
if self.quantize:
|
100 |
+
z_q, *_ = self.vqgan.quantize(trans_latent)
|
101 |
+
else:
|
102 |
+
z_q = trans_latent
|
103 |
+
dec = self.vqgan.decode(z_q)
|
104 |
+
return dec
|
105 |
+
def _get_clip_similarity(self, prompts, image, weights=None):
|
106 |
+
if isinstance(prompts, str):
|
107 |
+
prompts = [prompts]
|
108 |
+
elif not isinstance(prompts, list):
|
109 |
+
raise TypeError("Provide prompts as string or list of strings")
|
110 |
+
clip_inputs = self.clip_preprocessor(text=prompts,
|
111 |
+
images=image, return_tensors="pt", padding=True)
|
112 |
+
clip_outputs = self.clip(**clip_inputs)
|
113 |
+
similarity_logits = clip_outputs.logits_per_image
|
114 |
+
if weights:
|
115 |
+
similarity_logits *= weights
|
116 |
+
return similarity_logits.sum()
|
117 |
+
def get_similarity_loss(self, pos_prompts, neg_prompts, image):
|
118 |
+
pos_logits = self._get_clip_similarity(pos_prompts, image)
|
119 |
+
if neg_prompts:
|
120 |
+
neg_logits = self._get_clip_similarity(neg_prompts, image)
|
121 |
+
else:
|
122 |
+
neg_logits = torch.tensor([1], device=self.device)
|
123 |
+
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
124 |
+
return loss
|
125 |
+
def visualize(self, processed_img):
|
126 |
+
if self.make_grid:
|
127 |
+
self.index += 1
|
128 |
+
plt.subplot(1, 13, self.index)
|
129 |
+
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
130 |
+
else:
|
131 |
+
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
132 |
+
plt.show()
|
133 |
+
def attn_masking(self, grad):
|
134 |
+
# print("attnmask 1")
|
135 |
+
# print(f"input grad.shape = {grad.shape}")
|
136 |
+
# print(f"input grad = {get_resized_tensor(grad)}")
|
137 |
+
newgrad = grad
|
138 |
+
if self.attn_mask is not None:
|
139 |
+
# print("masking mult")
|
140 |
+
newgrad = grad * (self.attn_mask)
|
141 |
+
# print("output grad, ", get_resized_tensor(newgrad))
|
142 |
+
# print("end atn 1")
|
143 |
+
return newgrad
|
144 |
+
def attn_masking2(self, grad):
|
145 |
+
# print("attnmask 2")
|
146 |
+
# print(f"input grad.shape = {grad.shape}")
|
147 |
+
# print(f"input grad = {get_resized_tensor(grad)}")
|
148 |
+
newgrad = grad
|
149 |
+
if self.attn_mask is not None:
|
150 |
+
# print("masking mult")
|
151 |
+
newgrad = grad * ((self.attn_mask - 1) * -1)
|
152 |
+
# print("output grad, ", get_resized_tensor(newgrad))
|
153 |
+
# print("end atn 2")
|
154 |
+
return newgrad
|
155 |
+
|
156 |
+
def optimize(self, latent, pos_prompts, neg_prompts):
|
157 |
+
self.set_latent(latent)
|
158 |
+
# self.make_grid=True
|
159 |
+
transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
|
160 |
+
original_img = loop_post_process(transformed_img)
|
161 |
+
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
162 |
+
optim = torch.optim.Adam([vector], lr=self.lr)
|
163 |
+
if self.make_grid:
|
164 |
+
plt.figure(figsize=(35, 25))
|
165 |
+
self.index = 1
|
166 |
+
for i in tqdm(range(self.iterations)):
|
167 |
+
optim.zero_grad()
|
168 |
+
transformed_img = self(vector)
|
169 |
+
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
170 |
+
processed_img.retain_grad()
|
171 |
+
lpips_input = processed_img.clone()
|
172 |
+
lpips_input.register_hook(self.attn_masking2)
|
173 |
+
lpips_input.retain_grad()
|
174 |
+
clip_clone = processed_img.clone()
|
175 |
+
clip_clone.register_hook(self.attn_masking)
|
176 |
+
clip_clone.retain_grad()
|
177 |
+
with torch.autocast("cuda"):
|
178 |
+
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
|
179 |
+
print("CLIP loss", clip_loss)
|
180 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
181 |
+
print("LPIPS loss: ", perceptual_loss)
|
182 |
+
if log:
|
183 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
184 |
+
wandb.log({"CLIP Loss": clip_loss})
|
185 |
+
clip_loss.backward(retain_graph=True)
|
186 |
+
perceptual_loss.backward(retain_graph=True)
|
187 |
+
p2 = processed_img.grad
|
188 |
+
print("Sum Loss", perceptual_loss + clip_loss)
|
189 |
+
optim.step()
|
190 |
+
# if i % self.iterations // 10 == 0:
|
191 |
+
# self.visualize(transformed_img)
|
192 |
+
yield vector
|
193 |
+
if self.make_grid:
|
194 |
+
plt.savefig(f"plot {pos_prompts[0]}.png")
|
195 |
+
plt.show()
|
196 |
+
print("lpips solo op")
|
197 |
+
for i in range(self.reconstruction_steps):
|
198 |
+
optim.zero_grad()
|
199 |
+
transformed_img = self(vector)
|
200 |
+
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
201 |
+
processed_img.retain_grad()
|
202 |
+
lpips_input = processed_img.clone()
|
203 |
+
lpips_input.register_hook(self.attn_masking2)
|
204 |
+
lpips_input.retain_grad()
|
205 |
+
with torch.autocast("cuda"):
|
206 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
207 |
+
if log:
|
208 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
209 |
+
print("LPIPS loss: ", perceptual_loss)
|
210 |
+
perceptual_loss.backward(retain_graph=True)
|
211 |
+
optim.step()
|
212 |
+
yield vector
|
213 |
+
# torch.save(vector, "nose_vector.pt")
|
214 |
+
# print("")
|
215 |
+
# print("DISC STEPS")
|
216 |
+
# print("*************")
|
217 |
+
# for i in range(self.reconstruction_steps):
|
218 |
+
# optim.zero_grad()
|
219 |
+
# transformed_img = self(vector)
|
220 |
+
# processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
221 |
+
# disc_logits = self.disc(transformed_img)
|
222 |
+
# disc_loss = self.disc_loss_fn(disc_logits)
|
223 |
+
# print(f"disc_loss = {disc_loss}")
|
224 |
+
# if log:
|
225 |
+
# wandb.log({"Disc Loss": disc_loss})
|
226 |
+
# print("LPIPS loss: ", perceptual_loss)
|
227 |
+
# disc_loss.backward(retain_graph=True)
|
228 |
+
# optim.step()
|
229 |
+
# yield vector
|
230 |
+
yield vector if self.return_val == "vector" else self.latent + vector
|