local-prompt-mixing / src /prompt_utils.py
orpatashnik's picture
add code
c4e6a63
raw
history blame
2.85 kB
import json
import torch
import numpy as np
from tqdm import tqdm
def get_topk_similar_words(model, prompt, base_word, vocab, k=30):
text_input = model.tokenizer(
[prompt.format(word=base_word)],
padding="max_length",
max_length=model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
encoder_output = model.text_encoder(text_input.input_ids.to(model.device))
full_prompt_embedding = encoder_output.pooler_output
full_prompt_embedding = full_prompt_embedding / full_prompt_embedding.norm(p=2, dim=-1, keepdim=True)
prompts = [prompt.format(word=word) for word in vocab]
batch_size = 1000
all_prompts_embeddings = []
for i in tqdm(range(0, len(prompts), batch_size)):
curr_prompts = prompts[i:i + batch_size]
with torch.no_grad():
text_input = model.tokenizer(
curr_prompts,
padding="max_length",
max_length=model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
curr_embeddings = model.text_encoder(text_input.input_ids.to(model.device)).pooler_output
all_prompts_embeddings.append(curr_embeddings)
all_prompts_embeddings = torch.cat(all_prompts_embeddings)
all_prompts_embeddings = all_prompts_embeddings / all_prompts_embeddings.norm(p=2, dim=-1, keepdim=True)
prompts_similarities = all_prompts_embeddings.matmul(full_prompt_embedding.view(-1, 1))
sorted_prompts_similarities = np.flip(prompts_similarities.cpu().numpy().reshape(-1).argsort())
print(f"prompt: {prompt}")
print(f"initial word: {base_word}")
print(f"TOP {k} SIMILAR WORDS:")
similar_words = [vocab[index] for index in sorted_prompts_similarities[:k]]
print(similar_words)
return similar_words
def get_proxy_words(args, ldm_stable):
if len(args.proxy_words) > 0:
return [args.object_of_interest] + args.proxy_words
vocab = list(json.load(open("vocab.json")).keys())
vocab = [word for word in vocab if word.isalpha() and len(word) > 1]
filtered_vocab = get_topk_similar_words(ldm_stable, "a photo of a {word}", args.object_of_interest, vocab, k=50)
proxy_words = get_topk_similar_words(ldm_stable, args.prompt, args.object_of_interest, filtered_vocab, k=args.number_of_variations)
if proxy_words[0] != args.object_of_interest:
proxy_words = [args.object_of_interest] + proxy_words
return proxy_words
def get_proxy_prompts(args, ldm_stable):
proxy_words = get_proxy_words(args, ldm_stable)
prompts = [args.prompt.format(word=args.object_of_interest)]
proxy_prompts = [{"word": word, "prompt": args.prompt.format(word=word)} for word in proxy_words]
return proxy_words, prompts, proxy_prompts