SeiriRyu JadenFK commited on
Commit
ed67bfe
Β·
0 Parent(s):

Duplicate from baulab/Erasing-Concepts-In-Diffusion

Browse files

Co-authored-by: Jaden Fiotto-Kaufman <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ images/applications.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Erasing Concepts from Diffusion Models
3
+ emoji: πŸ’‘
4
+ colorFrom: indigo
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: baulab/Erasing-Concepts-In-Diffusion
12
+ ---
13
+
14
+
15
+
16
+ # Erasing Concepts from Diffusion Models
17
+
18
+ Project Website [https://erasing.baulab.info](https://erasing.baulab.info) <br>
19
+ Arxiv Preprint [https://arxiv.org/pdf/2303.07345.pdf](https://arxiv.org/pdf/2303.07345.pdf) <br>
20
+ Fine-tuned Weights [https://erasing.baulab.info/weights/esd_models/](https://erasing.baulab.info/weights/esd_models/) <br>
21
+ <div align='center'>
22
+ <img src = 'images/applications.png'>
23
+ </div>
24
+
25
+ Motivated by recent advancements in text-to-image diffusion, we study erasure of specific concepts from the model's weights. While Stable Diffusion has shown promise in producing explicit or realistic artwork, it has raised concerns regarding its potential for misuse. We propose a fine-tuning method that can erase a visual concept from a pre-trained diffusion model, given only the name of the style and using negative guidance as a teacher. We benchmark our method against previous approaches that remove sexually explicit content and demonstrate its effectiveness, performing on par with Safe Latent Diffusion and censored training.
26
+
27
+ To evaluate artistic style removal, we conduct experiments erasing five modern artists from the network and conduct a user study to assess the human perception of the removed styles. Unlike previous methods, our approach can remove concepts from a diffusion model permanently rather than modifying the output at the inference time, so it cannot be circumvented even if a user has access to model weights
28
+
29
+ Given only a short text description of an undesired visual concept and no additional data, our method fine-tunes model weights to erase the targeted concept. Our method can avoid NSFW content, stop imitation of a specific artist's style, or even erase a whole object class from model output, while preserving the model's behavior and capabilities on other topics.
30
+
31
+ ## Demo vs github
32
+
33
+ This demo uses an updated implementation from the original Erasing codebase the publication is based from.
34
+
35
+ ## Running locally
36
+
37
+ 1.) Create an environment using the packages included in the requirements.txt file
38
+
39
+ 2.) Run `python app.py`
40
+
41
+ 3.) Open the application in browser at `http://127.0.0.1:7860/`
42
+
43
+ 4.) Train, evaluate, and save models using our method
44
+
45
+ ## Citing our work
46
+ The preprint can be cited as follows
47
+ ```
48
+ @article{gandikota2023erasing,
49
+ title={Erasing Concepts from Diffusion Models},
50
+ author={Rohit Gandikota and Joanna Materzy\'nska and Jaden Fiotto-Kaufman and David Bau},
51
+ journal={arXiv preprint arXiv:2303.07345},
52
+ year={2023}
53
+ }
54
+ ```
StableDiffuser.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from baukit import TraceDict
5
+ from diffusers import AutoencoderKL, UNet2DConditionModel
6
+ from PIL import Image
7
+ from tqdm.auto import tqdm
8
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
9
+ from diffusers.schedulers import EulerAncestralDiscreteScheduler
10
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
11
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
12
+ from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
14
+ import util
15
+
16
+
17
+ def default_parser():
18
+
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument('prompts', type=str, nargs='+')
22
+ parser.add_argument('outpath', type=str)
23
+
24
+ parser.add_argument('--images', type=str, nargs='+', default=None)
25
+ parser.add_argument('--nsteps', type=int, default=1000)
26
+ parser.add_argument('--nimgs', type=int, default=1)
27
+ parser.add_argument('--start_itr', type=int, default=0)
28
+ parser.add_argument('--return_steps', action='store_true', default=False)
29
+ parser.add_argument('--pred_x0', action='store_true', default=False)
30
+ parser.add_argument('--device', type=str, default='cuda:0')
31
+ parser.add_argument('--seed', type=int, default=42)
32
+
33
+ return parser
34
+
35
+
36
+ class StableDiffuser(torch.nn.Module):
37
+
38
+ def __init__(self,
39
+ scheduler='LMS'
40
+ ):
41
+
42
+ super().__init__()
43
+
44
+ # Load the autoencoder model which will be used to decode the latents into image space.
45
+ self.vae = AutoencoderKL.from_pretrained(
46
+ "CompVis/stable-diffusion-v1-4", subfolder="vae")
47
+
48
+ # Load the tokenizer and text encoder to tokenize and encode the text.
49
+ self.tokenizer = CLIPTokenizer.from_pretrained(
50
+ "openai/clip-vit-large-patch14")
51
+ self.text_encoder = CLIPTextModel.from_pretrained(
52
+ "openai/clip-vit-large-patch14")
53
+
54
+ # The UNet model for generating the latents.
55
+ self.unet = UNet2DConditionModel.from_pretrained(
56
+ "CompVis/stable-diffusion-v1-4", subfolder="unet")
57
+
58
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="feature_extractor")
59
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="safety_checker")
60
+
61
+ if scheduler == 'LMS':
62
+ self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
63
+ elif scheduler == 'DDIM':
64
+ self.scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
65
+ elif scheduler == 'DDPM':
66
+ self.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
67
+
68
+ self.eval()
69
+
70
+ def get_noise(self, batch_size, img_size, generator=None):
71
+
72
+ param = list(self.parameters())[0]
73
+
74
+ return torch.randn(
75
+ (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
76
+ generator=generator).type(param.dtype).to(param.device)
77
+
78
+ def add_noise(self, latents, noise, step):
79
+
80
+ return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
81
+
82
+ def text_tokenize(self, prompts):
83
+
84
+ return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
85
+
86
+ def text_detokenize(self, tokens):
87
+
88
+ return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
89
+
90
+ def text_encode(self, tokens):
91
+
92
+ return self.text_encoder(tokens.input_ids.to(self.unet.device))[0]
93
+
94
+ def decode(self, latents):
95
+
96
+ return self.vae.decode(1 / self.vae.config.scaling_factor * latents).sample
97
+
98
+ def encode(self, tensors):
99
+
100
+ return self.vae.encode(tensors).latent_dist.mode() * 0.18215
101
+
102
+ def to_image(self, image):
103
+
104
+ image = (image / 2 + 0.5).clamp(0, 1)
105
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
106
+ images = (image * 255).round().astype("uint8")
107
+ pil_images = [Image.fromarray(image) for image in images]
108
+
109
+ return pil_images
110
+
111
+ def set_scheduler_timesteps(self, n_steps):
112
+ self.scheduler.set_timesteps(n_steps, device=self.unet.device)
113
+
114
+ def get_initial_latents(self, n_imgs, img_size, n_prompts, generator=None):
115
+
116
+ noise = self.get_noise(n_imgs, img_size, generator=generator).repeat(n_prompts, 1, 1, 1)
117
+
118
+ latents = noise * self.scheduler.init_noise_sigma
119
+
120
+ return latents
121
+
122
+ def get_text_embeddings(self, prompts, n_imgs):
123
+
124
+ text_tokens = self.text_tokenize(prompts)
125
+
126
+ text_embeddings = self.text_encode(text_tokens)
127
+
128
+ unconditional_tokens = self.text_tokenize([""] * len(prompts))
129
+
130
+ unconditional_embeddings = self.text_encode(unconditional_tokens)
131
+
132
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
133
+
134
+ return text_embeddings
135
+
136
+ def predict_noise(self,
137
+ iteration,
138
+ latents,
139
+ text_embeddings,
140
+ guidance_scale=7.5
141
+ ):
142
+
143
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
144
+ latents = torch.cat([latents] * 2)
145
+ latents = self.scheduler.scale_model_input(
146
+ latents, self.scheduler.timesteps[iteration])
147
+
148
+ # predict the noise residual
149
+ noise_prediction = self.unet(
150
+ latents, self.scheduler.timesteps[iteration], encoder_hidden_states=text_embeddings).sample
151
+
152
+ # perform guidance
153
+ noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2)
154
+ noise_prediction = noise_prediction_uncond + guidance_scale * \
155
+ (noise_prediction_text - noise_prediction_uncond)
156
+
157
+ return noise_prediction
158
+
159
+ @torch.no_grad()
160
+ def diffusion(self,
161
+ latents,
162
+ text_embeddings,
163
+ end_iteration=1000,
164
+ start_iteration=0,
165
+ return_steps=False,
166
+ pred_x0=False,
167
+ trace_args=None,
168
+ show_progress=True,
169
+ **kwargs):
170
+
171
+ latents_steps = []
172
+ trace_steps = []
173
+
174
+ trace = None
175
+
176
+ for iteration in tqdm(range(start_iteration, end_iteration), disable=not show_progress):
177
+
178
+ if trace_args:
179
+
180
+ trace = TraceDict(self, **trace_args)
181
+
182
+ noise_pred = self.predict_noise(
183
+ iteration,
184
+ latents,
185
+ text_embeddings,
186
+ **kwargs)
187
+
188
+ # compute the previous noisy sample x_t -> x_t-1
189
+ output = self.scheduler.step(noise_pred, self.scheduler.timesteps[iteration], latents)
190
+
191
+ if trace_args:
192
+
193
+ trace.close()
194
+
195
+ trace_steps.append(trace)
196
+
197
+ latents = output.prev_sample
198
+
199
+ if return_steps or iteration == end_iteration - 1:
200
+
201
+ output = output.pred_original_sample if pred_x0 else latents
202
+
203
+ if return_steps:
204
+ latents_steps.append(output.cpu())
205
+ else:
206
+ latents_steps.append(output)
207
+
208
+ return latents_steps, trace_steps
209
+
210
+ @torch.no_grad()
211
+ def __call__(self,
212
+ prompts,
213
+ img_size=512,
214
+ n_steps=50,
215
+ n_imgs=1,
216
+ end_iteration=None,
217
+ generator=None,
218
+ **kwargs
219
+ ):
220
+
221
+ assert 0 <= n_steps <= 1000
222
+
223
+ if not isinstance(prompts, list):
224
+
225
+ prompts = [prompts]
226
+
227
+ self.set_scheduler_timesteps(n_steps)
228
+
229
+ latents = self.get_initial_latents(n_imgs, img_size, len(prompts), generator=generator)
230
+
231
+ text_embeddings = self.get_text_embeddings(prompts,n_imgs=n_imgs)
232
+
233
+ end_iteration = end_iteration or n_steps
234
+
235
+ latents_steps, trace_steps = self.diffusion(
236
+ latents,
237
+ text_embeddings,
238
+ end_iteration=end_iteration,
239
+ **kwargs
240
+ )
241
+
242
+ latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
243
+ images_steps = [self.to_image(latents) for latents in latents_steps]
244
+
245
+ for i in range(len(images_steps)):
246
+ self.safety_checker = self.safety_checker.float()
247
+ safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
248
+ image, has_nsfw_concept = self.safety_checker(
249
+ images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
250
+ )
251
+
252
+ images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
253
+
254
+ images_steps = list(zip(*images_steps))
255
+
256
+ if trace_steps:
257
+
258
+ return images_steps, trace_steps
259
+
260
+ return images_steps
261
+
262
+
263
+ if __name__ == '__main__':
264
+
265
+ parser = default_parser()
266
+
267
+ args = parser.parse_args()
268
+
269
+ diffuser = StableDiffuser(seed=args.seed, scheduler='DDIM').to(torch.device(args.device)).half()
270
+
271
+ images = diffuser(args.prompts,
272
+ n_steps=args.nsteps,
273
+ n_imgs=args.nimgs,
274
+ start_iteration=args.start_itr,
275
+ return_steps=args.return_steps,
276
+ pred_x0=args.pred_x0
277
+ )
278
+
279
+ util.image_grid(images, args.outpath)
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from finetuning import FineTunedModel
4
+ from StableDiffuser import StableDiffuser
5
+ from train import train
6
+
7
+ import os
8
+ model_map = {'Van Gogh' : 'models/vangogh.pt',
9
+ 'Pablo Picasso': 'models/pablopicasso.pt',
10
+ 'Car' : 'models/car.pt',
11
+ 'Garbage Truck': 'models/garbagetruck.pt',
12
+ 'French Horn': 'models/frenchhorn.pt',
13
+ 'Kilian Eng' : 'models/kilianeng.pt',
14
+ 'Thomas Kinkade' : 'models/thomaskinkade.pt',
15
+ 'Tyler Edlin' : 'models/tyleredlin.pt',
16
+ 'Kelly McKernan': 'models/kellymckernan.pt',
17
+ 'Rembrandt': 'models/rembrandt.pt' }
18
+
19
+ ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
20
+ SPACE_ID = os.getenv('SPACE_ID')
21
+
22
+ SHARED_UI_WARNING = f'''## Attention - Training using the ESD-u method does not work in this shared UI. You can either duplicate and use it with a gpu with at least 40GB, or clone this repository to run on your own machine.
23
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
24
+ '''
25
+
26
+
27
+ class Demo:
28
+
29
+ def __init__(self) -> None:
30
+
31
+ self.training = False
32
+ self.generating = False
33
+
34
+ self.diffuser = StableDiffuser(scheduler='DDIM').to('cuda').eval().half()
35
+
36
+ with gr.Blocks() as demo:
37
+ self.layout()
38
+ demo.queue(concurrency_count=5).launch()
39
+
40
+
41
+ def layout(self):
42
+
43
+ with gr.Row():
44
+
45
+ if SPACE_ID == ORIGINAL_SPACE_ID:
46
+
47
+ self.warning = gr.Markdown(SHARED_UI_WARNING)
48
+
49
+ with gr.Row():
50
+
51
+ with gr.Tab("Test") as inference_column:
52
+
53
+ with gr.Row():
54
+
55
+ self.explain_infr = gr.Markdown(interactive=False,
56
+ value='This is a demo of [Erasing Concepts from Stable Diffusion](https://erasing.baulab.info/). To try out a model where a concept has been erased, select a model and enter any prompt. For example, if you select the model "Van Gogh" you can generate images for the prompt "A portrait in the style of Van Gogh" and compare the erased and unerased models. We have also provided several other pre-fine-tuned models with artistic styles and objects erased (Check out the "ESD Model" drop-down). You can also train and run your own custom models. Check out the "train" section for custom erasure of concepts.')
57
+
58
+ with gr.Row():
59
+
60
+ with gr.Column(scale=1):
61
+
62
+ self.prompt_input_infr = gr.Text(
63
+ placeholder="Enter prompt...",
64
+ label="Prompt",
65
+ info="Prompt to generate"
66
+ )
67
+
68
+ with gr.Row():
69
+
70
+ self.model_dropdown = gr.Dropdown(
71
+ label="ESD Model",
72
+ choices= list(model_map.keys()),
73
+ value='Van Gogh',
74
+ interactive=True
75
+ )
76
+
77
+ self.seed_infr = gr.Number(
78
+ label="Seed",
79
+ value=42
80
+ )
81
+
82
+ with gr.Column(scale=2):
83
+
84
+ self.infr_button = gr.Button(
85
+ value="Generate",
86
+ interactive=True
87
+ )
88
+
89
+ with gr.Row():
90
+
91
+ self.image_new = gr.Image(
92
+ label="ESD",
93
+ interactive=False
94
+ )
95
+ self.image_orig = gr.Image(
96
+ label="SD",
97
+ interactive=False
98
+ )
99
+
100
+ with gr.Tab("Train") as training_column:
101
+
102
+ with gr.Row():
103
+
104
+ self.explain_train= gr.Markdown(interactive=False,
105
+ value='In this part you can erase any concept from Stable Diffusion. Enter a prompt for the concept or style you want to erase, and select ESD-x if you want to focus erasure on prompts that mention the concept explicitly. [NOTE: ESD-u is currently unavailable in this space. But you can duplicate the space and run it on GPU with VRAM >40GB for enabling ESD-u]. With default settings, it takes about 15 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/rohitgandikota/erasing).')
106
+
107
+ with gr.Row():
108
+
109
+ with gr.Column(scale=3):
110
+
111
+ self.prompt_input = gr.Text(
112
+ placeholder="Enter prompt...",
113
+ label="Prompt to Erase",
114
+ info="Prompt corresponding to concept to erase"
115
+ )
116
+
117
+ choices = ['ESD-x']
118
+ if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40:
119
+ choices.append('ESD-u')
120
+
121
+ self.train_method_input = gr.Dropdown(
122
+ choices=choices,
123
+ value='ESD-x',
124
+ label='Train Method',
125
+ info='Method of training'
126
+ )
127
+
128
+ self.neg_guidance_input = gr.Number(
129
+ value=1,
130
+ label="Negative Guidance",
131
+ info='Guidance of negative training used to train'
132
+ )
133
+
134
+ self.iterations_input = gr.Number(
135
+ value=150,
136
+ precision=0,
137
+ label="Iterations",
138
+ info='iterations used to train'
139
+ )
140
+
141
+ self.lr_input = gr.Number(
142
+ value=1e-5,
143
+ label="Learning Rate",
144
+ info='Learning rate used to train'
145
+ )
146
+
147
+ with gr.Column(scale=1):
148
+
149
+ self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
150
+
151
+ self.train_button = gr.Button(
152
+ value="Train",
153
+ )
154
+
155
+ self.download = gr.Files()
156
+
157
+ self.infr_button.click(self.inference, inputs = [
158
+ self.prompt_input_infr,
159
+ self.seed_infr,
160
+ self.model_dropdown
161
+ ],
162
+ outputs=[
163
+ self.image_new,
164
+ self.image_orig
165
+ ]
166
+ )
167
+ self.train_button.click(self.train, inputs = [
168
+ self.prompt_input,
169
+ self.train_method_input,
170
+ self.neg_guidance_input,
171
+ self.iterations_input,
172
+ self.lr_input
173
+ ],
174
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
175
+ )
176
+
177
+ def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
178
+
179
+ if self.training:
180
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
181
+
182
+ if train_method == 'ESD-x':
183
+
184
+ modules = ".*attn2$"
185
+ frozen = []
186
+
187
+ elif train_method == 'ESD-u':
188
+
189
+ modules = "unet$"
190
+ frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"]
191
+
192
+ elif train_method == 'ESD-self':
193
+
194
+ modules = ".*attn1$"
195
+ frozen = []
196
+
197
+ randn = torch.randint(1, 10000000, (1,)).item()
198
+
199
+ save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
200
+
201
+ self.training = True
202
+
203
+ train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
204
+
205
+ self.training = False
206
+
207
+ torch.cuda.empty_cache()
208
+
209
+ model_map['Custom'] = save_path
210
+
211
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
212
+
213
+
214
+ def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
215
+
216
+ seed = seed or 42
217
+
218
+ generator = torch.manual_seed(seed)
219
+
220
+ model_path = model_map[model_name]
221
+
222
+ checkpoint = torch.load(model_path)
223
+
224
+ finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
225
+
226
+ torch.cuda.empty_cache()
227
+
228
+ images = self.diffuser(
229
+ prompt,
230
+ n_steps=50,
231
+ generator=generator
232
+ )
233
+
234
+
235
+ orig_image = images[0][0]
236
+
237
+ torch.cuda.empty_cache()
238
+
239
+ generator = torch.manual_seed(seed)
240
+
241
+ with finetuner:
242
+
243
+ images = self.diffuser(
244
+ prompt,
245
+ n_steps=50,
246
+ generator=generator
247
+ )
248
+
249
+ edited_image = images[0][0]
250
+
251
+ del finetuner
252
+ torch.cuda.empty_cache()
253
+
254
+ return edited_image, orig_image
255
+
256
+
257
+ demo = Demo()
258
+
finetuning.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import re
3
+ import torch
4
+ import util
5
+
6
+ class FineTunedModel(torch.nn.Module):
7
+
8
+ def __init__(self,
9
+ model,
10
+ modules,
11
+ frozen_modules=[]
12
+ ):
13
+
14
+ super().__init__()
15
+
16
+ if isinstance(modules, str):
17
+ modules = [modules]
18
+
19
+ self.model = model
20
+ self.ft_modules = {}
21
+ self.orig_modules = {}
22
+
23
+ util.freeze(self.model)
24
+
25
+ for module_name, module in model.named_modules():
26
+ for ft_module_regex in modules:
27
+
28
+ match = re.search(ft_module_regex, module_name)
29
+
30
+ if match is not None:
31
+
32
+ ft_module = copy.deepcopy(module)
33
+
34
+ self.orig_modules[module_name] = module
35
+ self.ft_modules[module_name] = ft_module
36
+
37
+ util.unfreeze(ft_module)
38
+
39
+ print(f"=> Finetuning {module_name}")
40
+
41
+ for ft_module_name, module in ft_module.named_modules():
42
+
43
+ ft_module_name = f"{module_name}.{ft_module_name}"
44
+
45
+ for freeze_module_name in frozen_modules:
46
+
47
+ match = re.search(freeze_module_name, ft_module_name)
48
+
49
+ if match:
50
+ print(f"=> Freezing {ft_module_name}")
51
+ util.freeze(module)
52
+
53
+ self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
54
+ self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())
55
+
56
+
57
+ @classmethod
58
+ def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
59
+
60
+ if isinstance(checkpoint, str):
61
+ checkpoint = torch.load(checkpoint)
62
+
63
+ modules = [f"{key}$" for key in list(checkpoint.keys())]
64
+
65
+ ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules)
66
+ ftm.load_state_dict(checkpoint)
67
+
68
+ return ftm
69
+
70
+
71
+ def __enter__(self):
72
+
73
+ for key, ft_module in self.ft_modules.items():
74
+ util.set_module(self.model, key, ft_module)
75
+
76
+ def __exit__(self, exc_type, exc_value, tb):
77
+
78
+ for key, module in self.orig_modules.items():
79
+ util.set_module(self.model, key, module)
80
+
81
+ def parameters(self):
82
+
83
+ parameters = []
84
+
85
+ for ft_module in self.ft_modules.values():
86
+
87
+ parameters.extend(list(ft_module.parameters()))
88
+
89
+ return parameters
90
+
91
+ def state_dict(self):
92
+
93
+ state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
94
+
95
+ return state_dict
96
+
97
+ def load_state_dict(self, state_dict):
98
+
99
+ for key, sd in state_dict.items():
100
+
101
+ self.ft_modules[key].load_state_dict(sd)
images/applications.png ADDED

