Spaces:
Runtime error
Runtime error
Commit
Β·
ed67bfe
0
Parent(s):
Duplicate from baulab/Erasing-Concepts-In-Diffusion
Browse filesCo-authored-by: Jaden Fiotto-Kaufman <[email protected]>
- .gitattributes +35 -0
- .gitignore +1 -0
- README.md +54 -0
- StableDiffuser.py +279 -0
- __init__.py +0 -0
- app.py +258 -0
- finetuning.py +101 -0
- images/applications.png +3 -0
- models/car.pt +3 -0
- models/frenchhorn.pt +3 -0
- models/garbagetruck.pt +3 -0
- models/kellymckernan.pt +3 -0
- models/kilianeng.pt +3 -0
- models/pablopicasso.pt +3 -0
- models/rembrandt.pt +3 -0
- models/thomaskinkade.pt +3 -0
- models/tyleredlin.pt +3 -0
- models/vangogh.pt +3 -0
- requirements.txt +8 -0
- train.py +91 -0
- util.py +107 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
images/applications.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Erasing Concepts from Diffusion Models
|
3 |
+
emoji: π‘
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.21.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: baulab/Erasing-Concepts-In-Diffusion
|
12 |
+
---
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
# Erasing Concepts from Diffusion Models
|
17 |
+
|
18 |
+
Project Website [https://erasing.baulab.info](https://erasing.baulab.info) <br>
|
19 |
+
Arxiv Preprint [https://arxiv.org/pdf/2303.07345.pdf](https://arxiv.org/pdf/2303.07345.pdf) <br>
|
20 |
+
Fine-tuned Weights [https://erasing.baulab.info/weights/esd_models/](https://erasing.baulab.info/weights/esd_models/) <br>
|
21 |
+
<div align='center'>
|
22 |
+
<img src = 'images/applications.png'>
|
23 |
+
</div>
|
24 |
+
|
25 |
+
Motivated by recent advancements in text-to-image diffusion, we study erasure of specific concepts from the model's weights. While Stable Diffusion has shown promise in producing explicit or realistic artwork, it has raised concerns regarding its potential for misuse. We propose a fine-tuning method that can erase a visual concept from a pre-trained diffusion model, given only the name of the style and using negative guidance as a teacher. We benchmark our method against previous approaches that remove sexually explicit content and demonstrate its effectiveness, performing on par with Safe Latent Diffusion and censored training.
|
26 |
+
|
27 |
+
To evaluate artistic style removal, we conduct experiments erasing five modern artists from the network and conduct a user study to assess the human perception of the removed styles. Unlike previous methods, our approach can remove concepts from a diffusion model permanently rather than modifying the output at the inference time, so it cannot be circumvented even if a user has access to model weights
|
28 |
+
|
29 |
+
Given only a short text description of an undesired visual concept and no additional data, our method fine-tunes model weights to erase the targeted concept. Our method can avoid NSFW content, stop imitation of a specific artist's style, or even erase a whole object class from model output, while preserving the model's behavior and capabilities on other topics.
|
30 |
+
|
31 |
+
## Demo vs github
|
32 |
+
|
33 |
+
This demo uses an updated implementation from the original Erasing codebase the publication is based from.
|
34 |
+
|
35 |
+
## Running locally
|
36 |
+
|
37 |
+
1.) Create an environment using the packages included in the requirements.txt file
|
38 |
+
|
39 |
+
2.) Run `python app.py`
|
40 |
+
|
41 |
+
3.) Open the application in browser at `http://127.0.0.1:7860/`
|
42 |
+
|
43 |
+
4.) Train, evaluate, and save models using our method
|
44 |
+
|
45 |
+
## Citing our work
|
46 |
+
The preprint can be cited as follows
|
47 |
+
```
|
48 |
+
@article{gandikota2023erasing,
|
49 |
+
title={Erasing Concepts from Diffusion Models},
|
50 |
+
author={Rohit Gandikota and Joanna Materzy\'nska and Jaden Fiotto-Kaufman and David Bau},
|
51 |
+
journal={arXiv preprint arXiv:2303.07345},
|
52 |
+
year={2023}
|
53 |
+
}
|
54 |
+
```
|
StableDiffuser.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from baukit import TraceDict
|
5 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
9 |
+
from diffusers.schedulers import EulerAncestralDiscreteScheduler
|
10 |
+
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
11 |
+
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
12 |
+
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
13 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
14 |
+
import util
|
15 |
+
|
16 |
+
|
17 |
+
def default_parser():
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
|
21 |
+
parser.add_argument('prompts', type=str, nargs='+')
|
22 |
+
parser.add_argument('outpath', type=str)
|
23 |
+
|
24 |
+
parser.add_argument('--images', type=str, nargs='+', default=None)
|
25 |
+
parser.add_argument('--nsteps', type=int, default=1000)
|
26 |
+
parser.add_argument('--nimgs', type=int, default=1)
|
27 |
+
parser.add_argument('--start_itr', type=int, default=0)
|
28 |
+
parser.add_argument('--return_steps', action='store_true', default=False)
|
29 |
+
parser.add_argument('--pred_x0', action='store_true', default=False)
|
30 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
31 |
+
parser.add_argument('--seed', type=int, default=42)
|
32 |
+
|
33 |
+
return parser
|
34 |
+
|
35 |
+
|
36 |
+
class StableDiffuser(torch.nn.Module):
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
scheduler='LMS'
|
40 |
+
):
|
41 |
+
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
# Load the autoencoder model which will be used to decode the latents into image space.
|
45 |
+
self.vae = AutoencoderKL.from_pretrained(
|
46 |
+
"CompVis/stable-diffusion-v1-4", subfolder="vae")
|
47 |
+
|
48 |
+
# Load the tokenizer and text encoder to tokenize and encode the text.
|
49 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
50 |
+
"openai/clip-vit-large-patch14")
|
51 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
52 |
+
"openai/clip-vit-large-patch14")
|
53 |
+
|
54 |
+
# The UNet model for generating the latents.
|
55 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
56 |
+
"CompVis/stable-diffusion-v1-4", subfolder="unet")
|
57 |
+
|
58 |
+
self.feature_extractor = CLIPFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="feature_extractor")
|
59 |
+
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="safety_checker")
|
60 |
+
|
61 |
+
if scheduler == 'LMS':
|
62 |
+
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
63 |
+
elif scheduler == 'DDIM':
|
64 |
+
self.scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
65 |
+
elif scheduler == 'DDPM':
|
66 |
+
self.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
67 |
+
|
68 |
+
self.eval()
|
69 |
+
|
70 |
+
def get_noise(self, batch_size, img_size, generator=None):
|
71 |
+
|
72 |
+
param = list(self.parameters())[0]
|
73 |
+
|
74 |
+
return torch.randn(
|
75 |
+
(batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
|
76 |
+
generator=generator).type(param.dtype).to(param.device)
|
77 |
+
|
78 |
+
def add_noise(self, latents, noise, step):
|
79 |
+
|
80 |
+
return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
|
81 |
+
|
82 |
+
def text_tokenize(self, prompts):
|
83 |
+
|
84 |
+
return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
85 |
+
|
86 |
+
def text_detokenize(self, tokens):
|
87 |
+
|
88 |
+
return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
|
89 |
+
|
90 |
+
def text_encode(self, tokens):
|
91 |
+
|
92 |
+
return self.text_encoder(tokens.input_ids.to(self.unet.device))[0]
|
93 |
+
|
94 |
+
def decode(self, latents):
|
95 |
+
|
96 |
+
return self.vae.decode(1 / self.vae.config.scaling_factor * latents).sample
|
97 |
+
|
98 |
+
def encode(self, tensors):
|
99 |
+
|
100 |
+
return self.vae.encode(tensors).latent_dist.mode() * 0.18215
|
101 |
+
|
102 |
+
def to_image(self, image):
|
103 |
+
|
104 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
105 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
106 |
+
images = (image * 255).round().astype("uint8")
|
107 |
+
pil_images = [Image.fromarray(image) for image in images]
|
108 |
+
|
109 |
+
return pil_images
|
110 |
+
|
111 |
+
def set_scheduler_timesteps(self, n_steps):
|
112 |
+
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
113 |
+
|
114 |
+
def get_initial_latents(self, n_imgs, img_size, n_prompts, generator=None):
|
115 |
+
|
116 |
+
noise = self.get_noise(n_imgs, img_size, generator=generator).repeat(n_prompts, 1, 1, 1)
|
117 |
+
|
118 |
+
latents = noise * self.scheduler.init_noise_sigma
|
119 |
+
|
120 |
+
return latents
|
121 |
+
|
122 |
+
def get_text_embeddings(self, prompts, n_imgs):
|
123 |
+
|
124 |
+
text_tokens = self.text_tokenize(prompts)
|
125 |
+
|
126 |
+
text_embeddings = self.text_encode(text_tokens)
|
127 |
+
|
128 |
+
unconditional_tokens = self.text_tokenize([""] * len(prompts))
|
129 |
+
|
130 |
+
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
131 |
+
|
132 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
133 |
+
|
134 |
+
return text_embeddings
|
135 |
+
|
136 |
+
def predict_noise(self,
|
137 |
+
iteration,
|
138 |
+
latents,
|
139 |
+
text_embeddings,
|
140 |
+
guidance_scale=7.5
|
141 |
+
):
|
142 |
+
|
143 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
144 |
+
latents = torch.cat([latents] * 2)
|
145 |
+
latents = self.scheduler.scale_model_input(
|
146 |
+
latents, self.scheduler.timesteps[iteration])
|
147 |
+
|
148 |
+
# predict the noise residual
|
149 |
+
noise_prediction = self.unet(
|
150 |
+
latents, self.scheduler.timesteps[iteration], encoder_hidden_states=text_embeddings).sample
|
151 |
+
|
152 |
+
# perform guidance
|
153 |
+
noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2)
|
154 |
+
noise_prediction = noise_prediction_uncond + guidance_scale * \
|
155 |
+
(noise_prediction_text - noise_prediction_uncond)
|
156 |
+
|
157 |
+
return noise_prediction
|
158 |
+
|
159 |
+
@torch.no_grad()
|
160 |
+
def diffusion(self,
|
161 |
+
latents,
|
162 |
+
text_embeddings,
|
163 |
+
end_iteration=1000,
|
164 |
+
start_iteration=0,
|
165 |
+
return_steps=False,
|
166 |
+
pred_x0=False,
|
167 |
+
trace_args=None,
|
168 |
+
show_progress=True,
|
169 |
+
**kwargs):
|
170 |
+
|
171 |
+
latents_steps = []
|
172 |
+
trace_steps = []
|
173 |
+
|
174 |
+
trace = None
|
175 |
+
|
176 |
+
for iteration in tqdm(range(start_iteration, end_iteration), disable=not show_progress):
|
177 |
+
|
178 |
+
if trace_args:
|
179 |
+
|
180 |
+
trace = TraceDict(self, **trace_args)
|
181 |
+
|
182 |
+
noise_pred = self.predict_noise(
|
183 |
+
iteration,
|
184 |
+
latents,
|
185 |
+
text_embeddings,
|
186 |
+
**kwargs)
|
187 |
+
|
188 |
+
# compute the previous noisy sample x_t -> x_t-1
|
189 |
+
output = self.scheduler.step(noise_pred, self.scheduler.timesteps[iteration], latents)
|
190 |
+
|
191 |
+
if trace_args:
|
192 |
+
|
193 |
+
trace.close()
|
194 |
+
|
195 |
+
trace_steps.append(trace)
|
196 |
+
|
197 |
+
latents = output.prev_sample
|
198 |
+
|
199 |
+
if return_steps or iteration == end_iteration - 1:
|
200 |
+
|
201 |
+
output = output.pred_original_sample if pred_x0 else latents
|
202 |
+
|
203 |
+
if return_steps:
|
204 |
+
latents_steps.append(output.cpu())
|
205 |
+
else:
|
206 |
+
latents_steps.append(output)
|
207 |
+
|
208 |
+
return latents_steps, trace_steps
|
209 |
+
|
210 |
+
@torch.no_grad()
|
211 |
+
def __call__(self,
|
212 |
+
prompts,
|
213 |
+
img_size=512,
|
214 |
+
n_steps=50,
|
215 |
+
n_imgs=1,
|
216 |
+
end_iteration=None,
|
217 |
+
generator=None,
|
218 |
+
**kwargs
|
219 |
+
):
|
220 |
+
|
221 |
+
assert 0 <= n_steps <= 1000
|
222 |
+
|
223 |
+
if not isinstance(prompts, list):
|
224 |
+
|
225 |
+
prompts = [prompts]
|
226 |
+
|
227 |
+
self.set_scheduler_timesteps(n_steps)
|
228 |
+
|
229 |
+
latents = self.get_initial_latents(n_imgs, img_size, len(prompts), generator=generator)
|
230 |
+
|
231 |
+
text_embeddings = self.get_text_embeddings(prompts,n_imgs=n_imgs)
|
232 |
+
|
233 |
+
end_iteration = end_iteration or n_steps
|
234 |
+
|
235 |
+
latents_steps, trace_steps = self.diffusion(
|
236 |
+
latents,
|
237 |
+
text_embeddings,
|
238 |
+
end_iteration=end_iteration,
|
239 |
+
**kwargs
|
240 |
+
)
|
241 |
+
|
242 |
+
latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
|
243 |
+
images_steps = [self.to_image(latents) for latents in latents_steps]
|
244 |
+
|
245 |
+
for i in range(len(images_steps)):
|
246 |
+
self.safety_checker = self.safety_checker.float()
|
247 |
+
safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
|
248 |
+
image, has_nsfw_concept = self.safety_checker(
|
249 |
+
images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
|
250 |
+
)
|
251 |
+
|
252 |
+
images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
|
253 |
+
|
254 |
+
images_steps = list(zip(*images_steps))
|
255 |
+
|
256 |
+
if trace_steps:
|
257 |
+
|
258 |
+
return images_steps, trace_steps
|
259 |
+
|
260 |
+
return images_steps
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == '__main__':
|
264 |
+
|
265 |
+
parser = default_parser()
|
266 |
+
|
267 |
+
args = parser.parse_args()
|
268 |
+
|
269 |
+
diffuser = StableDiffuser(seed=args.seed, scheduler='DDIM').to(torch.device(args.device)).half()
|
270 |
+
|
271 |
+
images = diffuser(args.prompts,
|
272 |
+
n_steps=args.nsteps,
|
273 |
+
n_imgs=args.nimgs,
|
274 |
+
start_iteration=args.start_itr,
|
275 |
+
return_steps=args.return_steps,
|
276 |
+
pred_x0=args.pred_x0
|
277 |
+
)
|
278 |
+
|
279 |
+
util.image_grid(images, args.outpath)
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from finetuning import FineTunedModel
|
4 |
+
from StableDiffuser import StableDiffuser
|
5 |
+
from train import train
|
6 |
+
|
7 |
+
import os
|
8 |
+
model_map = {'Van Gogh' : 'models/vangogh.pt',
|
9 |
+
'Pablo Picasso': 'models/pablopicasso.pt',
|
10 |
+
'Car' : 'models/car.pt',
|
11 |
+
'Garbage Truck': 'models/garbagetruck.pt',
|
12 |
+
'French Horn': 'models/frenchhorn.pt',
|
13 |
+
'Kilian Eng' : 'models/kilianeng.pt',
|
14 |
+
'Thomas Kinkade' : 'models/thomaskinkade.pt',
|
15 |
+
'Tyler Edlin' : 'models/tyleredlin.pt',
|
16 |
+
'Kelly McKernan': 'models/kellymckernan.pt',
|
17 |
+
'Rembrandt': 'models/rembrandt.pt' }
|
18 |
+
|
19 |
+
ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
|
20 |
+
SPACE_ID = os.getenv('SPACE_ID')
|
21 |
+
|
22 |
+
SHARED_UI_WARNING = f'''## Attention - Training using the ESD-u method does not work in this shared UI. You can either duplicate and use it with a gpu with at least 40GB, or clone this repository to run on your own machine.
|
23 |
+
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
|
24 |
+
'''
|
25 |
+
|
26 |
+
|
27 |
+
class Demo:
|
28 |
+
|
29 |
+
def __init__(self) -> None:
|
30 |
+
|
31 |
+
self.training = False
|
32 |
+
self.generating = False
|
33 |
+
|
34 |
+
self.diffuser = StableDiffuser(scheduler='DDIM').to('cuda').eval().half()
|
35 |
+
|
36 |
+
with gr.Blocks() as demo:
|
37 |
+
self.layout()
|
38 |
+
demo.queue(concurrency_count=5).launch()
|
39 |
+
|
40 |
+
|
41 |
+
def layout(self):
|
42 |
+
|
43 |
+
with gr.Row():
|
44 |
+
|
45 |
+
if SPACE_ID == ORIGINAL_SPACE_ID:
|
46 |
+
|
47 |
+
self.warning = gr.Markdown(SHARED_UI_WARNING)
|
48 |
+
|
49 |
+
with gr.Row():
|
50 |
+
|
51 |
+
with gr.Tab("Test") as inference_column:
|
52 |
+
|
53 |
+
with gr.Row():
|
54 |
+
|
55 |
+
self.explain_infr = gr.Markdown(interactive=False,
|
56 |
+
value='This is a demo of [Erasing Concepts from Stable Diffusion](https://erasing.baulab.info/). To try out a model where a concept has been erased, select a model and enter any prompt. For example, if you select the model "Van Gogh" you can generate images for the prompt "A portrait in the style of Van Gogh" and compare the erased and unerased models. We have also provided several other pre-fine-tuned models with artistic styles and objects erased (Check out the "ESD Model" drop-down). You can also train and run your own custom models. Check out the "train" section for custom erasure of concepts.')
|
57 |
+
|
58 |
+
with gr.Row():
|
59 |
+
|
60 |
+
with gr.Column(scale=1):
|
61 |
+
|
62 |
+
self.prompt_input_infr = gr.Text(
|
63 |
+
placeholder="Enter prompt...",
|
64 |
+
label="Prompt",
|
65 |
+
info="Prompt to generate"
|
66 |
+
)
|
67 |
+
|
68 |
+
with gr.Row():
|
69 |
+
|
70 |
+
self.model_dropdown = gr.Dropdown(
|
71 |
+
label="ESD Model",
|
72 |
+
choices= list(model_map.keys()),
|
73 |
+
value='Van Gogh',
|
74 |
+
interactive=True
|
75 |
+
)
|
76 |
+
|
77 |
+
self.seed_infr = gr.Number(
|
78 |
+
label="Seed",
|
79 |
+
value=42
|
80 |
+
)
|
81 |
+
|
82 |
+
with gr.Column(scale=2):
|
83 |
+
|
84 |
+
self.infr_button = gr.Button(
|
85 |
+
value="Generate",
|
86 |
+
interactive=True
|
87 |
+
)
|
88 |
+
|
89 |
+
with gr.Row():
|
90 |
+
|
91 |
+
self.image_new = gr.Image(
|
92 |
+
label="ESD",
|
93 |
+
interactive=False
|
94 |
+
)
|
95 |
+
self.image_orig = gr.Image(
|
96 |
+
label="SD",
|
97 |
+
interactive=False
|
98 |
+
)
|
99 |
+
|
100 |
+
with gr.Tab("Train") as training_column:
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
|
104 |
+
self.explain_train= gr.Markdown(interactive=False,
|
105 |
+
value='In this part you can erase any concept from Stable Diffusion. Enter a prompt for the concept or style you want to erase, and select ESD-x if you want to focus erasure on prompts that mention the concept explicitly. [NOTE: ESD-u is currently unavailable in this space. But you can duplicate the space and run it on GPU with VRAM >40GB for enabling ESD-u]. With default settings, it takes about 15 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/rohitgandikota/erasing).')
|
106 |
+
|
107 |
+
with gr.Row():
|
108 |
+
|
109 |
+
with gr.Column(scale=3):
|
110 |
+
|
111 |
+
self.prompt_input = gr.Text(
|
112 |
+
placeholder="Enter prompt...",
|
113 |
+
label="Prompt to Erase",
|
114 |
+
info="Prompt corresponding to concept to erase"
|
115 |
+
)
|
116 |
+
|
117 |
+
choices = ['ESD-x']
|
118 |
+
if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40:
|
119 |
+
choices.append('ESD-u')
|
120 |
+
|
121 |
+
self.train_method_input = gr.Dropdown(
|
122 |
+
choices=choices,
|
123 |
+
value='ESD-x',
|
124 |
+
label='Train Method',
|
125 |
+
info='Method of training'
|
126 |
+
)
|
127 |
+
|
128 |
+
self.neg_guidance_input = gr.Number(
|
129 |
+
value=1,
|
130 |
+
label="Negative Guidance",
|
131 |
+
info='Guidance of negative training used to train'
|
132 |
+
)
|
133 |
+
|
134 |
+
self.iterations_input = gr.Number(
|
135 |
+
value=150,
|
136 |
+
precision=0,
|
137 |
+
label="Iterations",
|
138 |
+
info='iterations used to train'
|
139 |
+
)
|
140 |
+
|
141 |
+
self.lr_input = gr.Number(
|
142 |
+
value=1e-5,
|
143 |
+
label="Learning Rate",
|
144 |
+
info='Learning rate used to train'
|
145 |
+
)
|
146 |
+
|
147 |
+
with gr.Column(scale=1):
|
148 |
+
|
149 |
+
self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
|
150 |
+
|
151 |
+
self.train_button = gr.Button(
|
152 |
+
value="Train",
|
153 |
+
)
|
154 |
+
|
155 |
+
self.download = gr.Files()
|
156 |
+
|
157 |
+
self.infr_button.click(self.inference, inputs = [
|
158 |
+
self.prompt_input_infr,
|
159 |
+
self.seed_infr,
|
160 |
+
self.model_dropdown
|
161 |
+
],
|
162 |
+
outputs=[
|
163 |
+
self.image_new,
|
164 |
+
self.image_orig
|
165 |
+
]
|
166 |
+
)
|
167 |
+
self.train_button.click(self.train, inputs = [
|
168 |
+
self.prompt_input,
|
169 |
+
self.train_method_input,
|
170 |
+
self.neg_guidance_input,
|
171 |
+
self.iterations_input,
|
172 |
+
self.lr_input
|
173 |
+
],
|
174 |
+
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
175 |
+
)
|
176 |
+
|
177 |
+
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
178 |
+
|
179 |
+
if self.training:
|
180 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
181 |
+
|
182 |
+
if train_method == 'ESD-x':
|
183 |
+
|
184 |
+
modules = ".*attn2$"
|
185 |
+
frozen = []
|
186 |
+
|
187 |
+
elif train_method == 'ESD-u':
|
188 |
+
|
189 |
+
modules = "unet$"
|
190 |
+
frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"]
|
191 |
+
|
192 |
+
elif train_method == 'ESD-self':
|
193 |
+
|
194 |
+
modules = ".*attn1$"
|
195 |
+
frozen = []
|
196 |
+
|
197 |
+
randn = torch.randint(1, 10000000, (1,)).item()
|
198 |
+
|
199 |
+
save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
|
200 |
+
|
201 |
+
self.training = True
|
202 |
+
|
203 |
+
train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
|
204 |
+
|
205 |
+
self.training = False
|
206 |
+
|
207 |
+
torch.cuda.empty_cache()
|
208 |
+
|
209 |
+
model_map['Custom'] = save_path
|
210 |
+
|
211 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
212 |
+
|
213 |
+
|
214 |
+
def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
|
215 |
+
|
216 |
+
seed = seed or 42
|
217 |
+
|
218 |
+
generator = torch.manual_seed(seed)
|
219 |
+
|
220 |
+
model_path = model_map[model_name]
|
221 |
+
|
222 |
+
checkpoint = torch.load(model_path)
|
223 |
+
|
224 |
+
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
|
225 |
+
|
226 |
+
torch.cuda.empty_cache()
|
227 |
+
|
228 |
+
images = self.diffuser(
|
229 |
+
prompt,
|
230 |
+
n_steps=50,
|
231 |
+
generator=generator
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
orig_image = images[0][0]
|
236 |
+
|
237 |
+
torch.cuda.empty_cache()
|
238 |
+
|
239 |
+
generator = torch.manual_seed(seed)
|
240 |
+
|
241 |
+
with finetuner:
|
242 |
+
|
243 |
+
images = self.diffuser(
|
244 |
+
prompt,
|
245 |
+
n_steps=50,
|
246 |
+
generator=generator
|
247 |
+
)
|
248 |
+
|
249 |
+
edited_image = images[0][0]
|
250 |
+
|
251 |
+
del finetuner
|
252 |
+
torch.cuda.empty_cache()
|
253 |
+
|
254 |
+
return edited_image, orig_image
|
255 |
+
|
256 |
+
|
257 |
+
demo = Demo()
|
258 |
+
|
finetuning.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import util
|
5 |
+
|
6 |
+
class FineTunedModel(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self,
|
9 |
+
model,
|
10 |
+
modules,
|
11 |
+
frozen_modules=[]
|
12 |
+
):
|
13 |
+
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
if isinstance(modules, str):
|
17 |
+
modules = [modules]
|
18 |
+
|
19 |
+
self.model = model
|
20 |
+
self.ft_modules = {}
|
21 |
+
self.orig_modules = {}
|
22 |
+
|
23 |
+
util.freeze(self.model)
|
24 |
+
|
25 |
+
for module_name, module in model.named_modules():
|
26 |
+
for ft_module_regex in modules:
|
27 |
+
|
28 |
+
match = re.search(ft_module_regex, module_name)
|
29 |
+
|
30 |
+
if match is not None:
|
31 |
+
|
32 |
+
ft_module = copy.deepcopy(module)
|
33 |
+
|
34 |
+
self.orig_modules[module_name] = module
|
35 |
+
self.ft_modules[module_name] = ft_module
|
36 |
+
|
37 |
+
util.unfreeze(ft_module)
|
38 |
+
|
39 |
+
print(f"=> Finetuning {module_name}")
|
40 |
+
|
41 |
+
for ft_module_name, module in ft_module.named_modules():
|
42 |
+
|
43 |
+
ft_module_name = f"{module_name}.{ft_module_name}"
|
44 |
+
|
45 |
+
for freeze_module_name in frozen_modules:
|
46 |
+
|
47 |
+
match = re.search(freeze_module_name, ft_module_name)
|
48 |
+
|
49 |
+
if match:
|
50 |
+
print(f"=> Freezing {ft_module_name}")
|
51 |
+
util.freeze(module)
|
52 |
+
|
53 |
+
self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
|
54 |
+
self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())
|
55 |
+
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
|
59 |
+
|
60 |
+
if isinstance(checkpoint, str):
|
61 |
+
checkpoint = torch.load(checkpoint)
|
62 |
+
|
63 |
+
modules = [f"{key}$" for key in list(checkpoint.keys())]
|
64 |
+
|
65 |
+
ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules)
|
66 |
+
ftm.load_state_dict(checkpoint)
|
67 |
+
|
68 |
+
return ftm
|
69 |
+
|
70 |
+
|
71 |
+
def __enter__(self):
|
72 |
+
|
73 |
+
for key, ft_module in self.ft_modules.items():
|
74 |
+
util.set_module(self.model, key, ft_module)
|
75 |
+
|
76 |
+
def __exit__(self, exc_type, exc_value, tb):
|
77 |
+
|
78 |
+
for key, module in self.orig_modules.items():
|
79 |
+
util.set_module(self.model, key, module)
|
80 |
+
|
81 |
+
def parameters(self):
|
82 |
+
|
83 |
+
parameters = []
|
84 |
+
|
85 |
+
for ft_module in self.ft_modules.values():
|
86 |
+
|
87 |
+
parameters.extend(list(ft_module.parameters()))
|
88 |
+
|
89 |
+
return parameters
|
90 |
+
|
91 |
+
def state_dict(self):
|
92 |
+
|
93 |
+
state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
|
94 |
+
|
95 |
+
return state_dict
|
96 |
+
|
97 |
+
def load_state_dict(self, state_dict):
|
98 |
+
|
99 |
+
for key, sd in state_dict.items():
|
100 |
+
|
101 |
+
self.ft_modules[key].load_state_dict(sd)
|
images/applications.png
ADDED
![]() |
Git LFS Details
|
models/car.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a486d8417dc06dcdadfafe738ca32fb9d48f3a1a144d96cb2781e9e5f0c6f98
|
3 |
+
size 3438317621
|
models/frenchhorn.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48dc787885e54bbad818b57205ccb39796051b07b98a00b43776fb7d7e375fc0
|
3 |
+
size 3438372469
|
models/garbagetruck.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e8bc8d2d973e941a16a03cae81113b5f5da07245888dd29934ea77af9242aba
|
3 |
+
size 3438373845
|
models/kellymckernan.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dee79125ec560c09fcd8820f2df9e3bee94b8a8e8ef3c8e9b330a80ec87cf45e
|
3 |
+
size 175879857
|
models/kilianeng.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f99b66762771ad4da29d7797b5f20b3b14390082a713d8891c800c759fafd11c
|
3 |
+
size 175878873
|
models/pablopicasso.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2988ba3a00af8e7796fdd459d31c63e45470d5b00b88fa0a18f1722c8c55fd9a
|
3 |
+
size 175879775
|
models/rembrandt.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b95b9bd8d13f2cd8fc389ec2ec6246a3da334cc0c09cba323a3743ae2453cf58
|
3 |
+
size 175879529
|
models/thomaskinkade.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a31bd7f5bc1bf838e9443a7fd3f729d5326d5d44847a23dffdf6611591ea24a
|
3 |
+
size 175879201
|
models/tyleredlin.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:830a00af362c2805be1f41acd121b5b74936f75b08c94ab44c68fe1b355af901
|
3 |
+
size 175878955
|
models/vangogh.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75cdb4313898f593b16f23dbceca498f3f16a749802450ab358a12c204404c27
|
3 |
+
size 175873179
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch==1.13.1 --index-url https://download.pytorch.org/whl/cu118
|
3 |
+
torchvision==0.14.1 --index-url https://download.pytorch.org/whl/cu118
|
4 |
+
diffusers
|
5 |
+
transformers
|
6 |
+
accelerate
|
7 |
+
scipy
|
8 |
+
git+https://github.com/davidbau/baukit.git
|
train.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from StableDiffuser import StableDiffuser
|
2 |
+
from finetuning import FineTunedModel
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
|
7 |
+
|
8 |
+
nsteps = 50
|
9 |
+
|
10 |
+
diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
|
11 |
+
diffuser.train()
|
12 |
+
|
13 |
+
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
14 |
+
|
15 |
+
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
16 |
+
criteria = torch.nn.MSELoss()
|
17 |
+
|
18 |
+
pbar = tqdm(range(iterations))
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
|
22 |
+
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
23 |
+
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
24 |
+
|
25 |
+
del diffuser.vae
|
26 |
+
del diffuser.text_encoder
|
27 |
+
del diffuser.tokenizer
|
28 |
+
|
29 |
+
torch.cuda.empty_cache()
|
30 |
+
|
31 |
+
for i in pbar:
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
|
35 |
+
diffuser.set_scheduler_timesteps(nsteps)
|
36 |
+
|
37 |
+
optimizer.zero_grad()
|
38 |
+
|
39 |
+
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
40 |
+
|
41 |
+
latents = diffuser.get_initial_latents(1, 512, 1)
|
42 |
+
|
43 |
+
with finetuner:
|
44 |
+
|
45 |
+
latents_steps, _ = diffuser.diffusion(
|
46 |
+
latents,
|
47 |
+
positive_text_embeddings,
|
48 |
+
start_iteration=0,
|
49 |
+
end_iteration=iteration,
|
50 |
+
guidance_scale=3,
|
51 |
+
show_progress=False
|
52 |
+
)
|
53 |
+
|
54 |
+
diffuser.set_scheduler_timesteps(1000)
|
55 |
+
|
56 |
+
iteration = int(iteration / nsteps * 1000)
|
57 |
+
|
58 |
+
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
59 |
+
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
|
60 |
+
|
61 |
+
with finetuner:
|
62 |
+
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
63 |
+
|
64 |
+
positive_latents.requires_grad = False
|
65 |
+
neutral_latents.requires_grad = False
|
66 |
+
|
67 |
+
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
|
68 |
+
|
69 |
+
loss.backward()
|
70 |
+
optimizer.step()
|
71 |
+
|
72 |
+
torch.save(finetuner.state_dict(), save_path)
|
73 |
+
|
74 |
+
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
|
75 |
+
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
if __name__ == '__main__':
|
78 |
+
|
79 |
+
import argparse
|
80 |
+
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
|
83 |
+
parser.add_argument('--prompt', required=True)
|
84 |
+
parser.add_argument('--modules', required=True)
|
85 |
+
parser.add_argument('--freeze_modules', nargs='+', required=True)
|
86 |
+
parser.add_argument('--save_path', required=True)
|
87 |
+
parser.add_argument('--iterations', type=int, required=True)
|
88 |
+
parser.add_argument('--lr', type=float, required=True)
|
89 |
+
parser.add_argument('--negative_guidance', type=float, required=True)
|
90 |
+
|
91 |
+
train(**vars(parser.parse_args()))
|
util.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
import textwrap
|
4 |
+
|
5 |
+
|
6 |
+
def to_gif(images, path):
|
7 |
+
|
8 |
+
images[0].save(path, save_all=True,
|
9 |
+
append_images=images[1:], loop=0, duration=len(images) * 20)
|
10 |
+
|
11 |
+
|
12 |
+
def figure_to_image(figure):
|
13 |
+
|
14 |
+
figure.set_dpi(300)
|
15 |
+
|
16 |
+
figure.canvas.draw()
|
17 |
+
|
18 |
+
return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
|
19 |
+
|
20 |
+
|
21 |
+
def image_grid(images, outpath=None, column_titles=None, row_titles=None):
|
22 |
+
|
23 |
+
n_rows = len(images)
|
24 |
+
n_cols = len(images[0])
|
25 |
+
|
26 |
+
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
|
27 |
+
figsize=(n_cols, n_rows), squeeze=False)
|
28 |
+
|
29 |
+
for row, _images in enumerate(images):
|
30 |
+
|
31 |
+
for column, image in enumerate(_images):
|
32 |
+
ax = axs[row][column]
|
33 |
+
ax.imshow(image)
|
34 |
+
if column_titles and row == 0:
|
35 |
+
ax.set_title(textwrap.fill(
|
36 |
+
column_titles[column], width=12), fontsize='x-small')
|
37 |
+
if row_titles and column == 0:
|
38 |
+
ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row]))
|
39 |
+
ax.set_xticks([])
|
40 |
+
ax.set_yticks([])
|
41 |
+
|
42 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
43 |
+
|
44 |
+
if outpath is not None:
|
45 |
+
plt.savefig(outpath, bbox_inches='tight', dpi=300)
|
46 |
+
plt.close()
|
47 |
+
else:
|
48 |
+
plt.tight_layout(pad=0)
|
49 |
+
image = figure_to_image(plt.gcf())
|
50 |
+
plt.close()
|
51 |
+
return image
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def get_module(module, module_name):
|
60 |
+
|
61 |
+
if isinstance(module_name, str):
|
62 |
+
module_name = module_name.split('.')
|
63 |
+
|
64 |
+
if len(module_name) == 0:
|
65 |
+
return module
|
66 |
+
else:
|
67 |
+
module = getattr(module, module_name[0])
|
68 |
+
return get_module(module, module_name[1:])
|
69 |
+
|
70 |
+
|
71 |
+
def set_module(module, module_name, new_module):
|
72 |
+
|
73 |
+
if isinstance(module_name, str):
|
74 |
+
module_name = module_name.split('.')
|
75 |
+
|
76 |
+
if len(module_name) == 1:
|
77 |
+
return setattr(module, module_name[0], new_module)
|
78 |
+
else:
|
79 |
+
module = getattr(module, module_name[0])
|
80 |
+
return set_module(module, module_name[1:], new_module)
|
81 |
+
|
82 |
+
|
83 |
+
def freeze(module):
|
84 |
+
|
85 |
+
for parameter in module.parameters():
|
86 |
+
|
87 |
+
parameter.requires_grad = False
|
88 |
+
|
89 |
+
|
90 |
+
def unfreeze(module):
|
91 |
+
|
92 |
+
for parameter in module.parameters():
|
93 |
+
|
94 |
+
parameter.requires_grad = True
|
95 |
+
|
96 |
+
|
97 |
+
def get_concat_h(im1, im2):
|
98 |
+
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
|
99 |
+
dst.paste(im1, (0, 0))
|
100 |
+
dst.paste(im2, (im1.width, 0))
|
101 |
+
return dst
|
102 |
+
|
103 |
+
def get_concat_v(im1, im2):
|
104 |
+
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
|
105 |
+
dst.paste(im1, (0, 0))
|
106 |
+
dst.paste(im2, (0, im1.height))
|
107 |
+
return dst
|