Spaces:
Configuration error
Configuration error
reduce memory usage
Browse files- ImageState.py +7 -3
- app.py +18 -13
- app_backend.py +0 -230
- masking.py +1 -1
ImageState.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# from align import align_from_path
|
2 |
from animation import clear_img_dir
|
3 |
-
from
|
4 |
import importlib
|
5 |
import gradio as gr
|
6 |
import matplotlib.pyplot as plt
|
@@ -13,12 +13,13 @@ from torchvision.transforms.functional import resize
|
|
13 |
from tqdm import tqdm
|
14 |
from transformers import CLIPModel, CLIPProcessor
|
15 |
import lpips
|
16 |
-
from
|
17 |
from edit import blend_paths
|
18 |
from img_processing import *
|
19 |
from img_processing import custom_to_pil
|
20 |
from loaders import load_default
|
21 |
-
|
|
|
22 |
num = 0
|
23 |
class PromptTransformHistory():
|
24 |
def __init__(self, iterations) -> None:
|
@@ -27,6 +28,7 @@ class PromptTransformHistory():
|
|
27 |
|
28 |
class ImageState:
|
29 |
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
|
|
|
30 |
self.vqgan = vqgan
|
31 |
self.device = vqgan.device
|
32 |
self.blend_latent = None
|
@@ -59,6 +61,7 @@ class ImageState:
|
|
59 |
new_latent = torch.lerp(src, src + vector, 1)
|
60 |
return new_latent
|
61 |
def _decode_latent_to_pil(self, latent):
|
|
|
62 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
63 |
return custom_to_pil(current_im)
|
64 |
# def _get_current_vector_transforms(self):
|
@@ -95,6 +98,7 @@ class ImageState:
|
|
95 |
@torch.no_grad()
|
96 |
def _render_all_transformations(self, return_twice=True):
|
97 |
global num
|
|
|
98 |
current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
|
99 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
100 |
if self.quant:
|
|
|
1 |
# from align import align_from_path
|
2 |
from animation import clear_img_dir
|
3 |
+
from backend import ImagePromptOptimizer, log
|
4 |
import importlib
|
5 |
import gradio as gr
|
6 |
import matplotlib.pyplot as plt
|
|
|
13 |
from tqdm import tqdm
|
14 |
from transformers import CLIPModel, CLIPProcessor
|
15 |
import lpips
|
16 |
+
from backend import get_resized_tensor
|
17 |
from edit import blend_paths
|
18 |
from img_processing import *
|
19 |
from img_processing import custom_to_pil
|
20 |
from loaders import load_default
|
21 |
+
# from app import vqgan
|
22 |
+
global vqgan
|
23 |
num = 0
|
24 |
class PromptTransformHistory():
|
25 |
def __init__(self, iterations) -> None:
|
|
|
28 |
|
29 |
class ImageState:
|
30 |
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
|
31 |
+
# global vqgan
|
32 |
self.vqgan = vqgan
|
33 |
self.device = vqgan.device
|
34 |
self.blend_latent = None
|
|
|
61 |
new_latent = torch.lerp(src, src + vector, 1)
|
62 |
return new_latent
|
63 |
def _decode_latent_to_pil(self, latent):
|
64 |
+
# global vqgan
|
65 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
66 |
return custom_to_pil(current_im)
|
67 |
# def _get_current_vector_transforms(self):
|
|
|
98 |
@torch.no_grad()
|
99 |
def _render_all_transformations(self, return_twice=True):
|
100 |
global num
|
101 |
+
# global vqgan
|
102 |
current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
|
103 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
104 |
if self.quant:
|
app.py
CHANGED
@@ -3,29 +3,33 @@ import os
|
|
3 |
import sys
|
4 |
|
5 |
import wandb
|
|
|
6 |
|
7 |
from configs import set_major_global, set_major_local, set_small_local
|
8 |
|
9 |
sys.path.append("taming-transformers")
|
10 |
-
import functools
|
11 |
|
12 |
import gradio as gr
|
13 |
from transformers import CLIPModel, CLIPProcessor
|
|
|
14 |
|
15 |
import edit
|
16 |
-
from
|
17 |
from ImageState import ImageState
|
18 |
from loaders import load_default
|
19 |
from animation import create_gif
|
20 |
from prompts import get_random_prompts
|
21 |
|
22 |
-
device = "cpu"
|
|
|
|
|
23 |
vqgan = load_default(device)
|
24 |
vqgan.eval()
|
25 |
processor = ProcessorGradientFlow(device=device)
|
26 |
-
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
27 |
-
|
28 |
-
|
|
|
29 |
def set_img_from_example(state, img):
|
30 |
return state.update_images(img, img, 0)
|
31 |
def get_cleared_mask():
|
@@ -40,7 +44,8 @@ class StateWrapper:
|
|
40 |
def apply_lip_vector(state, *args, **kwargs):
|
41 |
return state, *state[0].apply_lip_vector(*args, **kwargs)
|
42 |
def apply_prompts(state, *args, **kwargs):
|
43 |
-
|
|
|
44 |
def apply_rb_vector(state, *args, **kwargs):
|
45 |
return state, *state[0].apply_rb_vector(*args, **kwargs)
|
46 |
def blend(state, *args, **kwargs):
|
@@ -56,7 +61,7 @@ class StateWrapper:
|
|
56 |
def rewind(state, *args, **kwargs):
|
57 |
return state, *state[0].rewind(*args, **kwargs)
|
58 |
def set_mask(state, *args, **kwargs):
|
59 |
-
return state,
|
60 |
def update_images(state, *args, **kwargs):
|
61 |
return state, *state[0].update_images(*args, **kwargs)
|
62 |
def update_requant(state, *args, **kwargs):
|
@@ -191,7 +196,7 @@ with gr.Blocks(css="styles.css") as demo:
|
|
191 |
# step=1,
|
192 |
# value=0,
|
193 |
# label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
|
194 |
-
|
195 |
asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
|
196 |
lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
|
197 |
# hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
|
@@ -200,11 +205,11 @@ with gr.Blocks(css="styles.css") as demo:
|
|
200 |
# requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
|
201 |
base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
|
202 |
blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
|
203 |
-
small_local.click(set_small_local, outputs=[
|
204 |
-
major_local.click(set_major_local, outputs=[
|
205 |
-
major_global.click(set_major_global, outputs=[
|
206 |
apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
|
207 |
rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
|
208 |
-
set_mask.click(StateWrapper.set_mask, inputs=mask, outputs=testim)
|
209 |
demo.queue()
|
210 |
demo.launch(debug=True, enable_queue=True)
|
|
|
3 |
import sys
|
4 |
|
5 |
import wandb
|
6 |
+
import torch
|
7 |
|
8 |
from configs import set_major_global, set_major_local, set_small_local
|
9 |
|
10 |
sys.path.append("taming-transformers")
|
|
|
11 |
|
12 |
import gradio as gr
|
13 |
from transformers import CLIPModel, CLIPProcessor
|
14 |
+
from lpips import LPIPS
|
15 |
|
16 |
import edit
|
17 |
+
from backend import ImagePromptOptimizer, ProcessorGradientFlow
|
18 |
from ImageState import ImageState
|
19 |
from loaders import load_default
|
20 |
from animation import create_gif
|
21 |
from prompts import get_random_prompts
|
22 |
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
|
25 |
+
global vqgan
|
26 |
vqgan = load_default(device)
|
27 |
vqgan.eval()
|
28 |
processor = ProcessorGradientFlow(device=device)
|
29 |
+
# clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
30 |
+
lpips_fn = LPIPS(net='vgg').to(device)
|
31 |
+
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
32 |
+
promptoptim = ImagePromptOptimizer(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
|
33 |
def set_img_from_example(state, img):
|
34 |
return state.update_images(img, img, 0)
|
35 |
def get_cleared_mask():
|
|
|
44 |
def apply_lip_vector(state, *args, **kwargs):
|
45 |
return state, *state[0].apply_lip_vector(*args, **kwargs)
|
46 |
def apply_prompts(state, *args, **kwargs):
|
47 |
+
for image in state[0].apply_prompts(*args, **kwargs):
|
48 |
+
yield state, *image
|
49 |
def apply_rb_vector(state, *args, **kwargs):
|
50 |
return state, *state[0].apply_rb_vector(*args, **kwargs)
|
51 |
def blend(state, *args, **kwargs):
|
|
|
61 |
def rewind(state, *args, **kwargs):
|
62 |
return state, *state[0].rewind(*args, **kwargs)
|
63 |
def set_mask(state, *args, **kwargs):
|
64 |
+
return state, state[0].set_mask(*args, **kwargs)
|
65 |
def update_images(state, *args, **kwargs):
|
66 |
return state, *state[0].update_images(*args, **kwargs)
|
67 |
def update_requant(state, *args, **kwargs):
|
|
|
196 |
# step=1,
|
197 |
# value=0,
|
198 |
# label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
|
199 |
+
clear.click(state.clear_transforms, inputs=[state], outputs=[state, out, mask])
|
200 |
asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
|
201 |
lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
|
202 |
# hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
|
|
|
205 |
# requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
|
206 |
base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
|
207 |
blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
|
208 |
+
small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
209 |
+
major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
210 |
+
major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
211 |
apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
|
212 |
rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
|
213 |
+
set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
|
214 |
demo.queue()
|
215 |
demo.launch(debug=True, enable_queue=True)
|
app_backend.py
DELETED
@@ -1,230 +0,0 @@
|
|
1 |
-
# from functools import cache
|
2 |
-
import importlib
|
3 |
-
|
4 |
-
import gradio as gr
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
import torch
|
7 |
-
import torchvision
|
8 |
-
import wandb
|
9 |
-
from icecream import ic
|
10 |
-
from torch import nn
|
11 |
-
from torchvision.transforms.functional import resize
|
12 |
-
from tqdm import tqdm
|
13 |
-
from transformers import CLIPModel, CLIPProcessor
|
14 |
-
import lpips
|
15 |
-
from edit import blend_paths
|
16 |
-
from img_processing import *
|
17 |
-
from img_processing import custom_to_pil
|
18 |
-
from loaders import load_default
|
19 |
-
import glob
|
20 |
-
# global log
|
21 |
-
log=False
|
22 |
-
|
23 |
-
# ic.disable()
|
24 |
-
# ic.enable()
|
25 |
-
def get_resized_tensor(x):
|
26 |
-
if len(x.shape) == 2:
|
27 |
-
re = x.unsqueeze(0)
|
28 |
-
else: re = x
|
29 |
-
re = resize(re, (10, 10))
|
30 |
-
return re
|
31 |
-
class ProcessorGradientFlow():
|
32 |
-
"""
|
33 |
-
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
34 |
-
The original processor forces conversion to PIL images, which breaks gradient flow.
|
35 |
-
"""
|
36 |
-
def __init__(self, device="cuda") -> None:
|
37 |
-
self.device = device
|
38 |
-
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
39 |
-
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
40 |
-
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
41 |
-
self.normalize = torchvision.transforms.Normalize(
|
42 |
-
self.image_mean,
|
43 |
-
self.image_std
|
44 |
-
)
|
45 |
-
self.resize = torchvision.transforms.Resize(224)
|
46 |
-
self.center_crop = torchvision.transforms.CenterCrop(224)
|
47 |
-
def preprocess_img(self, images):
|
48 |
-
images = self.center_crop(images)
|
49 |
-
images = self.resize(images)
|
50 |
-
images = self.center_crop(images)
|
51 |
-
images = self.normalize(images)
|
52 |
-
return images
|
53 |
-
def __call__(self, images=[], **kwargs):
|
54 |
-
processed_inputs = self.processor(**kwargs)
|
55 |
-
processed_inputs["pixel_values"] = self.preprocess_img(images)
|
56 |
-
processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()}
|
57 |
-
return processed_inputs
|
58 |
-
|
59 |
-
class ImagePromptOptimizer(nn.Module):
|
60 |
-
def __init__(self,
|
61 |
-
vqgan,
|
62 |
-
clip,
|
63 |
-
clip_preprocessor,
|
64 |
-
iterations=100,
|
65 |
-
lr = 0.01,
|
66 |
-
save_vector=True,
|
67 |
-
return_val="vector",
|
68 |
-
quantize=True,
|
69 |
-
make_grid=False,
|
70 |
-
lpips_weight = 6.2) -> None:
|
71 |
-
|
72 |
-
super().__init__()
|
73 |
-
self.latent = None
|
74 |
-
self.device = vqgan.device
|
75 |
-
vqgan.eval()
|
76 |
-
self.vqgan = vqgan
|
77 |
-
self.clip = clip
|
78 |
-
self.iterations = iterations
|
79 |
-
self.lr = lr
|
80 |
-
self.clip_preprocessor = clip_preprocessor
|
81 |
-
self.make_grid = make_grid
|
82 |
-
self.return_val = return_val
|
83 |
-
self.quantize = quantize
|
84 |
-
self.lpips_weight = lpips_weight
|
85 |
-
self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
|
86 |
-
def disc_loss_fn(self, logits):
|
87 |
-
return -torch.mean(logits)
|
88 |
-
def set_latent(self, latent):
|
89 |
-
self.latent = latent.detach().to(self.device)
|
90 |
-
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
91 |
-
self.attn_mask = attn_mask
|
92 |
-
self.iterations = iterations
|
93 |
-
self.lr = lr
|
94 |
-
self.lpips_weight = lpips_weight
|
95 |
-
self.reconstruction_steps = reconstruction_steps
|
96 |
-
def forward(self, vector):
|
97 |
-
base_latent = self.latent.detach().requires_grad_()
|
98 |
-
trans_latent = base_latent + vector
|
99 |
-
if self.quantize:
|
100 |
-
z_q, *_ = self.vqgan.quantize(trans_latent)
|
101 |
-
else:
|
102 |
-
z_q = trans_latent
|
103 |
-
dec = self.vqgan.decode(z_q)
|
104 |
-
return dec
|
105 |
-
def _get_clip_similarity(self, prompts, image, weights=None):
|
106 |
-
if isinstance(prompts, str):
|
107 |
-
prompts = [prompts]
|
108 |
-
elif not isinstance(prompts, list):
|
109 |
-
raise TypeError("Provide prompts as string or list of strings")
|
110 |
-
clip_inputs = self.clip_preprocessor(text=prompts,
|
111 |
-
images=image, return_tensors="pt", padding=True)
|
112 |
-
clip_outputs = self.clip(**clip_inputs)
|
113 |
-
similarity_logits = clip_outputs.logits_per_image
|
114 |
-
if weights:
|
115 |
-
similarity_logits *= weights
|
116 |
-
return similarity_logits.sum()
|
117 |
-
def get_similarity_loss(self, pos_prompts, neg_prompts, image):
|
118 |
-
pos_logits = self._get_clip_similarity(pos_prompts, image)
|
119 |
-
if neg_prompts:
|
120 |
-
neg_logits = self._get_clip_similarity(neg_prompts, image)
|
121 |
-
else:
|
122 |
-
neg_logits = torch.tensor([1], device=self.device)
|
123 |
-
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
124 |
-
return loss
|
125 |
-
def visualize(self, processed_img):
|
126 |
-
if self.make_grid:
|
127 |
-
self.index += 1
|
128 |
-
plt.subplot(1, 13, self.index)
|
129 |
-
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
130 |
-
else:
|
131 |
-
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
132 |
-
plt.show()
|
133 |
-
def attn_masking(self, grad):
|
134 |
-
# print("attnmask 1")
|
135 |
-
# print(f"input grad.shape = {grad.shape}")
|
136 |
-
# print(f"input grad = {get_resized_tensor(grad)}")
|
137 |
-
newgrad = grad
|
138 |
-
if self.attn_mask is not None:
|
139 |
-
# print("masking mult")
|
140 |
-
newgrad = grad * (self.attn_mask)
|
141 |
-
# print("output grad, ", get_resized_tensor(newgrad))
|
142 |
-
# print("end atn 1")
|
143 |
-
return newgrad
|
144 |
-
def attn_masking2(self, grad):
|
145 |
-
# print("attnmask 2")
|
146 |
-
# print(f"input grad.shape = {grad.shape}")
|
147 |
-
# print(f"input grad = {get_resized_tensor(grad)}")
|
148 |
-
newgrad = grad
|
149 |
-
if self.attn_mask is not None:
|
150 |
-
# print("masking mult")
|
151 |
-
newgrad = grad * ((self.attn_mask - 1) * -1)
|
152 |
-
# print("output grad, ", get_resized_tensor(newgrad))
|
153 |
-
# print("end atn 2")
|
154 |
-
return newgrad
|
155 |
-
|
156 |
-
def optimize(self, latent, pos_prompts, neg_prompts):
|
157 |
-
self.set_latent(latent)
|
158 |
-
# self.make_grid=True
|
159 |
-
transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
|
160 |
-
original_img = loop_post_process(transformed_img)
|
161 |
-
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
162 |
-
optim = torch.optim.Adam([vector], lr=self.lr)
|
163 |
-
if self.make_grid:
|
164 |
-
plt.figure(figsize=(35, 25))
|
165 |
-
self.index = 1
|
166 |
-
for i in tqdm(range(self.iterations)):
|
167 |
-
optim.zero_grad()
|
168 |
-
transformed_img = self(vector)
|
169 |
-
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
170 |
-
processed_img.retain_grad()
|
171 |
-
lpips_input = processed_img.clone()
|
172 |
-
lpips_input.register_hook(self.attn_masking2)
|
173 |
-
lpips_input.retain_grad()
|
174 |
-
clip_clone = processed_img.clone()
|
175 |
-
clip_clone.register_hook(self.attn_masking)
|
176 |
-
clip_clone.retain_grad()
|
177 |
-
with torch.autocast("cuda"):
|
178 |
-
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
|
179 |
-
print("CLIP loss", clip_loss)
|
180 |
-
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
181 |
-
print("LPIPS loss: ", perceptual_loss)
|
182 |
-
if log:
|
183 |
-
wandb.log({"Perceptual Loss": perceptual_loss})
|
184 |
-
wandb.log({"CLIP Loss": clip_loss})
|
185 |
-
clip_loss.backward(retain_graph=True)
|
186 |
-
perceptual_loss.backward(retain_graph=True)
|
187 |
-
p2 = processed_img.grad
|
188 |
-
print("Sum Loss", perceptual_loss + clip_loss)
|
189 |
-
optim.step()
|
190 |
-
# if i % self.iterations // 10 == 0:
|
191 |
-
# self.visualize(transformed_img)
|
192 |
-
yield vector
|
193 |
-
if self.make_grid:
|
194 |
-
plt.savefig(f"plot {pos_prompts[0]}.png")
|
195 |
-
plt.show()
|
196 |
-
print("lpips solo op")
|
197 |
-
for i in range(self.reconstruction_steps):
|
198 |
-
optim.zero_grad()
|
199 |
-
transformed_img = self(vector)
|
200 |
-
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
201 |
-
processed_img.retain_grad()
|
202 |
-
lpips_input = processed_img.clone()
|
203 |
-
lpips_input.register_hook(self.attn_masking2)
|
204 |
-
lpips_input.retain_grad()
|
205 |
-
with torch.autocast("cuda"):
|
206 |
-
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
207 |
-
if log:
|
208 |
-
wandb.log({"Perceptual Loss": perceptual_loss})
|
209 |
-
print("LPIPS loss: ", perceptual_loss)
|
210 |
-
perceptual_loss.backward(retain_graph=True)
|
211 |
-
optim.step()
|
212 |
-
yield vector
|
213 |
-
# torch.save(vector, "nose_vector.pt")
|
214 |
-
# print("")
|
215 |
-
# print("DISC STEPS")
|
216 |
-
# print("*************")
|
217 |
-
# for i in range(self.reconstruction_steps):
|
218 |
-
# optim.zero_grad()
|
219 |
-
# transformed_img = self(vector)
|
220 |
-
# processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
221 |
-
# disc_logits = self.disc(transformed_img)
|
222 |
-
# disc_loss = self.disc_loss_fn(disc_logits)
|
223 |
-
# print(f"disc_loss = {disc_loss}")
|
224 |
-
# if log:
|
225 |
-
# wandb.log({"Disc Loss": disc_loss})
|
226 |
-
# print("LPIPS loss: ", perceptual_loss)
|
227 |
-
# disc_loss.backward(retain_graph=True)
|
228 |
-
# optim.step()
|
229 |
-
# yield vector
|
230 |
-
yield vector if self.return_val == "vector" else self.latent + vector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masking.py
CHANGED
@@ -13,7 +13,7 @@ from transformers import CLIPModel, CLIPProcessor
|
|
13 |
import edit
|
14 |
# import importlib
|
15 |
# importlib.reload(edit)
|
16 |
-
from
|
17 |
from loaders import load_default
|
18 |
|
19 |
device = "cuda"
|
|
|
13 |
import edit
|
14 |
# import importlib
|
15 |
# importlib.reload(edit)
|
16 |
+
from backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
|
17 |
from loaders import load_default
|
18 |
|
19 |
device = "cuda"
|