AYYasaswini commited on
Commit
71472f9
·
verified ·
1 Parent(s): 0c23a88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +762 -92
app.py CHANGED
@@ -1,8 +1,12 @@
1
  import gradio as gr
2
  from base64 import b64encode
 
3
  import numpy
4
  import torch
5
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
 
 
 
6
  from PIL import Image
7
  from torch import autocast
8
  from torchvision import transforms as tfms
@@ -11,9 +15,15 @@ from transformers import CLIPTextModel, CLIPTokenizer, logging
11
  import torchvision.transforms as T
12
 
13
  torch.manual_seed(1)
 
 
 
14
  logging.set_verbosity_error()
 
15
  torch_device = "cpu"
16
 
 
 
17
  vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
18
 
19
  # Load the tokenizer and text encoder to tokenize and encode the text.
@@ -26,14 +36,43 @@ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", sub
26
  # The noise scheduler
27
  scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
28
 
 
29
  vae = vae.to(torch_device)
30
  text_encoder = text_encoder.to(torch_device)
31
  unet = unet.to(torch_device);
32
 
33
- token_emb_layer = text_encoder.text_model.embeddings.token_embedding
34
- pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
35
- position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
36
- position_embeddings = pos_emb_layer(position_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def pil_to_latent(input_im):
39
  # Single image -> single latent in a batch (so size 1, 4, 64, 64)
@@ -52,6 +91,147 @@ def latents_to_pil(latents):
52
  pil_images = [Image.fromarray(image) for image in images]
53
  return pil_images
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def get_output_embeds(input_embeddings):
56
  # CLIP's text model uses causal mask, so we prepare it here:
57
  bsz, seq_len = input_embeddings.shape[:2]
@@ -77,15 +257,55 @@ def get_output_embeds(input_embeddings):
77
  # And now they're ready!
78
  return output
79
 
80
- def generate_with_embs(text_embeddings, seed, max_length):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  height = 512 # default height of Stable Diffusion
82
  width = 512 # default width of Stable Diffusion
83
- num_inference_steps = 10 # Number of denoising steps
84
  guidance_scale = 7.5 # Scale for classifier-free guidance
85
- generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
86
  batch_size = 1
87
 
88
- # max_length = text_input.input_ids.shape[-1]
89
  uncond_input = tokenizer(
90
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
91
  )
@@ -124,109 +344,479 @@ def generate_with_embs(text_embeddings, seed, max_length):
124
 
125
  return latents_to_pil(latents)[0]
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # Prep Scheduler
128
- def set_timesteps(scheduler, num_inference_steps):
129
- scheduler.set_timesteps(num_inference_steps)
130
- scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- def embed_style(prompt, style_embed, style_seed):
133
- # Tokenize
134
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
135
- input_ids = text_input.input_ids.to(torch_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- # Get token embeddings
138
- token_embeddings = token_emb_layer(input_ids)
 
139
 
140
- replacement_token_embedding = style_embed.to(torch_device)
141
- # replacement_token_embedding = birb_embed[embed_values[4]].to(torch_device)
142
- # Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
143
- replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
144
- replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
145
 
146
- # Insert this into the token embeddings
147
- # token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
148
- indices = torch.where(input_ids[0] == 6829)[0]
149
- for index in indices:
150
- token_embeddings[0, index] = replacement_token_embedding.to(torch_device)
151
- # Combine with pos embs
152
- input_embeddings = token_embeddings + position_embeddings
153
 
154
- # Feed through to get final output embs
155
- modified_output_embeddings = get_output_embeds(input_embeddings)
 
156
 
157
- # And generate an image with this:
158
- max_length = text_input.input_ids.shape[-1]
159
- return generate_with_embs(modified_output_embeddings, style_seed, max_length)
 
 
160
 
161
- def loss_style(prompt, style_embed, style_seed):
162
- # Tokenize
163
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
164
- input_ids = text_input.input_ids.to(torch_device)
 
 
 
 
 
165
 
166
- # Get token embeddings
167
- token_embeddings = token_emb_layer(input_ids)
168
-
169
 
170
- # The new embedding - our special birb word
171
- replacement_token_embedding = style_embed.to(torch_device)
172
- # Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
173
- replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
174
- replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
175
- indices = torch.where(input_ids[0] == 6829)[0] # Extract indices where the condition is True
176
- print(f"indices: {indices}") # Debug print
177
- for index in indices:
178
- print(f"index: {index}") # Debug print
179
- token_embeddings[0, index] = replacement_token_embedding.to(torch_device) # Update each index
180
-
181
- # Insert this into the token embeddings
182
- # token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
183
 
184
- # Combine with pos embs
185
- input_embeddings = token_embeddings + position_embeddings
186
 
187
- # Feed through to get final output embs
188
- modified_output_embeddings = get_output_embeds(input_embeddings)
 
189
 
190
- # And generate an image with this:
191
- max_length = text_input.input_ids.shape[-1]
192
- return generate_loss_based_image(modified_output_embeddings, style_seed,max_length)
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- def sepia_loss(images):
196
- sepia_tone = 0.393 * images[:,0] + 0.769 * images[:,1] + 0.189 * images[:,2]
197
- error = torch.abs(sepia_tone - 0.5).mean()
198
  return error
199
 
200
- def generate_loss_based_image(text_embeddings, seed, max_length):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- height = 64
203
- width = 64
204
- num_inference_steps = 10
205
- guidance_scale = 8
206
- generator = torch.manual_seed(64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  batch_size = 1
208
- loss_scale = 200
 
 
 
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  uncond_input = tokenizer(
211
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
212
  )
213
  with torch.no_grad():
214
  uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
215
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
 
 
 
 
216
 
217
  # Prep Scheduler
218
- set_timesteps(scheduler, num_inference_steps+1)
219
 
220
  # Prep latents
221
  latents = torch.randn(
222
- (batch_size, unet.in_channels, height // 8, width // 8),
223
- generator=generator,
224
  )
225
  latents = latents.to(torch_device)
226
  latents = latents * scheduler.init_noise_sigma
227
 
228
- sched_out = None
229
-
230
  # Loop
231
  for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
232
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
@@ -242,39 +832,116 @@ def generate_loss_based_image(text_embeddings, seed, max_length):
242
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
243
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
244
 
245
- ### ADDITIONAL GUIDANCE ###
246
- if i%5 == 0 and i>0:
247
  # Requires grad on the latents
248
  latents = latents.detach().requires_grad_()
249
 
250
  # Get the predicted x0:
251
- scheduler._step_index -= 1
252
- latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
253
 
254
  # Decode to image space
255
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
256
 
257
-
258
  # Calculate loss
259
- loss = sepia_loss(denoised_images) * loss_scale
260
 
261
  # Occasionally print it out
262
- # if i%10==0:
263
- print(i, 'loss:', loss)
264
 
265
  # Get gradient
266
  cond_grad = torch.autograd.grad(loss, latents)[0]
267
 
268
  # Modify the latents based on this gradient
269
- latents = latents.detach() - cond_grad * sigma**2
270
- # To PIL Images
271
- im_t0 = latents_to_pil(latents_x0)[0]
272
- im_next = latents_to_pil(latents)[0]
273
 
274
  # Now step with scheduler
275
  latents = scheduler.step(noise_pred, t, latents).prev_sample
276
-
277
- return latents_to_pil(latents)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
 
280
  def generate_image_from_prompt(text_in, style_in):
@@ -291,6 +958,8 @@ def generate_image_from_prompt(text_in, style_in):
291
  style = dict_styles # (learn_embed[0])
292
  birb_embed = torch.load(learn_embed[0])
293
  #birb_embed.keys(), dict_styles['<gartic-phone>'].shape
 
 
294
  #style_embed = torch.load(dict_styles)
295
  #birb_embed = torch.load('learned_embeds.bin')
296
  #birb_embed.keys(), birb_embed['<birb-style>'].shape
@@ -301,7 +970,7 @@ def generate_image_from_prompt(text_in, style_in):
301
  #loss_generated_img = (loss_style(prompt, style_embed[0], style_seed))
302
 
303
  return [generated_image]
304
-
305
 
306
  # Define Interface
307
  title = 'Stable Diffusion Art Generator'
@@ -330,3 +999,4 @@ demo = gr.Interface(generate_image_from_prompt,
330
  )
331
 
332
  demo.launch(debug=True)
 
 
1
  import gradio as gr
2
  from base64 import b64encode
3
+
4
  import numpy
5
  import torch
6
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
7
+
8
+
9
+
10
  from PIL import Image
11
  from torch import autocast
12
  from torchvision import transforms as tfms
 
15
  import torchvision.transforms as T
16
 
17
  torch.manual_seed(1)
18
+
19
+
20
+ # Supress some unnecessary warnings when loading the CLIPTextModel
21
  logging.set_verbosity_error()
22
+
23
  torch_device = "cpu"
24
 
25
+
26
+ # Load the autoencoder model which will be used to decode the latents into image space.
27
  vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
28
 
29
  # Load the tokenizer and text encoder to tokenize and encode the text.
 
36
  # The noise scheduler
37
  scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
38
 
39
+ # To the GPU we go!
40
  vae = vae.to(torch_device)
41
  text_encoder = text_encoder.to(torch_device)
42
  unet = unet.to(torch_device);
43
 
44
+ """## A diffusion loop
45
+
46
+ If all you want is to make a picture with some text, you could ignore this notebook and use one of the existing tools (such as [DreamStudio](https://beta.dreamstudio.ai/)) or use the simplified pipeline from huggingface, as documented [here](https://huggingface.co/blog/stable_diffusion).
47
+
48
+ What we want to do in this notebook is dig a little deeper into how this works, so we'll start by checking that the example code runs. Again, this is adapted from the [HF notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) and looks very similar to what you'll find if you inspect [the `__call__()` method of the stable diffusion pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L200).
49
+ """
50
+
51
+
52
+ # Prep Scheduler
53
+ def set_timesteps(scheduler, num_inference_steps):
54
+ scheduler.set_timesteps(num_inference_steps)
55
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
56
+
57
+
58
+ # Prep latents
59
+ latents = torch.randn(
60
+ (batch_size, unet.in_channels, height // 8, width // 8),
61
+ generator=generator,
62
+ )
63
+ latents = latents.to(torch_device)
64
+ latents = latents * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]
65
+
66
+ # Loop
67
+ with autocast("cuda"): # will fallback to CPU if no CUDA; no autocast for MPS
68
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
69
+
70
+ """It's working, but that's quite a bit of code! Let's look at the components one by one.
71
+
72
+ ## The Autoencoder (AE)
73
+
74
+ The AE can 'encode' an image into some sort of latent representation, and decode this back into an image. I've wrapped the code for this into a couple of functions here so we can see what this looks like in action:
75
+ """
76
 
77
  def pil_to_latent(input_im):
78
  # Single image -> single latent in a batch (so size 1, 4, 64, 64)
 
91
  pil_images = [Image.fromarray(image) for image in images]
92
  return pil_images
93
 
94
+
95
+ """What does this look like at different timesteps? Experiment and see for yourself!
96
+
97
+ If you uncomment the cell below you'll see that in this case the `scheduler.add_noise` function literally just adds noise scaled by sigma: `noisy_samples = original_samples + noise * sigmas`
98
+ """
99
+ #encoded = pil_to_latent(input_image)
100
+ #encoded.shape
101
+ #decoded = latents_to_pil(encoded)[0]
102
+ #decoded
103
+ # ??scheduler.add_noise
104
+
105
+ """Other diffusion models may be trained with different noising and scheduling approaches, some of which keep the variance fairly constant across noise levels ('variance preserving') with different scaling and mixing tricks instead of having noisy latents with higher and higher variance as more noise is added ('variance exploding').
106
+
107
+ If we want to start from random noise instead of a noised image, we need to scale it by the largest sigma value used during training, ~14 in this case. And before these noisy latents are fed to the model they are scaled again in the so-called pre-conditioning step:
108
+ `latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)` (now handled by `latent_model_input = scheduler.scale_model_input(latent_model_input, t)`).
109
+
110
+ Again, this scaling/pre-conditioning differs between papers and implementations, so keep an eye out for this if you work with a different type of diffusion model.
111
+
112
+ ## Loop starting from noised version of input (AKA image2image)
113
+
114
+ Let's see what happens when we use our image as a starting point, adding some noise and then doing the final few denoising steps in the loop with a new prompt.
115
+
116
+ We'll use a similar loop to the first demo, but we'll skip the first `start_step` steps.
117
+
118
+ To noise our image we'll use code like that shown above, using the scheduler to noise it to a level equivalent to step 10 (`start_step`).
119
+ """
120
+
121
+ # Settings (same as before except for the new prompt)
122
+
123
+ """You can see that some colours and structure from the image are kept, but we now have a new picture! The more noise you add and the more steps you do, the further away it gets from the input image.
124
+
125
+ This is how the popular img2img pipeline works. Again, if this is your end goal there are tools to make this easy!
126
+
127
+ But you can see that under the hood this is the same as the generation loop just skipping the first few steps and starting from a noised image rather than pure noise.
128
+
129
+ Explore changing how many steps are skipped and see how this affects the amount the image changes from the input.
130
+
131
+ ## Exploring the text -> embedding pipeline
132
+
133
+ We use a text encoder model to turn our text into a set of 'embeddings' which are fed to the diffusion model as conditioning. Let's follow a piece of text through this process and see how it works.
134
+ """
135
+
136
+ # Our text prompt
137
+ prompt = 'A picture of a puppy'
138
+
139
+
140
+ """We begin with tokenization:"""
141
+
142
+ # Turn the text into a sequnce of tokens:
143
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
144
+ text_input['input_ids'][0] # View the tokens
145
+
146
+ # See the individual tokens
147
+ for t in text_input['input_ids'][0][:8]: # We'll just look at the first 7 to save you from a wall of '<|endoftext|>'
148
+ print(t, tokenizer.decoder.get(int(t)))
149
+
150
+ # TODO call out that 6829 is puppy
151
+
152
+ """We can jump straight to the final (output) embeddings like so:"""
153
+
154
+ # Grab the output embeddings
155
+ output_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
156
+ print('Shape:', output_embeddings.shape)
157
+ output_embeddings
158
+
159
+ """We pass our tokens through the text_encoder and we magically get some numbers we can feed to the model.
160
+
161
+ How are these generated? The tokens are transformed into a set of input embeddings, which are then fed through the transformer model to get the final output embeddings.
162
+
163
+ To get these input embeddings, there are actually two steps - as revealed by inspecting `text_encoder.text_model.embeddings`:
164
+ """
165
+
166
+ text_encoder.text_model.embeddings
167
+
168
+ """### Token embeddings
169
+
170
+ The token is fed to the `token_embedding` to transform it into a vector. The function name `get_input_embeddings` here is misleading since these token embeddings need to be combined with the position embeddings before they are actually used as inputs to the model! Anyway, let's look at just the token embedding part first
171
+
172
+ We can look at the embedding layer:
173
+ """
174
+
175
+ # Access the embedding layer
176
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
177
+ token_emb_layer # Vocab size 49408, emb_dim 768
178
+
179
+ """And embed a token like so:"""
180
+
181
+ # Embed a token - in this case the one for 'puppy'
182
+ embedding = token_emb_layer(torch.tensor(6829, device=torch_device))
183
+ embedding.shape # 768-dim representation
184
+
185
+ """This single token has been mapped to a 768-dimensional vector - the token embedding.
186
+
187
+ We can do the same with all of the tokens in the prompt to get all the token embeddings:
188
+ """
189
+
190
+ token_embeddings = token_emb_layer(text_input.input_ids.to(torch_device))
191
+ print(token_embeddings.shape) # batch size 1, 77 tokens, 768 values for each
192
+ token_embeddings
193
+
194
+ """### Positional Embeddings
195
+
196
+ Positional embeddings tell the model where in a sequence a token is. Much like the token embedding, this is a set of (optionally learnable) parameters. But now instead of dealing with ~50k tokens we just need one for each position (77 total):
197
+ """
198
+
199
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
200
+ pos_emb_layer
201
+
202
+ """We can get the positional embedding for each position:"""
203
+
204
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
205
+ position_embeddings = pos_emb_layer(position_ids)
206
+ print(position_embeddings.shape)
207
+ position_embeddings
208
+
209
+ """### Combining token and position embeddings
210
+
211
+ Time to combine the two. How do we do this? Just add them! Other approaches are possible but for this model this is how it is done.
212
+
213
+ Combining them in this way gives us the final input embeddings ready to feed through the transformer model:
214
+ """
215
+
216
+ # And combining them we get the final input embeddings
217
+ input_embeddings = token_embeddings + position_embeddings
218
+ print(input_embeddings.shape)
219
+ input_embeddings
220
+
221
+ """We can check that these are the same as the result we'd get from `text_encoder.text_model.embeddings`:"""
222
+
223
+ # The following combines all the above steps (but doesn't let us fiddle with them!)
224
+ text_encoder.text_model.embeddings(text_input.input_ids.to(torch_device))
225
+
226
+ """### Feeding these through the transformer model
227
+
228
+ ![transformer diagram](https://github.com/johnowhitaker/tglcourse/raw/main/images/text_encoder_noborder.png)
229
+
230
+ We want to mess with these input embeddings (specifically the token embeddings) before we send them through the rest of the model, but first we should check that we know how to do that. I read the code of the text_encoders `forward` method, and based on that the code for the `forward` method of the text_model that the text_encoder wraps. To inspect it yourself, type `??text_encoder.text_model.forward` and you'll get the function info and source code - a useful debugging trick!
231
+
232
+ Anyway, based on that we can copy in the bits we need to get the so-called 'last hidden state' and thus generate our final embeddings:
233
+ """
234
+
235
  def get_output_embeds(input_embeddings):
236
  # CLIP's text model uses causal mask, so we prepare it here:
237
  bsz, seq_len = input_embeddings.shape[:2]
 
257
  # And now they're ready!
258
  return output
259
 
260
+ out_embs_test = get_output_embeds(input_embeddings) # Feed through the model with our new function
261
+ print(out_embs_test.shape) # Check the output shape
262
+ out_embs_test # Inspect the output
263
+
264
+ """Note that these match the `output_embeddings` we saw near the start - we've figured out how to split up that one step ("get the text embeddings") into multiple sub-steps ready for us to modify.
265
+
266
+ Now that we have this process in place, we can replace the input embedding of a token with a new one of our choice - which in our final use-case will be something we learn. To demonstrate the concept though, let's replace the input embedding for 'puppy' in the prompt we've been playing with with the embedding for token 2368, get a new set of output embeddings based on this, and use these to generate an image to see what we get:
267
+ """
268
+
269
+ prompt = 'A picture of a puppy'
270
+
271
+ # Tokenize
272
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
273
+ input_ids = text_input.input_ids.to(torch_device)
274
+
275
+ # Get token embeddings
276
+ token_embeddings = token_emb_layer(input_ids)
277
+
278
+ # The new embedding. In this case just the input embedding of token 2368...
279
+ replacement_token_embedding = text_encoder.get_input_embeddings()(torch.tensor(2368, device=torch_device))
280
+
281
+ # Insert this into the token embeddings (
282
+ token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
283
+
284
+ # Combine with pos embs
285
+ input_embeddings = token_embeddings + position_embeddings
286
+
287
+ # Feed through to get final output embs
288
+ modified_output_embeddings = get_output_embeds(input_embeddings)
289
+
290
+ print(modified_output_embeddings.shape)
291
+ modified_output_embeddings
292
+
293
+ """The first few are the same, the last aren't. Everything at and after the position of the token we're replacing will be affected.
294
+
295
+ If all went well, we should see something other than a puppy when we use these to generate an image. And sure enough, we do!
296
+ """
297
+
298
+ #Generating an image with these modified embeddings
299
+
300
+ def generate_with_embs(text_embeddings):
301
  height = 512 # default height of Stable Diffusion
302
  width = 512 # default width of Stable Diffusion
303
+ num_inference_steps = 30 # Number of denoising steps
304
  guidance_scale = 7.5 # Scale for classifier-free guidance
305
+ generator = torch.manual_seed(32) # Seed generator to create the inital latent noise
306
  batch_size = 1
307
 
308
+ max_length = text_input.input_ids.shape[-1]
309
  uncond_input = tokenizer(
310
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
311
  )
 
344
 
345
  return latents_to_pil(latents)[0]
346
 
347
+ #Generating an image with these modified embeddings
348
+
349
+ def generate_with_embs_seed(text_embeddings, seed, max_length):
350
+ """
351
+
352
+ Args:
353
+ text_embeddings:
354
+ seed:
355
+ max_length:
356
+
357
+ Returns:
358
+
359
+ """
360
+ height = 512 # default height of Stable Diffusion
361
+ width = 512 # default width of Stable Diffusion
362
+ num_inference_steps = 30 # Number of denoising steps
363
+ guidance_scale = 7.5 # Scale for classifier-free guidance
364
+ generator = torch.manual_seed(32) # Seed generator to create the inital latent noise
365
+ batch_size = 1
366
+
367
+ # max_length = text_input.input_ids.shape[-1]
368
+ uncond_input = tokenizer(
369
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
370
+ )
371
+ with torch.no_grad():
372
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
373
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
374
+
375
+ # Prep Scheduler
376
+ set_timesteps(scheduler, num_inference_steps)
377
+
378
+ # Prep latents
379
+ latents = torch.randn(
380
+ (batch_size, unet.in_channels, height // 8, width // 8),
381
+ generator=generator,
382
+ )
383
+ latents = latents.to(torch_device)
384
+ latents = latents * scheduler.init_noise_sigma
385
+
386
+ # Loop
387
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
388
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
389
+ latent_model_input = torch.cat([latents] * 2)
390
+ sigma = scheduler.sigmas[i]
391
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
392
+
393
+ # predict the noise residual
394
+ with torch.no_grad():
395
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
396
+
397
+ # perform guidance
398
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
399
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
400
+
401
+ # compute the previous noisy sample x_t -> x_t-1
402
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
403
+
404
+ return latents_to_pil(latents)[0]
405
+
406
+ generate_with_embs(modified_output_embeddings)
407
+
408
+ """Suprise! Now you know what token 2368 means ;)
409
+
410
+ **What can we do with this?** Why did we go to all of this trouble? Well, we'll see a more compelling use-case shortly but the tl;dr is that once we can access and modify the token embeddings we can do tricks like replacing them with something else. In the example we just did, that was just another token embedding from the model's vocabulary, equivalent to just editing the prompt. But we can also mix tokens - for example, here's a half-puppy-half-skunk:
411
+ """
412
+
413
+ # In case you're wondering how to get the token for a word, or the embedding for a token:
414
+ prompt = 'skunk'
415
+ print('tokenizer(prompt):', tokenizer(prompt))
416
+ print('token_emb_layer([token_id]) shape:', token_emb_layer(torch.tensor([8797], device=torch_device)).shape)
417
+
418
+ prompt = 'A picture of a puppy'
419
+
420
+ # Tokenize
421
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
422
+ input_ids = text_input.input_ids.to(torch_device)
423
+
424
+ # Get token embeddings
425
+ token_embeddings = token_emb_layer(input_ids)
426
+
427
+ # The new embedding. Which is now a mixture of the token embeddings for 'puppy' and 'skunk'
428
+ puppy_token_embedding = token_emb_layer(torch.tensor(6829, device=torch_device))
429
+ skunk_token_embedding = token_emb_layer(torch.tensor(42194, device=torch_device))
430
+ replacement_token_embedding = 0.5*puppy_token_embedding + 0.5*skunk_token_embedding
431
+
432
+ # Insert this into the token embeddings (
433
+ token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
434
+
435
+ # Combine with pos embs
436
+ input_embeddings = token_embeddings + position_embeddings
437
+
438
+ # Feed through to get final output embs
439
+ modified_output_embeddings = get_output_embeds(input_embeddings)
440
+
441
+ # Generate an image with these
442
+ generate_with_embs(modified_output_embeddings)
443
+
444
+ """### Textual Inversion
445
+
446
+ OK, so we can slip in a modified token embedding, and use this to generate an image. We used the token embedding for 'cat' in the above example, but what if instead could 'learn' a new token embedding for a specific concept? This is the idea behind 'Textual Inversion', in which a few example images are used to create a new token embedding:
447
+
448
+ ![Overview image from the blog post](https://textual-inversion.github.io/static/images/training/training.JPG)
449
+ _Diagram from the [textual inversion blog post](https://textual-inversion.github.io/static/images/training/training.JPG) - note it doesn't show the positional embeddings step for simplicity_
450
+
451
+ We won't cover how this training works, but we can try loading one of these new 'concepts' from the [community-created SD concepts library](https://huggingface.co/sd-concepts-library) and see how it fits in with our example above. I'll use https://huggingface.co/sd-concepts-library/birb-style since it was the first one I made :) Download the learned_embeds.bin file from there and upload the file to wherever this notebook is before running this next cell:
452
+ """
453
+
454
+ birb_embed = torch.load('learned_embeds.bin')
455
+ birb_embed.keys(), birb_embed['<birb-style>'].shape
456
+
457
+ """We get a dictionary with a key (the special placeholder I used, <birb-style>) and the corresponding token embedding. As in the previous example, let's replace the 'puppy' token embedding with this and see what happens:"""
458
+
459
+ prompt = 'A mouse in the style of puppy'
460
+
461
+ # Tokenize
462
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
463
+ input_ids = text_input.input_ids.to(torch_device)
464
+
465
+ # Get token embeddings
466
+ token_embeddings = token_emb_layer(input_ids)
467
+
468
+ # The new embedding - our special birb word
469
+ replacement_token_embedding = birb_embed['<birb-style>'].to(torch_device)
470
+
471
+ # Insert this into the token embeddings
472
+ token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
473
+
474
+ # Combine with pos embs
475
+ input_embeddings = token_embeddings + position_embeddings
476
+
477
+ # Feed through to get final output embs
478
+ modified_output_embeddings = get_output_embeds(input_embeddings)
479
+
480
+ # And generate an image with this:
481
+ generate_with_embs(modified_output_embeddings)
482
+
483
+ """The token for 'puppy' was replaced with one that captures a particular style of painting, but it could just as easily represent a specific object or class of objects.
484
+
485
+ Again, there is [a nice inference notebook ](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb) from hf to make it easy to use the different concepts, that properly handles using the names in prompts ("A \<cat-toy> in the style of \<birb-style>") without worrying about all this manual stuff. The goal of this notebook is to pull back the curtain a bit so you know what is going on behind the scenes :)
486
+
487
+ ## Messing with Embeddings
488
+
489
+ Besides just replacing the token embedding of a single word, there are various other tricks we can try. For example, what if we create a 'chimera' by averaging the embeddings of two different prompts?
490
+ """
491
+
492
+ # Embed two prompts
493
+ text_input1 = tokenizer(["A mouse"], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
494
+ text_input2 = tokenizer(["A leopard"], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
495
+ with torch.no_grad():
496
+ text_embeddings1 = text_encoder(text_input1.input_ids.to(torch_device))[0]
497
+ text_embeddings2 = text_encoder(text_input2.input_ids.to(torch_device))[0]
498
+
499
+ # Mix them together
500
+ mix_factor = 0.35
501
+ mixed_embeddings = (text_embeddings1*mix_factor + \
502
+ text_embeddings2*(1-mix_factor))
503
+
504
+ # Generate!
505
+ generate_with_embs(mixed_embeddings)
506
+
507
+ """## The UNET and CFG
508
+
509
+ Now it's time we looked at the actual diffusion model. This is typically a Unet that takes in the noisy latents (x) and predicts the noise. We use a conditional model that also takes in the timestep (t) and our text embedding (aka encoder_hidden_states) as conditioning. Feeding all of these into the model looks like this:
510
+ `noise_pred = unet(latents, t, encoder_hidden_states=text_embeddings)["sample"]`
511
+
512
+ We can try it out and see what the output looks like:
513
+ """
514
+
515
  # Prep Scheduler
516
+ set_timesteps(scheduler, num_inference_steps)
517
+
518
+ # What is our timestep
519
+ t = scheduler.timesteps[0]
520
+ sigma = scheduler.sigmas[0]
521
+
522
+ # A noisy latent
523
+ latents = torch.randn(
524
+ (batch_size, unet.in_channels, height // 8, width // 8),
525
+ generator=generator,
526
+ )
527
+ latents = latents.to(torch_device)
528
+ latents = latents * scheduler.init_noise_sigma
529
+
530
+ # Text embedding
531
+ text_input = tokenizer(['A macaw'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
532
+ with torch.no_grad():
533
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
534
+
535
+ # Run this through the unet to predict the noise residual
536
+ with torch.no_grad():
537
+ noise_pred = unet(latents, t, encoder_hidden_states=text_embeddings)["sample"]
538
+
539
+ latents.shape, noise_pred.shape # We get preds in the same shape as the input
540
+
541
+ """Given a set of noisy latents, the model predicts the noise component. We can remove this noise from the noisy latents to see what the output image looks like (`latents_x0 = latents - sigma * noise_pred`). And we can add most of the noise back to this predicted output to get the (slightly less noisy hopefully) input for the next diffusion step. To visualize this let's generate another image, saving both the predicted output (x0) and the next step (xt-1) after every step:"""
542
+
543
+ prompt = 'Oil painting of an otter in a top hat'
544
+ height = 512
545
+ width = 512
546
+ num_inference_steps = 50
547
+ guidance_scale = 8
548
+ generator = torch.manual_seed(32)
549
+ batch_size = 1
550
+
551
+ # Make a folder to store results
552
+ #!rm -rf steps/
553
+ #!mkdir -p steps/
554
+
555
+ # Prep text
556
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
557
+ with torch.no_grad():
558
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
559
+ max_length = text_input.input_ids.shape[-1]
560
+ uncond_input = tokenizer(
561
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
562
+ )
563
+ with torch.no_grad():
564
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
565
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
566
 
567
+ # Prep Scheduler
568
+ set_timesteps(scheduler, num_inference_steps)
569
+
570
+ # Prep latents
571
+ latents = torch.randn(
572
+ (batch_size, unet.in_channels, height // 8, width // 8),
573
+ generator=generator,
574
+ )
575
+ latents = latents.to(torch_device)
576
+ latents = latents * scheduler.init_noise_sigma
577
+
578
+ # Loop
579
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
580
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
581
+ latent_model_input = torch.cat([latents] * 2)
582
+ sigma = scheduler.sigmas[i]
583
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
584
+
585
+ # predict the noise residual
586
+ with torch.no_grad():
587
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
588
 
589
+ # perform guidance
590
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
591
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
592
 
593
+ # Get the predicted x0:
594
+ # latents_x0 = latents - sigma * noise_pred # Calculating ourselves
595
+ scheduler_step = scheduler.step(noise_pred, t, latents)
596
+ latents_x0 = scheduler_step.pred_original_sample # Using the scheduler (Diffusers 0.4 and above)
 
597
 
598
+ # compute the previous noisy sample x_t -> x_t-1
599
+ latents = scheduler_step.prev_sample
 
 
 
 
 
600
 
601
+ # To PIL Images
602
+ im_t0 = latents_to_pil(latents_x0)[0]
603
+ im_next = latents_to_pil(latents)[0]
604
 
605
+ # Combine the two images and save for later viewing
606
+ im = Image.new('RGB', (1024, 512))
607
+ im.paste(im_next, (0, 0))
608
+ im.paste(im_t0, (512, 0))
609
+ im.save(f'steps/{i:04}.jpeg')
610
 
611
+ # Make and show the progress video (change width to 1024 for full res)
612
+ #!ffmpeg -v 1 -y -f image2 -framerate 12 -i steps/%04d.jpeg -c:v libx264 -preset slow -qp 18 -pix_fmt yuv420p out.mp4
613
+ #mp4 = open('out.mp4','rb').read()
614
+ #data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
615
+ #HTML("""
616
+ #<video width=600 controls>
617
+ # <source src="%s" type="video/mp4">
618
+ #</video>
619
+ #""" % data_url)
620
 
621
+ #"""The version on the right shows the predicted 'final output' (x0) at each step, and this is what is usually used for progress videos etc. The version on the left is the 'next step'. I found it interesteing to compare the two - watching the progress videos only you'd think drastic changes are happening expecially at early stages, but since the changes made per-step are relatively small the actual process is much more gradual.
 
 
622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
 
 
624
 
625
+ # Guidance
626
+
627
+
628
 
 
 
 
629
 
630
+ def blue_loss(images):
631
+ # How far are the blue channel values to 0.9:
632
+ error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
633
+ return error
634
+
635
+ def orange_loss(images):
636
+ """
637
+ Calculate the mean absolute error between the RGB values of the images and the target orange color.
638
+
639
+ Parameters:
640
+ - images (torch.Tensor): A batch of images with shape (batch_size, channels, height, width).
641
+ The images are assumed to be in RGB format.
642
+
643
+ Returns:
644
+ - torch.Tensor: The mean absolute error for the orange color.
645
+ """
646
+ # Define the target RGB values for the color orange
647
+ target_orange = torch.tensor([255/255, 200/255, 0/255]).view(1, 3, 1, 1).to(images.device) # (R, G, B)
648
+
649
+ # Normalize images to [0, 1] range if not already normalized
650
+ images = images / 255.0 if images.max() > 1.0 else images
651
+
652
+ # Calculate the mean absolute error between the RGB values and the target orange values
653
+ error = torch.abs(images - target_orange).mean()
654
 
 
 
 
655
  return error
656
 
657
+ """During each update step, we find the gradient of the loss with respect to the current noisy latents, and tweak them in the direction that reduces this loss as well as performing the normal update step:"""
658
+
659
+ prompt = 'A campfire (oil on canvas)' #@param
660
+ height = 512 # default height of Stable Diffusion
661
+ width = 512 # default width of Stable Diffusion
662
+ num_inference_steps = 50 #@param # Number of denoising steps
663
+ guidance_scale = 8 #@param # Scale for classifier-free guidance
664
+ generator = torch.manual_seed(32) # Seed generator to create the inital latent noise
665
+ batch_size = 1
666
+ orange_loss_scale = 200 #@param
667
+
668
+ # Prep text
669
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
670
+ with torch.no_grad():
671
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
672
+
673
+ # And the uncond. input as before:
674
+ max_length = text_input.input_ids.shape[-1]
675
+ uncond_input = tokenizer(
676
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
677
+ )
678
+ with torch.no_grad():
679
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
680
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
681
+
682
+ # Prep Scheduler
683
+ set_timesteps(scheduler, num_inference_steps)
684
+
685
+ # Prep latents
686
+ latents = torch.randn(
687
+ (batch_size, unet.in_channels, height // 8, width // 8),
688
+ generator=generator,
689
+ )
690
+ latents = latents.to(torch_device)
691
+ latents = latents * scheduler.init_noise_sigma
692
+
693
+ # Loop
694
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
695
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
696
+ latent_model_input = torch.cat([latents] * 2)
697
+ sigma = scheduler.sigmas[i]
698
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
699
+
700
+ # predict the noise residual
701
+ with torch.no_grad():
702
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
703
+
704
+ # perform CFG
705
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
706
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
707
+
708
+ #### ADDITIONAL GUIDANCE ###
709
+ if i%5 == 0:
710
+ # Requires grad on the latents
711
+ latents = latents.detach().requires_grad_()
712
+
713
+ # Get the predicted x0:
714
+ latents_x0 = latents - sigma * noise_pred
715
+ # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
716
+
717
+ # Decode to image space
718
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
719
+
720
+ # Calculate loss
721
+ loss = blue_loss(denoised_images) * orange_loss_scale
722
+
723
+ # Occasionally print it out
724
+ if i%10==0:
725
+ print(i, 'loss:', loss.item())
726
+
727
+ # Get gradient
728
+ cond_grad = torch.autograd.grad(loss, latents)[0]
729
+
730
+ # Modify the latents based on this gradient
731
+ latents = latents.detach() - cond_grad * sigma**2
732
+
733
+ # Now step with scheduler
734
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
735
+
736
+
737
+ latents_to_pil(latents)[0]
738
+
739
+ prompt = 'A mouse in the style of puppy'
740
 
741
+ # Tokenize
742
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
743
+ text_input
744
+ input_ids = text_input.input_ids.to(torch_device)
745
+
746
+ # Get token embeddings
747
+ token_embeddings = token_emb_layer(input_ids)
748
+
749
+ # The new embedding - our special birb word
750
+ replacement_token_embedding = birb_embed['<birb-style>'].to(torch_device)
751
+
752
+ # Insert this into the token embeddings
753
+ token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
754
+
755
+ # Combine with pos embs
756
+ input_embeddings = token_embeddings + position_embeddings
757
+
758
+ # Feed through to get final output embs
759
+ modified_output_embeddings = get_output_embeds(input_embeddings)
760
+
761
+ # And generate an image with this:
762
+ generate_with_embs(modified_output_embeddings)
763
+
764
+ text_input, input_ids,token_embeddings
765
+
766
+ def generate_loss(modified_output_embeddings, seed, max_length):
767
+ # prompt = 'A campfire (oil on canvas)' #@param
768
+ height = 512 # default height of Stable Diffusion
769
+ width = 512 # default width of Stable Diffusion
770
+ num_inference_steps = 50 #@param # Number of denoising steps
771
+ guidance_scale = 8 #@param # Scale for classifier-free guidance
772
+ generator = torch.manual_seed(32) # Seed generator to create the initial latent noise
773
  batch_size = 1
774
+ blue_loss_scale = 200 #@param
775
+
776
+ # Prep text
777
+ # text_input = tokenizer([""] * batch_size, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
778
 
779
+ #input_ids = text_input.input_ids.to(torch_device)
780
+ # Get token embeddings
781
+ #token_embeddings = token_emb_layer(input_ids)
782
+
783
+ # The new embedding - our special birb word
784
+ #replacement_token_embedding = birb_embed['<birb-style>'].to(torch_device)
785
+ # Insert this into the token embeddings
786
+ #indices = torch.where(input_ids[0] == 6829)[0]
787
+ #token_embeddings[0, indices] = replacement_token_embedding.expand_as(token_embeddings[0, indices])
788
+
789
+ # Combine with pos embs
790
+ #input_embeddings = token_embeddings + position_embeddings
791
+
792
+ # Pass the modified embeddings to the text encoder
793
+ #with torch.no_grad():
794
+ # text_embeddings = text_encoder(inputs_embeds=input_embeddings)[0]
795
+
796
+ # And the uncond. input as before:
797
+ # max_length = input_ids.shape[-1]
798
  uncond_input = tokenizer(
799
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
800
  )
801
  with torch.no_grad():
802
  uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
803
+ # Ensure both embeddings have the same shape
804
+ if uncond_embeddings.shape != modified_output_embeddings.shape:
805
+ raise ValueError(f"Shape mismatch: uncond_embeddings {uncond_embeddings.shape} vs modified_output_embeddings {modified_output_embeddings.shape}")
806
+
807
+ text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings])
808
 
809
  # Prep Scheduler
810
+ set_timesteps(scheduler, num_inference_steps)
811
 
812
  # Prep latents
813
  latents = torch.randn(
814
+ (batch_size, unet.in_channels, height // 8, width // 8),
815
+ generator=generator,
816
  )
817
  latents = latents.to(torch_device)
818
  latents = latents * scheduler.init_noise_sigma
819
 
 
 
820
  # Loop
821
  for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
822
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
 
832
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
833
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
834
 
835
+ #### ADDITIONAL GUIDANCE ###
836
+ if i % 5 == 0:
837
  # Requires grad on the latents
838
  latents = latents.detach().requires_grad_()
839
 
840
  # Get the predicted x0:
841
+ latents_x0 = latents - sigma * noise_pred
842
+ # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
843
 
844
  # Decode to image space
845
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
846
 
 
847
  # Calculate loss
848
+ loss = orange_loss(denoised_images) * blue_loss_scale
849
 
850
  # Occasionally print it out
851
+ if i % 10 == 0:
852
+ print(i, 'loss:', loss.item())
853
 
854
  # Get gradient
855
  cond_grad = torch.autograd.grad(loss, latents)[0]
856
 
857
  # Modify the latents based on this gradient
858
+ latents = latents.detach() - cond_grad * sigma ** 2
 
 
 
859
 
860
  # Now step with scheduler
861
  latents = scheduler.step(noise_pred, t, latents).prev_sample
862
+
863
+ # Convert the final latents to an image and display it
864
+ image = latents_to_pil(latents)[0]
865
+ image.show()
866
+ return image
867
+
868
+ def generate_loss_style(prompt, style_embed, style_seed):
869
+ # Tokenize
870
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
871
+ input_ids = text_input.input_ids.to(torch_device)
872
+
873
+ # Get token embeddings
874
+ token_embeddings = token_emb_layer(input_ids)
875
+ if isinstance(style_embed, dict):
876
+ style_embed = style_embed['<gartic-phone>']
877
+
878
+ # The new embedding - our special birb word
879
+ replacement_token_embedding = style_embed.to(torch_device)
880
+ # Assuming token_embeddings has shape [batch_size, seq_length, embedding_dim]
881
+ replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
882
+ replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
883
+ indices = torch.where(input_ids[0] == 6829)[0] # Extract indices where the condition is True
884
+ print(f"indices: {indices}") # Debug print
885
+ for index in indices:
886
+ print(f"index: {index}") # Debug print
887
+ token_embeddings[0, index] = replacement_token_embedding.to(torch_device) # Update each index
888
+
889
+ # Insert this into the token embeddings
890
+ # token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
891
+
892
+ # Combine with pos embs
893
+ input_embeddings = token_embeddings + position_embeddings
894
+
895
+ # Feed through to get final output embs
896
+ modified_output_embeddings = get_output_embeds(input_embeddings)
897
+
898
+ # And generate an image with this:
899
+ max_length = text_input.input_ids.shape[-1]
900
+ return generate_loss(modified_output_embeddings, style_seed,max_length)
901
+
902
+ def generate_embed_style(prompt, learned_style, seed):
903
+ # prompt = 'A campfire (oil on canvas)' #@param
904
+ height = 512 # default height of Stable Diffusion
905
+ width = 512 # default width of Stable Diffusion
906
+ num_inference_steps = 50 #@param # Number of denoising steps
907
+ guidance_scale = 8 #@param # Scale for classifier-free guidance
908
+ generator = torch.manual_seed(32) # Seed generator to create the initial latent noise
909
+ batch_size = 1
910
+ blue_loss_scale = 200 #@param
911
+ if isinstance(learned_style, dict):
912
+ learned_style = learned_style['<gartic-phone>']
913
+
914
+ # Prep text
915
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
916
+
917
+ input_ids = text_input.input_ids.to(torch_device)
918
+ # Get token embeddings
919
+ token_embeddings = text_encoder.get_input_embeddings()(input_ids)
920
+
921
+ # The new embedding - our special birb word
922
+ replacement_token_embedding = learned_style.to(torch_device)
923
+ replacement_token_embedding = replacement_token_embedding[:768] # Adjust the size
924
+ replacement_token_embedding = replacement_token_embedding.unsqueeze(0) # Make it [1, 768] if necessary
925
+ # Insert this into the token embeddings
926
+ indices = torch.where(input_ids[0] == 6829)[0]
927
+ for index in indices:
928
+ token_embeddings[0, index] = replacement_token_embedding.to(torch_device)
929
+ # Combine with pos embs
930
+ position_ids = torch.arange(token_embeddings.shape[1], dtype=torch.long, device=torch_device)
931
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
932
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
933
+ position_embeddings = pos_emb_layer(position_ids)
934
+ #position_embeddings = text_encoder.get_position_embeddings()(position_ids)
935
+ input_embeddings = token_embeddings + position_embeddings
936
+ # Feed through to get final output embs
937
+ modified_output_embeddings = get_output_embeds(input_embeddings)
938
+ # And generate an image with this:
939
+ max_length = text_input.input_ids.shape[-1]
940
+ emb_seed = generate_with_embs_seed(modified_output_embeddings, seed, max_length)
941
+ #generate_loss_details = generate_loss(modified_output_embeddings, seed, max_length)
942
+ return emb_seed
943
+ # And generate an , generateimage with this:
944
+
945
 
946
 
947
  def generate_image_from_prompt(text_in, style_in):
 
958
  style = dict_styles # (learn_embed[0])
959
  birb_embed = torch.load(learn_embed[0])
960
  #birb_embed.keys(), dict_styles['<gartic-phone>'].shape
961
+
962
+
963
  #style_embed = torch.load(dict_styles)
964
  #birb_embed = torch.load('learned_embeds.bin')
965
  #birb_embed.keys(), birb_embed['<birb-style>'].shape
 
970
  #loss_generated_img = (loss_style(prompt, style_embed[0], style_seed))
971
 
972
  return [generated_image]
973
+
974
 
975
  # Define Interface
976
  title = 'Stable Diffusion Art Generator'
 
999
  )
1000
 
1001
  demo.launch(debug=True)
1002
+