Spaces:
Runtime error
Runtime error
Commit
Β·
4e75298
1
Parent(s):
266993c
Update utils.py
Browse files
utils.py
CHANGED
@@ -218,3 +218,39 @@ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
|
|
218 |
|
219 |
pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
|
220 |
#pos_emb_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
|
220 |
#pos_emb_layer
|
221 |
+
|
222 |
+
def func_generate(query, concept_idx, seed_start, contrast_loss=False, contrast_perc=None):
|
223 |
+
prompt = query + ' in the style of bulb'
|
224 |
+
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
|
225 |
+
return_tensors="pt")
|
226 |
+
input_ids = text_input.input_ids.to(torch_device)
|
227 |
+
|
228 |
+
# Get token embeddings
|
229 |
+
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
|
230 |
+
position_embeddings = pos_emb_layer(position_ids)
|
231 |
+
|
232 |
+
s = seed_start
|
233 |
+
|
234 |
+
token_embeddings = token_emb_layer(input_ids)
|
235 |
+
# The new embedding - our special birb word
|
236 |
+
replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
|
237 |
+
|
238 |
+
# Insert this into the token embeddings
|
239 |
+
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
|
240 |
+
|
241 |
+
# Combine with pos embs
|
242 |
+
input_embeddings = token_embeddings + position_embeddings
|
243 |
+
|
244 |
+
# Feed through to get final output embs
|
245 |
+
modified_output_embeddings = get_output_embeds(input_embeddings)
|
246 |
+
|
247 |
+
# And generate an image with this:
|
248 |
+
|
249 |
+
if contrast_loss and seed_values[concept_idx] > 0:
|
250 |
+
s = seed_values[concept_idx]
|
251 |
+
else:
|
252 |
+
s = random.randint(s + 1, s + 30)
|
253 |
+
seed_values[concept_idx] = s
|
254 |
+
|
255 |
+
g = torch.manual_seed(s)
|
256 |
+
return generate_with_embs(text_input, modified_output_embeddings, generator=g, contrast_loss=contrast_loss, contrast_perc=contrast_perc)
|