Spaces:
Configuration error
Configuration error
refactoring optimization loop
Browse files- ImageState.py +4 -4
- animation.py +0 -4
- backend.py +49 -47
ImageState.py
CHANGED
@@ -102,7 +102,7 @@ class ImageState:
|
|
102 |
x = Image.fromarray(x, "L")
|
103 |
return x
|
104 |
|
105 |
-
@torch.
|
106 |
def _render_all_transformations(self, return_twice=True):
|
107 |
global num
|
108 |
current_vector_transforms = (
|
@@ -150,7 +150,7 @@ class ImageState:
|
|
150 |
clear_img_dir(self.img_dir)
|
151 |
return self.blend(blend_weight)
|
152 |
|
153 |
-
@torch.
|
154 |
def blend(self, weight):
|
155 |
_, latent = blend_paths(
|
156 |
self.vqgan,
|
@@ -163,7 +163,7 @@ class ImageState:
|
|
163 |
self.blend_latent = latent
|
164 |
return self._render_all_transformations()
|
165 |
|
166 |
-
@torch.
|
167 |
def rewind(self, index):
|
168 |
if not self.transform_history:
|
169 |
print("No history")
|
@@ -221,7 +221,7 @@ class ImageState:
|
|
221 |
):
|
222 |
transform_log.transforms.append(transform.detach().cpu())
|
223 |
self.current_prompt_transforms[-1] = transform
|
224 |
-
with torch.
|
225 |
image = self._render_all_transformations(return_twice=False)
|
226 |
if log:
|
227 |
wandb.log({"image": wandb.Image(image)})
|
|
|
102 |
x = Image.fromarray(x, "L")
|
103 |
return x
|
104 |
|
105 |
+
@torch.no_grad()
|
106 |
def _render_all_transformations(self, return_twice=True):
|
107 |
global num
|
108 |
current_vector_transforms = (
|
|
|
150 |
clear_img_dir(self.img_dir)
|
151 |
return self.blend(blend_weight)
|
152 |
|
153 |
+
@torch.no_grad()
|
154 |
def blend(self, weight):
|
155 |
_, latent = blend_paths(
|
156 |
self.vqgan,
|
|
|
163 |
self.blend_latent = latent
|
164 |
return self._render_all_transformations()
|
165 |
|
166 |
+
@torch.no_grad()
|
167 |
def rewind(self, index):
|
168 |
if not self.transform_history:
|
169 |
print("No history")
|
|
|
221 |
):
|
222 |
transform_log.transforms.append(transform.detach().cpu())
|
223 |
self.current_prompt_transforms[-1] = transform
|
224 |
+
with torch.no_grad():
|
225 |
image = self._render_all_transformations(return_twice=False)
|
226 |
if log:
|
227 |
wandb.log({"image": wandb.Image(image)})
|
animation.py
CHANGED
@@ -4,10 +4,6 @@ import os
|
|
4 |
|
5 |
|
6 |
def clear_img_dir(img_dir):
|
7 |
-
if not os.path.exists("img_history"):
|
8 |
-
os.mkdir("img_history")
|
9 |
-
if not os.path.exists(img_dir):
|
10 |
-
os.mkdir(img_dir)
|
11 |
for filename in glob.glob(img_dir + "/*"):
|
12 |
os.remove(filename)
|
13 |
|
|
|
4 |
|
5 |
|
6 |
def clear_img_dir(img_dir):
|
|
|
|
|
|
|
|
|
7 |
for filename in glob.glob(img_dir + "/*"):
|
8 |
os.remove(filename)
|
9 |
|
backend.py
CHANGED
@@ -140,7 +140,7 @@ class ImagePromptEditor(nn.Module):
|
|
140 |
return newgrad
|
141 |
|
142 |
def _get_next_inputs(self, transformed_img):
|
143 |
-
processed_img = loop_post_process(transformed_img)
|
144 |
processed_img.retain_grad()
|
145 |
|
146 |
lpips_input = processed_img.clone()
|
@@ -154,51 +154,53 @@ class ImagePromptEditor(nn.Module):
|
|
154 |
return (processed_img, lpips_input, clip_input)
|
155 |
|
156 |
def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts):
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
with torch.autocast("cuda"):
|
163 |
-
clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
|
164 |
-
print("CLIP loss", clip_loss)
|
165 |
-
perceptual_loss = (
|
166 |
-
self.perceptual_loss(lpips_input, original_img.clone())
|
167 |
-
* self.lpips_weight
|
168 |
)
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
def _optimize_LPIPS(self, vector, original_img, optim):
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
202 |
|
203 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
204 |
self.set_latent(latent)
|
@@ -209,10 +211,10 @@ class ImagePromptEditor(nn.Module):
|
|
209 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
210 |
optim = torch.optim.Adam([vector], lr=self.lr)
|
211 |
|
212 |
-
for
|
213 |
-
yield
|
214 |
|
215 |
print("Running LPIPS optim only")
|
216 |
-
for
|
217 |
-
yield
|
218 |
yield vector if self.return_val == "vector" else self.latent + vector
|
|
|
140 |
return newgrad
|
141 |
|
142 |
def _get_next_inputs(self, transformed_img):
|
143 |
+
processed_img = loop_post_process(transformed_img)
|
144 |
processed_img.retain_grad()
|
145 |
|
146 |
lpips_input = processed_img.clone()
|
|
|
154 |
return (processed_img, lpips_input, clip_input)
|
155 |
|
156 |
def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts):
|
157 |
+
for i in (range(self.iterations)):
|
158 |
+
optim.zero_grad()
|
159 |
+
transformed_img = self(vector)
|
160 |
+
processed_img, lpips_input, clip_input = self._get_next_inputs(
|
161 |
+
transformed_img
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
)
|
163 |
+
with torch.autocast("cuda"):
|
164 |
+
clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
|
165 |
+
print("CLIP loss", clip_loss)
|
166 |
+
perceptual_loss = (
|
167 |
+
self.perceptual_loss(lpips_input, original_img.clone())
|
168 |
+
* self.lpips_weight
|
169 |
+
)
|
170 |
+
print("LPIPS loss: ", perceptual_loss)
|
171 |
+
print("Sum Loss", perceptual_loss + clip_loss)
|
172 |
+
if log:
|
173 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
174 |
+
wandb.log({"CLIP Loss": clip_loss})
|
175 |
+
|
176 |
+
# These gradients will be masked if attn_mask has been set
|
177 |
+
clip_loss.backward(retain_graph=True)
|
178 |
+
perceptual_loss.backward(retain_graph=True)
|
179 |
+
|
180 |
+
optim.step()
|
181 |
+
yield vector
|
182 |
|
183 |
def _optimize_LPIPS(self, vector, original_img, optim):
|
184 |
+
for i in range(self.reconstruction_steps):
|
185 |
+
optim.zero_grad()
|
186 |
+
transformed_img = self(vector)
|
187 |
+
processed_img = loop_post_process(transformed_img)
|
188 |
+
processed_img.retain_grad()
|
189 |
+
|
190 |
+
lpips_input = processed_img.clone()
|
191 |
+
lpips_input.register_hook(self._attn_mask_inverse)
|
192 |
+
lpips_input.retain_grad()
|
193 |
+
with torch.autocast("cuda"):
|
194 |
+
perceptual_loss = (
|
195 |
+
self.perceptual_loss(lpips_input, original_img.clone())
|
196 |
+
* self.lpips_weight
|
197 |
+
)
|
198 |
+
if log:
|
199 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
200 |
+
print("LPIPS loss: ", perceptual_loss)
|
201 |
+
perceptual_loss.backward(retain_graph=True)
|
202 |
+
optim.step()
|
203 |
+
yield vector
|
204 |
|
205 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
206 |
self.set_latent(latent)
|
|
|
211 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
212 |
optim = torch.optim.Adam([vector], lr=self.lr)
|
213 |
|
214 |
+
for transform in self._optimize_CLIP_LPIPS(optim, original_img, vector, pos_prompts, neg_prompts):
|
215 |
+
yield transform
|
216 |
|
217 |
print("Running LPIPS optim only")
|
218 |
+
for transform in self._optimize_LPIPS(vector, original_img, optim):
|
219 |
+
yield transform
|
220 |
yield vector if self.return_val == "vector" else self.latent + vector
|