Spaces:
Configuration error
Configuration error
update animation creation
Browse files- ImageState.py +31 -9
- animation.py +2 -3
- app.py +14 -12
- configs.py +1 -1
ImageState.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
# from align import align_from_path
|
|
|
|
|
|
|
2 |
from animation import clear_img_dir
|
3 |
from backend import ImagePromptOptimizer, log
|
4 |
import importlib
|
@@ -38,6 +41,9 @@ class ImageState:
|
|
38 |
self.transform_history = []
|
39 |
self.attn_mask = None
|
40 |
self.prompt_optim = prompt_optimizer
|
|
|
|
|
|
|
41 |
self._load_vectors()
|
42 |
self.init_transforms()
|
43 |
def _load_vectors(self):
|
@@ -45,6 +51,24 @@ class ImageState:
|
|
45 |
self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
|
46 |
self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
|
47 |
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def init_transforms(self):
|
49 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
50 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
@@ -104,10 +128,10 @@ class ImageState:
|
|
104 |
if self.quant:
|
105 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
106 |
image = self._decode_latent_to_pil(new_latent)
|
107 |
-
img_dir =
|
108 |
if not os.path.exists(img_dir):
|
109 |
os.mkdir(img_dir)
|
110 |
-
image.save(f"
|
111 |
num += 1
|
112 |
return (image, image) if return_twice else image
|
113 |
def apply_gp_vector(self, weight):
|
@@ -149,14 +173,12 @@ class ImageState:
|
|
149 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
150 |
print(latent_index)
|
151 |
self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
|
152 |
-
# print(self.current_prompt_transform)
|
153 |
-
# print(self.current_prompt_transforms.mean())
|
154 |
return self._render_all_transformations()
|
155 |
-
def rescale_mask(self, mask):
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
|
161 |
transform_log = PromptTransformHistory(iterations + reconstruction_steps)
|
162 |
transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
|
|
1 |
# from align import align_from_path
|
2 |
+
import imageio
|
3 |
+
import glob
|
4 |
+
import uuid
|
5 |
from animation import clear_img_dir
|
6 |
from backend import ImagePromptOptimizer, log
|
7 |
import importlib
|
|
|
41 |
self.transform_history = []
|
42 |
self.attn_mask = None
|
43 |
self.prompt_optim = prompt_optimizer
|
44 |
+
self.state_id = "./" + str(uuid.uuid4())
|
45 |
+
print("NEW INSTANCE")
|
46 |
+
print(self.state_id)
|
47 |
self._load_vectors()
|
48 |
self.init_transforms()
|
49 |
def _load_vectors(self):
|
|
|
51 |
self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
|
52 |
self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
|
53 |
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
|
54 |
+
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
55 |
+
images = []
|
56 |
+
folder = self.state_id
|
57 |
+
paths = glob.glob(folder + "/*")
|
58 |
+
frame_duration = total_duration / len(paths)
|
59 |
+
print(len(paths), "frame dur", frame_duration)
|
60 |
+
durations = [frame_duration] * len(paths)
|
61 |
+
if extend_frames:
|
62 |
+
durations [0] = 1.5
|
63 |
+
durations [-1] = 3
|
64 |
+
for file_name in os.listdir(folder):
|
65 |
+
if file_name.endswith('.png'):
|
66 |
+
file_path = os.path.join(folder, file_name)
|
67 |
+
images.append(imageio.imread(file_path))
|
68 |
+
# images[0] = images[0].set_meta_data({'duration': 1})
|
69 |
+
# images[-1] = images[-1].set_meta_data({'duration': 1})
|
70 |
+
imageio.mimsave(gif_name, images, duration=durations)
|
71 |
+
return gif_name
|
72 |
def init_transforms(self):
|
73 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
74 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
|
|
128 |
if self.quant:
|
129 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
130 |
image = self._decode_latent_to_pil(new_latent)
|
131 |
+
img_dir = self.state_id
|
132 |
if not os.path.exists(img_dir):
|
133 |
os.mkdir(img_dir)
|
134 |
+
image.save(f"{img_dir}/img_{num:06}.png")
|
135 |
num += 1
|
136 |
return (image, image) if return_twice else image
|
137 |
def apply_gp_vector(self, weight):
|
|
|
173 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
174 |
print(latent_index)
|
175 |
self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
|
|
|
|
|
176 |
return self._render_all_transformations()
|
177 |
+
# def rescale_mask(self, mask):
|
178 |
+
# rep = mask.clone()
|
179 |
+
# rep[mask < 0.03] = -1000000
|
180 |
+
# rep[mask >= 0.03] = 1
|
181 |
+
# return rep
|
182 |
def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
|
183 |
transform_log = PromptTransformHistory(iterations + reconstruction_steps)
|
184 |
transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
animation.py
CHANGED
@@ -2,15 +2,14 @@ import imageio
|
|
2 |
import glob
|
3 |
import os
|
4 |
|
5 |
-
def clear_img_dir():
|
6 |
-
img_dir = "./img_history"
|
7 |
if not os.path.exists(img_dir):
|
8 |
os.mkdir(img_dir)
|
9 |
for filename in glob.glob(img_dir+"/*"):
|
10 |
os.remove(filename)
|
11 |
|
12 |
|
13 |
-
def create_gif(total_duration, extend_frames, folder
|
14 |
images = []
|
15 |
paths = glob.glob(folder + "/*")
|
16 |
frame_duration = total_duration / len(paths)
|
|
|
2 |
import glob
|
3 |
import os
|
4 |
|
5 |
+
def clear_img_dir(img_dir):
|
|
|
6 |
if not os.path.exists(img_dir):
|
7 |
os.mkdir(img_dir)
|
8 |
for filename in glob.glob(img_dir+"/*"):
|
9 |
os.remove(filename)
|
10 |
|
11 |
|
12 |
+
def create_gif(total_duration, extend_frames, folder, gif_name="face_edit.gif"):
|
13 |
images = []
|
14 |
paths = glob.glob(folder + "/*")
|
15 |
frame_duration = total_duration / len(paths)
|
app.py
CHANGED
@@ -6,7 +6,8 @@ import wandb
|
|
6 |
import torch
|
7 |
|
8 |
from configs import set_major_global, set_major_local, set_small_local
|
9 |
-
|
|
|
10 |
sys.path.append("taming-transformers")
|
11 |
|
12 |
import gradio as gr
|
@@ -37,6 +38,8 @@ def get_cleared_mask():
|
|
37 |
# mask.clear()
|
38 |
|
39 |
class StateWrapper:
|
|
|
|
|
40 |
def apply_asian_vector(state, *args, **kwargs):
|
41 |
return state, *state[0].apply_asian_vector(*args, **kwargs)
|
42 |
def apply_gp_vector(state, *args, **kwargs):
|
@@ -141,7 +144,7 @@ with gr.Blocks(css="styles.css") as demo:
|
|
141 |
minimum=0,
|
142 |
maximum=100)
|
143 |
|
144 |
-
apply_prompts = gr.Button(value="🎨 Apply Prompts", elem_id="apply")
|
145 |
clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
|
146 |
with gr.Accordion(label="💾 Save Animation", open=False):
|
147 |
gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
|
@@ -149,7 +152,7 @@ with gr.Blocks(css="styles.css") as demo:
|
|
149 |
extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
|
150 |
gif = gr.File(interactive=False)
|
151 |
create_animation = gr.Button(value="Create Animation")
|
152 |
-
create_animation.click(create_gif, inputs=[duration, extend_frames], outputs=gif)
|
153 |
|
154 |
with gr.Column(scale=1):
|
155 |
gr.Markdown(value="""## Text Prompting
|
@@ -166,12 +169,12 @@ with gr.Blocks(css="styles.css") as demo:
|
|
166 |
with gr.Row():
|
167 |
gr.Markdown(value="## Preset Configs", show_label=False)
|
168 |
with gr.Row():
|
169 |
-
with gr.Column():
|
170 |
-
|
171 |
-
with gr.Column():
|
172 |
-
|
173 |
-
with gr.Column():
|
174 |
-
|
175 |
iterations = gr.Slider(minimum=10,
|
176 |
maximum=60,
|
177 |
step=1,
|
@@ -181,14 +184,13 @@ with gr.Blocks(css="styles.css") as demo:
|
|
181 |
maximum=7e-1,
|
182 |
value=1e-1,
|
183 |
label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
|
184 |
-
with gr.Accordion(label="Advanced Prompt Editing Options", open=False):
|
185 |
lpips_weight = gr.Slider(minimum=0,
|
186 |
maximum=50,
|
187 |
value=1,
|
188 |
label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
|
189 |
reconstruction_steps = gr.Slider(minimum=0,
|
190 |
maximum=50,
|
191 |
-
value=
|
192 |
step=1,
|
193 |
label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that will 'pull' the image back towards the original identity")
|
194 |
# discriminator_steps = gr.Slider(minimum=0,
|
@@ -196,7 +198,7 @@ with gr.Blocks(css="styles.css") as demo:
|
|
196 |
# step=1,
|
197 |
# value=0,
|
198 |
# label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
|
199 |
-
clear.click(
|
200 |
asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
|
201 |
lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
|
202 |
# hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
|
|
|
6 |
import torch
|
7 |
|
8 |
from configs import set_major_global, set_major_local, set_small_local
|
9 |
+
import uuid
|
10 |
+
# print()'
|
11 |
sys.path.append("taming-transformers")
|
12 |
|
13 |
import gradio as gr
|
|
|
38 |
# mask.clear()
|
39 |
|
40 |
class StateWrapper:
|
41 |
+
def create_gif(state, *args, **kwargs):
|
42 |
+
return state, state[0].create_gif(*args, **kwargs)
|
43 |
def apply_asian_vector(state, *args, **kwargs):
|
44 |
return state, *state[0].apply_asian_vector(*args, **kwargs)
|
45 |
def apply_gp_vector(state, *args, **kwargs):
|
|
|
144 |
minimum=0,
|
145 |
maximum=100)
|
146 |
|
147 |
+
apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
|
148 |
clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
|
149 |
with gr.Accordion(label="💾 Save Animation", open=False):
|
150 |
gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
|
|
|
152 |
extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
|
153 |
gif = gr.File(interactive=False)
|
154 |
create_animation = gr.Button(value="Create Animation")
|
155 |
+
create_animation.click(StateWrapper.create_gif, inputs=[state, duration, extend_frames], outputs=[state, gif])
|
156 |
|
157 |
with gr.Column(scale=1):
|
158 |
gr.Markdown(value="""## Text Prompting
|
|
|
169 |
with gr.Row():
|
170 |
gr.Markdown(value="## Preset Configs", show_label=False)
|
171 |
with gr.Row():
|
172 |
+
# with gr.Column():
|
173 |
+
small_local = gr.Button(value="Small Masked Changes (e.g. add lipstick)", elem_id="small_local").style(full_width=False)
|
174 |
+
# with gr.Column():
|
175 |
+
major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
|
176 |
+
# with gr.Column():
|
177 |
+
major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
|
178 |
iterations = gr.Slider(minimum=10,
|
179 |
maximum=60,
|
180 |
step=1,
|
|
|
184 |
maximum=7e-1,
|
185 |
value=1e-1,
|
186 |
label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
|
|
|
187 |
lpips_weight = gr.Slider(minimum=0,
|
188 |
maximum=50,
|
189 |
value=1,
|
190 |
label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
|
191 |
reconstruction_steps = gr.Slider(minimum=0,
|
192 |
maximum=50,
|
193 |
+
value=3,
|
194 |
step=1,
|
195 |
label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that will 'pull' the image back towards the original identity")
|
196 |
# discriminator_steps = gr.Slider(minimum=0,
|
|
|
198 |
# step=1,
|
199 |
# value=0,
|
200 |
# label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
|
201 |
+
clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
|
202 |
asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
|
203 |
lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
|
204 |
# hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
|
configs.py
CHANGED
@@ -4,4 +4,4 @@ def set_small_local():
|
|
4 |
def set_major_local():
|
5 |
return (gr.Slider.update(value=25), gr.Slider.update(value=0.2), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
|
6 |
def set_major_global():
|
7 |
-
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=
|
|
|
4 |
def set_major_local():
|
5 |
return (gr.Slider.update(value=25), gr.Slider.update(value=0.2), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
|
6 |
def set_major_global():
|
7 |
+
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=1), gr.Slider.update(value=1))
|