venkyyuvy commited on
Commit
99f1917
1 Parent(s): 369bbda

init commit

Browse files
Marc Allante.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b0496315f14f212535f9350c3dbf05787ac50a78465d4be2f39a1ba373e4968
3
+ size 3819
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import torch
5
+ from image_generator import generate_image_per_prompt_style
6
+
7
+ torch.manual_seed(11)
8
+
9
+
10
+ # Set device
11
+ torch_device = (
12
+ "cuda"
13
+ if torch.cuda.is_available()
14
+ else "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cpu"
17
+ )
18
+ if "mps" == torch_device:
19
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
20
+ # Define Interface
21
+ title = "Generative Art - Stable Diffusion with Styles and additional guidance"
22
+
23
+ gr_interface = gr.Interface(
24
+ generate_image_per_prompt_style,
25
+ inputs=[
26
+ gr.Textbox("cat running", label="Prompt"),
27
+ gr.Dropdown(
28
+ [
29
+ "illustration_style",
30
+ "line-art",
31
+ "hitokomoru-style",
32
+ "midjourney-style",
33
+ "hanfu-anime-style",
34
+ "birb-style",
35
+ "style-of-marc-allante",
36
+ ],
37
+ value="birb-style",
38
+ label="Pre-trained Styles",
39
+ ),
40
+ gr.Dropdown(
41
+ [
42
+ "blue_loss",
43
+ "cosine_loss",
44
+ ],
45
+ value="cosine_loss",
46
+ label="Additional guidance for image generation",
47
+ ),
48
+ gr.Textbox("on a city road", label="Additional Prompt"),
49
+ ],
50
+ outputs=[
51
+ gr.Gallery(
52
+ label="Generated images",
53
+ show_label=False,
54
+ elem_id="gallery",
55
+ columns=[2],
56
+ rows=[2],
57
+ object_fit="contain",
58
+ height="auto",
59
+ )
60
+ ],
61
+ title=title,
62
+ examples=[
63
+ ["A flying bird", "illustration_style", "blue_loss", ""],
64
+ ["cat running", "on a city road", "cosine_loss", ""]
65
+ ]
66
+ )
67
+ gr_interface.launch(debug=True)
68
+
69
+
birb-style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2e23a8f2d3628ed77acb8151751ecd4efc4017e8da86bc29af10f855ca308d9
3
+ size 3819
hanfu-anime-style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18ee85c31cff7a0ab35f90af24fbf1a4ab8a9960ab041511e386d5990953e050
3
+ size 3819
hitokomoru-style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f81a9c575e329e08a24e08f47ae73c5b50dec4bcb557974552549b45e2d1b0d4
3
+ size 3819
illustration_style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44d65046c071e37f75f31a7a81a34c50a96080e8a3aedc7cda1094dae5d385f0
3
+ size 3819
image_generator.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import torch
4
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
5
+ from tqdm.auto import tqdm
6
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
7
+
8
+ from utils import load_embedding_bin, set_timesteps, latents_to_pil
9
+ from loss import blue_loss, cosine_loss
10
+ from matplotlib import pyplot as plt
11
+ from pathlib import Path
12
+
13
+ torch.manual_seed(11)
14
+ logging.set_verbosity_error()
15
+
16
+ # Set device
17
+ torch_device = (
18
+ "cuda"
19
+ if torch.cuda.is_available()
20
+ else "mps"
21
+ if torch.backends.mps.is_available()
22
+ else "cpu"
23
+ )
24
+ if "mps" == torch_device:
25
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
26
+ # Style embeddings
27
+ STYLE_EMBEDDINGS = {
28
+ "illustration-style": "illustration_style.bin",
29
+ "line-art": "line-art.bin",
30
+ "hitokomoru-style": "hitokomoru-style.bin",
31
+ "midjourney-style": "midjourney-style.bin",
32
+ "hanfu-anime-style": "hanfu-anime-style.bin",
33
+ "birb-style": "birb-style.bin",
34
+ "style-of-marc-allante": "Marc Allante.bin",
35
+ }
36
+ LOSS = {"blue_loss": blue_loss,
37
+ "cosine_loss": cosine_loss}
38
+ STYLE_SEEDS = [11, 56, 110, 65, 5, 29, 47]
39
+ # Load the autoencoder model which will be used to decode the latents into image space.
40
+ vae = AutoencoderKL.from_pretrained(
41
+ "CompVis/stable-diffusion-v1-4", subfolder="vae"
42
+ ).to(torch_device)
43
+ #
44
+ # # Load the tokenizer and text encoder to tokenize and encode the text.
45
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
46
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(
47
+ torch_device
48
+ )
49
+ #
50
+ # # The UNet model for generating the latents.
51
+ unet = UNet2DConditionModel.from_pretrained(
52
+ "CompVis/stable-diffusion-v1-4", subfolder="unet"
53
+ ).to(torch_device)
54
+ #
55
+ # # The noise scheduler
56
+ scheduler = LMSDiscreteScheduler(
57
+ beta_start=0.00085,
58
+ beta_end=0.012,
59
+ beta_schedule="scaled_linear",
60
+ num_train_timesteps=1000,
61
+ )
62
+
63
+ # vae = vae
64
+ # text_encoder = text_encoder.to(torch_device)
65
+ unet = unet
66
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
67
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
68
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
69
+ position_embeddings = pos_emb_layer(position_ids)
70
+
71
+
72
+ def build_causal_attention_mask(bsz, seq_len, dtype):
73
+ # lazily create causal attention mask, with full attention between the vision tokens
74
+ # pytorch uses additive attention mask; fill with -inf
75
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
76
+ mask.fill_(torch.tensor(torch.finfo(dtype).min))
77
+ mask.triu_(1) # zero out the lower diagonal
78
+ mask = mask.unsqueeze(1) # expand mask
79
+ return mask
80
+
81
+
82
+ def get_output_embeds(input_embeddings):
83
+ # CLIP's text model uses causal mask, so we prepare it here:
84
+ bsz, seq_len = input_embeddings.shape[:2]
85
+ causal_attention_mask = build_causal_attention_mask(
86
+ bsz, seq_len, dtype=input_embeddings.dtype
87
+ )
88
+
89
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
90
+ # so that it doesn't just return the pooled final predictions:
91
+ encoder_outputs = text_encoder.text_model.encoder(
92
+ inputs_embeds=input_embeddings,
93
+ attention_mask=None, # We aren't using an attention mask so that can be None
94
+ causal_attention_mask=causal_attention_mask.to(torch_device),
95
+ output_attentions=None,
96
+ output_hidden_states=True, # We want the output embs not the final output
97
+ return_dict=None,
98
+ )
99
+
100
+ # We're interested in the output hidden state only
101
+ output = encoder_outputs[0]
102
+
103
+ # There is a final layer norm we need to pass these through
104
+ output = text_encoder.text_model.final_layer_norm(output)
105
+
106
+ # And now they're ready!
107
+ return output
108
+
109
+
110
+ # Generating an image with these modified embeddings
111
+ def generate_with_embs(text_embeddings, seed, max_length):
112
+ height = 512 # default height of Stable Diffusion
113
+ width = 512 # default width of Stable Diffusion
114
+ num_inference_steps = 30 # Number of denoising steps
115
+ guidance_scale = 7.5 # Scale for classifier-free guidance
116
+ generator = torch.manual_seed(seed)
117
+ batch_size = 1
118
+
119
+ # tokenizer
120
+ uncond_input = tokenizer(
121
+ [""] * batch_size,
122
+ padding="max_length",
123
+ max_length=max_length,
124
+ return_tensors="pt",
125
+ )
126
+
127
+ with torch.no_grad():
128
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
129
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
130
+
131
+ # Prep Scheduler
132
+ set_timesteps(scheduler, num_inference_steps)
133
+
134
+ # Prep latents
135
+ # step = " prep_latents "
136
+ latents = torch.randn(
137
+ (batch_size, unet.in_channels, height // 8, width // 8),
138
+ generator=generator,
139
+ )
140
+ latents = latents.to(torch_device)
141
+ latents = latents * scheduler.init_noise_sigma
142
+
143
+ # Loop
144
+ for i, t in tqdm(enumerate(scheduler.timesteps),
145
+ total=len(scheduler.timesteps)):
146
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
147
+ latent_model_input = torch.cat([latents] * 2)
148
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
149
+
150
+ # predict the noise residual
151
+ with torch.no_grad():
152
+ noise_pred = unet(
153
+ latent_model_input, t, encoder_hidden_states=text_embeddings
154
+ )["sample"]
155
+
156
+ # perform guidance
157
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
158
+ noise_pred = noise_pred_uncond + guidance_scale * (
159
+ noise_pred_text - noise_pred_uncond
160
+ )
161
+
162
+ # compute the previous noisy sample x_t -> x_t-1
163
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
164
+
165
+ return latents_to_pil(latents)[0]
166
+
167
+
168
+ def generate_image_from_embeddings(
169
+ mod_output_embeddings, seed, max_length,
170
+ loss_selection, additional_prompt):
171
+ height = 512
172
+ width = 512
173
+ num_inference_steps = 50
174
+ guidance_scale = 8
175
+ generator = torch.manual_seed(seed)
176
+ batch_size = 1
177
+ if loss_selection == "blue_loss":
178
+ loss_fn = LOSS["blue_loss"]
179
+ loss_scale = 120
180
+ else:
181
+ loss_fn = LOSS["cosine_loss"](additional_prompt)
182
+ loss_scale = 20
183
+
184
+ # Use the modified_output_embeddings directly
185
+ text_embeddings = mod_output_embeddings
186
+
187
+ uncond_input = tokenizer(
188
+ [""] * batch_size,
189
+ padding="max_length",
190
+ max_length=max_length,
191
+ return_tensors="pt",
192
+ )
193
+ with torch.no_grad():
194
+ uncond_embeddings = text_encoder(
195
+ uncond_input.input_ids.to(torch_device))[0]
196
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
197
+
198
+ # Prep Scheduler
199
+ set_timesteps(scheduler, num_inference_steps)
200
+
201
+ # Prep latents
202
+ latents = torch.randn(
203
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
204
+ generator=generator,
205
+ )
206
+ latents = latents.to(torch_device)
207
+ latents = latents * scheduler.init_noise_sigma
208
+
209
+ # Loop
210
+ for i, t in tqdm(enumerate(scheduler.timesteps),
211
+ total=len(scheduler.timesteps)):
212
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
213
+ latent_model_input = torch.cat([latents] * 2)
214
+ sigma = scheduler.sigmas[i]
215
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
216
+
217
+ # predict the noise residual
218
+ with torch.no_grad():
219
+ noise_pred = unet(
220
+ latent_model_input, t, encoder_hidden_states=text_embeddings
221
+ )["sample"]
222
+
223
+ # perform CFG
224
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
225
+ noise_pred = noise_pred_uncond + guidance_scale * (
226
+ noise_pred_text - noise_pred_uncond
227
+ )
228
+
229
+ #### ADDITIONAL GUIDANCE ###
230
+ if i % 2 == 0:
231
+ # Requires grad on the latents
232
+ latents = latents.detach().requires_grad_()
233
+
234
+ # Get the predicted x0:
235
+ # latents_x0 = latents - sigma * noise_pred
236
+ latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
237
+ scheduler._step_index -= 1
238
+ # Decode to image space
239
+ denoised_images = (
240
+ vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
241
+ ) # range (0, 1)
242
+
243
+ # Calculate loss
244
+ loss = loss_fn(denoised_images) * loss_scale
245
+
246
+ # Occasionally print it out
247
+ if i % 10 == 0:
248
+ print(i, "loss:", loss.item())
249
+
250
+ # Get gradient
251
+ cond_grad = torch.autograd.grad(loss, latents)[0]
252
+
253
+ # Modify the latents based on this gradient
254
+ latents = latents.detach() - cond_grad * sigma**2
255
+
256
+ # Now step with scheduler
257
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
258
+
259
+ return latents_to_pil(latents)[0]
260
+
261
+
262
+ def generate_image_per_style(prompt, style_embed, style_seed, style_embedding_key):
263
+ modified_output_embeddings = None
264
+ gen_out_style_image = None
265
+ max_length = 0
266
+
267
+ # Tokenize
268
+ text_input = tokenizer(
269
+ prompt,
270
+ padding="max_length",
271
+ max_length=tokenizer.model_max_length,
272
+ truncation=True,
273
+ return_tensors="pt",
274
+ )
275
+ input_ids = text_input.input_ids.to(torch_device)
276
+
277
+ # Get token embeddings
278
+ token_embeddings = token_emb_layer(input_ids)
279
+
280
+ replacement_token_embedding = style_embed[style_embedding_key]
281
+
282
+ # Insert this into the token embeddings
283
+ token_embeddings[
284
+ 0, torch.where(input_ids[0] == 6829)[0]
285
+ ] = replacement_token_embedding.to(torch_device)
286
+
287
+ # Combine with pos embs
288
+ input_embeddings = token_embeddings + position_embeddings
289
+
290
+ # Feed through to get final output embs
291
+ modified_output_embeddings = get_output_embeds(input_embeddings)
292
+
293
+ # And generate an image with this:
294
+ max_length = text_input.input_ids.shape[-1]
295
+
296
+ gen_out_style_image = generate_with_embs(
297
+ modified_output_embeddings, style_seed, max_length
298
+ )
299
+
300
+ return gen_out_style_image
301
+
302
+
303
+ def generate_image_per_loss(
304
+ prompt, style_embed, style_seed, style_embedding_key,
305
+ loss, additional_prompt
306
+ ):
307
+ gen_out_loss_image = None
308
+
309
+ # Tokenize
310
+ text_input = tokenizer(
311
+ prompt,
312
+ padding="max_length",
313
+ max_length=tokenizer.model_max_length,
314
+ truncation=True,
315
+ return_tensors="pt",
316
+ )
317
+ input_ids = text_input.input_ids.to(torch_device)
318
+
319
+ # Get token embeddings
320
+ token_embeddings = token_emb_layer(input_ids)
321
+
322
+ replacement_token_embedding = style_embed[style_embedding_key].to(torch_device)
323
+
324
+ # Insert this into the token embeddings
325
+ token_embeddings[
326
+ 0, torch.where(input_ids[0] == 6829)[0]
327
+ ] = replacement_token_embedding
328
+
329
+ # Combine with pos embs
330
+ input_embeddings = token_embeddings + position_embeddings
331
+ modified_output_embeddings = get_output_embeds(input_embeddings)
332
+
333
+ # max_length = tokenizer.model_max_length
334
+
335
+ max_length = text_input.input_ids.shape[-1]
336
+ gen_out_loss_image = generate_image_from_embeddings(
337
+ modified_output_embeddings, style_seed, max_length,
338
+ loss, additional_prompt
339
+ )
340
+
341
+ return gen_out_loss_image
342
+
343
+
344
+ def generate_image_per_prompt_style(text_in, style_in,
345
+ loss, additional_prompt):
346
+ gen_style_image = None
347
+ gen_loss_image = None
348
+ STYLE_KEYS = []
349
+ style_key = ""
350
+
351
+ if style_in not in STYLE_EMBEDDINGS:
352
+ raise ValueError(
353
+ f"Unknown style: {style_in}. Available styles are: {', '.join(STYLE_EMBEDDINGS.keys())}"
354
+ )
355
+
356
+ STYLE_SEEDS = [32, 64, 128, 16, 8, 96]
357
+ STYLE_KEYS = list(STYLE_EMBEDDINGS.keys())
358
+ print(f"prompt: {text_in}")
359
+ print(f"style: {style_in}")
360
+
361
+ idx = STYLE_KEYS.index(style_in)
362
+ style_file = STYLE_EMBEDDINGS[style_in]
363
+ print(f"style_file: {style_file}")
364
+
365
+ prompt = text_in
366
+
367
+ style_seed = STYLE_SEEDS[idx]
368
+
369
+ style_key = Path(style_file).stem
370
+ style_key = style_key.replace("_", "-")
371
+ print(style_key, STYLE_KEYS, style_file)
372
+
373
+ file_path = os.path.join(os.getcwd(), style_file)
374
+ embedding = load_embedding_bin(file_path)
375
+ style_key = f"<{style_key}>"
376
+
377
+ gen_style_image = generate_image_per_style(prompt, embedding, style_seed, style_key)
378
+
379
+ gen_loss_image = generate_image_per_loss(prompt, embedding, style_seed, style_key, loss, additional_prompt)
380
+
381
+ return [gen_style_image, gen_loss_image]
line-art.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0528436ec2228c659e0cf1316e713345bc97a3d88294f1a2987a3505d220e770
3
+ size 3819
loss.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms import v2
5
+ from transformers import CLIPTextModel, CLIPTokenizer, \
6
+ CLIPProcessor, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
7
+
8
+ import os
9
+ # from image_generator import get_output_embeds, position_embeddings
10
+
11
+
12
+ # Set device
13
+ torch_device = "cuda" if torch.cuda.is_available() else "mps" \
14
+ if torch.backends.mps.is_available() else "cpu"
15
+
16
+ if "mps" == torch_device:
17
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
18
+
19
+ # Load the tokenizer and text encoder to tokenize and encode the text.
20
+ clip_model_name = "openai/clip-vit-large-patch14"
21
+ tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
22
+ text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(torch_device);
23
+ vision_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_model_name).to(torch_device);
24
+ processor = CLIPProcessor.from_pretrained(clip_model_name)
25
+
26
+ # # additional textual prompt
27
+ def get_text_embed(prompt = "on a mountain"):
28
+ inputs = processor(text=prompt,
29
+ return_tensors="pt",
30
+ padding=True)
31
+ with torch.no_grad():
32
+ text_embed = CLIPTextModelWithProjection.from_pretrained(
33
+ clip_model_name)(**inputs).text_embeds.to(torch_device)
34
+ return text_embed
35
+
36
+ # def get_text_embed(prompt = "on a mountain"):
37
+ # text_input = tokenizer([prompt],
38
+ # padding="max_length",
39
+ # max_length=tokenizer.model_max_length,
40
+ # truncation=True,
41
+ # return_tensors="pt")
42
+ # with torch.no_grad():
43
+ # text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
44
+ # input_embeddings = text_embeddings + position_embeddings.to(torch_device)
45
+ # modified_output_embeddings = get_output_embeds(input_embeddings)
46
+ # return modified_output_embeddings
47
+
48
+ class cosine_loss(nn.Module):
49
+ def __init__(self, prompt) -> None:
50
+ self.text_embed = get_text_embed(prompt)
51
+ super().__init__()
52
+
53
+ def forward(self, gen_image):
54
+ gen_image_clamped = gen_image.clamp(0, 1).mul(255)
55
+ resized_image = v2.Resize(224)(gen_image_clamped)
56
+ image_embed = vision_encoder(resized_image).image_embeds
57
+ similarity = F.cosine_similarity(self.text_embed, image_embed, dim=1)
58
+ loss = 1 - similarity.mean()
59
+ return loss
60
+
61
+ def blue_loss(images):
62
+ # How far are the blue channel values to 0.9:
63
+ error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
64
+ return error
midjourney-style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4865a5d2ecd012985940748023fd80e4fd299837f1dccedb85ee83be5bb1f957
3
+ size 3819
utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from diffusers import AutoencoderKL
4
+
5
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to("mps:0")
6
+
7
+ def pil_to_latent(input_im):
8
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
9
+ with torch.no_grad():
10
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
11
+ return 0.18215 * latent.latent_dist.sample()
12
+
13
+ def latents_to_pil(latents, torch_device="mps:0"):
14
+ # bath of latents -> list of images
15
+ latents = (1 / 0.18215) * latents
16
+ with torch.no_grad():
17
+ image = vae.decode(latents.to(torch_device)).sample
18
+ image = (image / 2 + 0.5).clamp(0, 1)
19
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
20
+ images = (image * 255).round().astype("uint8")
21
+ pil_images = [Image.fromarray(image) for image in images]
22
+ return pil_images
23
+
24
+ def load_embedding_bin(path):
25
+ return torch.load(path)
26
+
27
+ # Prep Scheduler
28
+ def set_timesteps(scheduler, num_inference_steps):
29
+ scheduler.set_timesteps(num_inference_steps)
30
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925