Spaces:
Runtime error
Runtime error
import json | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
def get_topk_similar_words(model, prompt, base_word, vocab, k=30): | |
text_input = model.tokenizer( | |
[prompt.format(word=base_word)], | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
encoder_output = model.text_encoder(text_input.input_ids.to(model.device)) | |
full_prompt_embedding = encoder_output.pooler_output | |
full_prompt_embedding = full_prompt_embedding / full_prompt_embedding.norm(p=2, dim=-1, keepdim=True) | |
prompts = [prompt.format(word=word) for word in vocab] | |
batch_size = 1000 | |
all_prompts_embeddings = [] | |
for i in tqdm(range(0, len(prompts), batch_size)): | |
curr_prompts = prompts[i:i + batch_size] | |
with torch.no_grad(): | |
text_input = model.tokenizer( | |
curr_prompts, | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
curr_embeddings = model.text_encoder(text_input.input_ids.to(model.device)).pooler_output | |
all_prompts_embeddings.append(curr_embeddings) | |
all_prompts_embeddings = torch.cat(all_prompts_embeddings) | |
all_prompts_embeddings = all_prompts_embeddings / all_prompts_embeddings.norm(p=2, dim=-1, keepdim=True) | |
prompts_similarities = all_prompts_embeddings.matmul(full_prompt_embedding.view(-1, 1)) | |
sorted_prompts_similarities = np.flip(prompts_similarities.cpu().numpy().reshape(-1).argsort()) | |
print(f"prompt: {prompt}") | |
print(f"initial word: {base_word}") | |
print(f"TOP {k} SIMILAR WORDS:") | |
similar_words = [vocab[index] for index in sorted_prompts_similarities[:k]] | |
print(similar_words) | |
return similar_words | |
def get_proxy_words(args, ldm_stable): | |
if len(args.proxy_words) > 0: | |
return [args.object_of_interest] + args.proxy_words | |
vocab = list(json.load(open("vocab.json")).keys()) | |
vocab = [word for word in vocab if word.isalpha() and len(word) > 1] | |
filtered_vocab = get_topk_similar_words(ldm_stable, "a photo of a {word}", args.object_of_interest, vocab, k=50) | |
proxy_words = get_topk_similar_words(ldm_stable, args.prompt, args.object_of_interest, filtered_vocab, k=args.number_of_variations) | |
if proxy_words[0] != args.object_of_interest: | |
proxy_words = [args.object_of_interest] + proxy_words | |
return proxy_words | |
def get_proxy_prompts(args, ldm_stable): | |
proxy_words = get_proxy_words(args, ldm_stable) | |
prompts = [args.prompt.format(word=args.object_of_interest)] | |
proxy_prompts = [{"word": word, "prompt": args.prompt.format(word=word)} for word in proxy_words] | |
return proxy_words, prompts, proxy_prompts |