piyushgrover commited on
Commit
4e75298
Β·
1 Parent(s): 266993c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +36 -0
utils.py CHANGED
@@ -218,3 +218,39 @@ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
218
 
219
  pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
220
  #pos_emb_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
220
  #pos_emb_layer
221
+
222
+ def func_generate(query, concept_idx, seed_start, contrast_loss=False, contrast_perc=None):
223
+ prompt = query + ' in the style of bulb'
224
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
225
+ return_tensors="pt")
226
+ input_ids = text_input.input_ids.to(torch_device)
227
+
228
+ # Get token embeddings
229
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
230
+ position_embeddings = pos_emb_layer(position_ids)
231
+
232
+ s = seed_start
233
+
234
+ token_embeddings = token_emb_layer(input_ids)
235
+ # The new embedding - our special birb word
236
+ replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
237
+
238
+ # Insert this into the token embeddings
239
+ token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
240
+
241
+ # Combine with pos embs
242
+ input_embeddings = token_embeddings + position_embeddings
243
+
244
+ # Feed through to get final output embs
245
+ modified_output_embeddings = get_output_embeds(input_embeddings)
246
+
247
+ # And generate an image with this:
248
+
249
+ if contrast_loss and seed_values[concept_idx] > 0:
250
+ s = seed_values[concept_idx]
251
+ else:
252
+ s = random.randint(s + 1, s + 30)
253
+ seed_values[concept_idx] = s
254
+
255
+ g = torch.manual_seed(s)
256
+ return generate_with_embs(text_input, modified_output_embeddings, generator=g, contrast_loss=contrast_loss, contrast_perc=contrast_perc)