Git LFS Details

  • SHA256: deec6225d4533fb66380321c76a918ff0b9932502192a25f2f32e6f11f2c5db2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.01 MB
models/car.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a486d8417dc06dcdadfafe738ca32fb9d48f3a1a144d96cb2781e9e5f0c6f98
3
+ size 3438317621
models/frenchhorn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48dc787885e54bbad818b57205ccb39796051b07b98a00b43776fb7d7e375fc0
3
+ size 3438372469
models/garbagetruck.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e8bc8d2d973e941a16a03cae81113b5f5da07245888dd29934ea77af9242aba
3
+ size 3438373845
models/kellymckernan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dee79125ec560c09fcd8820f2df9e3bee94b8a8e8ef3c8e9b330a80ec87cf45e
3
+ size 175879857
models/kilianeng.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f99b66762771ad4da29d7797b5f20b3b14390082a713d8891c800c759fafd11c
3
+ size 175878873
models/pablopicasso.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2988ba3a00af8e7796fdd459d31c63e45470d5b00b88fa0a18f1722c8c55fd9a
3
+ size 175879775
models/rembrandt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b95b9bd8d13f2cd8fc389ec2ec6246a3da334cc0c09cba323a3743ae2453cf58
3
+ size 175879529
models/thomaskinkade.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a31bd7f5bc1bf838e9443a7fd3f729d5326d5d44847a23dffdf6611591ea24a
3
+ size 175879201
models/tyleredlin.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:830a00af362c2805be1f41acd121b5b74936f75b08c94ab44c68fe1b355af901
3
+ size 175878955
models/vangogh.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75cdb4313898f593b16f23dbceca498f3f16a749802450ab358a12c204404c27
3
+ size 175873179
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch==1.13.1 --index-url https://download.pytorch.org/whl/cu118
3
+ torchvision==0.14.1 --index-url https://download.pytorch.org/whl/cu118
4
+ diffusers
5
+ transformers
6
+ accelerate
7
+ scipy
8
+ git+https://github.com/davidbau/baukit.git
train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from StableDiffuser import StableDiffuser
2
+ from finetuning import FineTunedModel
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
7
+
8
+ nsteps = 50
9
+
10
+ diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
+ diffuser.train()
12
+
13
+ finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
14
+
15
+ optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
16
+ criteria = torch.nn.MSELoss()
17
+
18
+ pbar = tqdm(range(iterations))
19
+
20
+ with torch.no_grad():
21
+
22
+ neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
23
+ positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
24
+
25
+ del diffuser.vae
26
+ del diffuser.text_encoder
27
+ del diffuser.tokenizer
28
+
29
+ torch.cuda.empty_cache()
30
+
31
+ for i in pbar:
32
+
33
+ with torch.no_grad():
34
+
35
+ diffuser.set_scheduler_timesteps(nsteps)
36
+
37
+ optimizer.zero_grad()
38
+
39
+ iteration = torch.randint(1, nsteps - 1, (1,)).item()
40
+
41
+ latents = diffuser.get_initial_latents(1, 512, 1)
42
+
43
+ with finetuner:
44
+
45
+ latents_steps, _ = diffuser.diffusion(
46
+ latents,
47
+ positive_text_embeddings,
48
+ start_iteration=0,
49
+ end_iteration=iteration,
50
+ guidance_scale=3,
51
+ show_progress=False
52
+ )
53
+
54
+ diffuser.set_scheduler_timesteps(1000)
55
+
56
+ iteration = int(iteration / nsteps * 1000)
57
+
58
+ positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
59
+ neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
60
+
61
+ with finetuner:
62
+ negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
63
+
64
+ positive_latents.requires_grad = False
65
+ neutral_latents.requires_grad = False
66
+
67
+ loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
68
+
69
+ loss.backward()
70
+ optimizer.step()
71
+
72
+ torch.save(finetuner.state_dict(), save_path)
73
+
74
+ del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
75
+
76
+ torch.cuda.empty_cache()
77
+ if __name__ == '__main__':
78
+
79
+ import argparse
80
+
81
+ parser = argparse.ArgumentParser()
82
+
83
+ parser.add_argument('--prompt', required=True)
84
+ parser.add_argument('--modules', required=True)
85
+ parser.add_argument('--freeze_modules', nargs='+', required=True)
86
+ parser.add_argument('--save_path', required=True)
87
+ parser.add_argument('--iterations', type=int, required=True)
88
+ parser.add_argument('--lr', type=float, required=True)
89
+ parser.add_argument('--negative_guidance', type=float, required=True)
90
+
91
+ train(**vars(parser.parse_args()))
util.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from matplotlib import pyplot as plt
3
+ import textwrap
4
+
5
+
6
+ def to_gif(images, path):
7
+
8
+ images[0].save(path, save_all=True,
9
+ append_images=images[1:], loop=0, duration=len(images) * 20)
10
+
11
+
12
+ def figure_to_image(figure):
13
+
14
+ figure.set_dpi(300)
15
+
16
+ figure.canvas.draw()
17
+
18
+ return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
19
+
20
+
21
+ def image_grid(images, outpath=None, column_titles=None, row_titles=None):
22
+
23
+ n_rows = len(images)
24
+ n_cols = len(images[0])
25
+
26
+ fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
27
+ figsize=(n_cols, n_rows), squeeze=False)
28
+
29
+ for row, _images in enumerate(images):
30
+
31
+ for column, image in enumerate(_images):
32
+ ax = axs[row][column]
33
+ ax.imshow(image)
34
+ if column_titles and row == 0:
35
+ ax.set_title(textwrap.fill(
36
+ column_titles[column], width=12), fontsize='x-small')
37
+ if row_titles and column == 0:
38
+ ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row]))
39
+ ax.set_xticks([])
40
+ ax.set_yticks([])
41
+
42
+ plt.subplots_adjust(wspace=0, hspace=0)
43
+
44
+ if outpath is not None:
45
+ plt.savefig(outpath, bbox_inches='tight', dpi=300)
46
+ plt.close()
47
+ else:
48
+ plt.tight_layout(pad=0)
49
+ image = figure_to_image(plt.gcf())
50
+ plt.close()
51
+ return image
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+ def get_module(module, module_name):
60
+
61
+ if isinstance(module_name, str):
62
+ module_name = module_name.split('.')
63
+
64
+ if len(module_name) == 0:
65
+ return module
66
+ else:
67
+ module = getattr(module, module_name[0])
68
+ return get_module(module, module_name[1:])
69
+
70
+
71
+ def set_module(module, module_name, new_module):
72
+
73
+ if isinstance(module_name, str):
74
+ module_name = module_name.split('.')
75
+
76
+ if len(module_name) == 1:
77
+ return setattr(module, module_name[0], new_module)
78
+ else:
79
+ module = getattr(module, module_name[0])
80
+ return set_module(module, module_name[1:], new_module)
81
+
82
+
83
+ def freeze(module):
84
+
85
+ for parameter in module.parameters():
86
+
87
+ parameter.requires_grad = False
88
+
89
+
90
+ def unfreeze(module):
91
+
92
+ for parameter in module.parameters():
93
+
94
+ parameter.requires_grad = True
95
+
96
+
97
+ def get_concat_h(im1, im2):
98
+ dst = Image.new('RGB', (im1.width + im2.width, im1.height))
99
+ dst.paste(im1, (0, 0))
100
+ dst.paste(im2, (im1.width, 0))
101
+ return dst
102
+
103
+ def get_concat_v(im1, im2):
104
+ dst = Image.new('RGB', (im1.width, im1.height + im2.height))
105
+ dst.paste(im1, (0, 0))
106
+ dst.paste(im2, (0, im1.height))
107
+ return dst