File size: 2,847 Bytes
c4e6a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import torch
import numpy as np
from tqdm import tqdm


def get_topk_similar_words(model, prompt, base_word, vocab, k=30):
    text_input = model.tokenizer(
        [prompt.format(word=base_word)],
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        encoder_output = model.text_encoder(text_input.input_ids.to(model.device))
    full_prompt_embedding = encoder_output.pooler_output
    full_prompt_embedding = full_prompt_embedding / full_prompt_embedding.norm(p=2, dim=-1, keepdim=True)

    prompts = [prompt.format(word=word) for word in vocab]
    batch_size = 1000
    all_prompts_embeddings = []
    for i in tqdm(range(0, len(prompts), batch_size)):
        curr_prompts = prompts[i:i + batch_size]
        with torch.no_grad():
            text_input = model.tokenizer(
                curr_prompts,
                padding="max_length",
                max_length=model.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            curr_embeddings = model.text_encoder(text_input.input_ids.to(model.device)).pooler_output
        all_prompts_embeddings.append(curr_embeddings)

    all_prompts_embeddings = torch.cat(all_prompts_embeddings)
    all_prompts_embeddings = all_prompts_embeddings / all_prompts_embeddings.norm(p=2, dim=-1, keepdim=True)
    prompts_similarities = all_prompts_embeddings.matmul(full_prompt_embedding.view(-1, 1))
    sorted_prompts_similarities = np.flip(prompts_similarities.cpu().numpy().reshape(-1).argsort())

    print(f"prompt: {prompt}")
    print(f"initial word: {base_word}")
    print(f"TOP {k} SIMILAR WORDS:")
    similar_words = [vocab[index] for index in sorted_prompts_similarities[:k]]
    print(similar_words)
    return similar_words

def get_proxy_words(args, ldm_stable):
    if len(args.proxy_words) > 0:
        return [args.object_of_interest] + args.proxy_words
    vocab = list(json.load(open("vocab.json")).keys())
    vocab = [word for word in vocab if word.isalpha() and len(word) > 1]
    filtered_vocab = get_topk_similar_words(ldm_stable, "a photo of a {word}", args.object_of_interest, vocab, k=50)
    proxy_words = get_topk_similar_words(ldm_stable, args.prompt, args.object_of_interest, filtered_vocab, k=args.number_of_variations)
    if proxy_words[0] != args.object_of_interest:
        proxy_words = [args.object_of_interest] + proxy_words

    return proxy_words

def get_proxy_prompts(args, ldm_stable):
    proxy_words = get_proxy_words(args, ldm_stable)
    prompts = [args.prompt.format(word=args.object_of_interest)]
    proxy_prompts = [{"word": word, "prompt": args.prompt.format(word=word)} for word in proxy_words]
    return proxy_words, prompts, proxy_prompts