File size: 2,373 Bytes
9b95338
bb788ed
2623e85
 
8ceef3d
2623e85
8ceef3d
 
0bad089
a147e52
 
0bad089
 
a147e52
23020da
 
a147e52
 
2623e85
 
 
 
 
 
 
 
 
79ae41f
95d2476
26b15f7
 
 
 
 
 
 
 
 
 
 
d9adf81
 
 
 
c83bbd7
2bbff8e
 
26b15f7
cfc29b0
 
c3ed860
9638484
cfc29b0
f834bd4
b016be0
3d77f49
 
 
6ad4ef6
3d77f49
 
 
 
cfc29b0
d9adf81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from faiss import IndexFlatIP, IndexFlatL2
import pandas as pd
import numpy as np
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
normalized_input_embeddings = np.load("normalized_bert_input_embeddings.npy")
unnormalized_input_embeddings = np.load("unnormalized_bert_input_embeddings.npy")

index_L2 = IndexFlatL2(unnormalized_input_embeddings.shape[-1])  
index_L2.add(unnormalized_input_embeddings)

index_IP = IndexFlatIP(normalized_input_embeddings.shape[-1])  
index_IP.add(normalized_input_embeddings)


vocab = {v:k for k,v in tokenizer.vocab.items()}
lookup_table = pd.Series(vocab).sort_index()

def get_first_subword(word):
    try:
        return tokenizer.vocab[word]
    except:
        return tokenizer(word, add_special_tokens=False)['input_ids'][0]

def search(token_to_lookup, num_neighbors=100):
    i = get_first_subword(token_to_lookup)
    _ , I_IP = index_IP.search(normalized_input_embeddings[i:i+1], num_neighbors)
    hits_IP = lookup_table.take(I_IP[0])
    results_IP = hits_IP.values[1:]
    results_IP = [r for r in results_IP if not "[unused" in r]

    _ , I_L2 = index_L2.search(unnormalized_input_embeddings[i:i+1], num_neighbors)
    hits_L2 = lookup_table.take(I_L2[0])
    results_L2 = hits_L2.values[1:]
    results_L2 = [r for r in results_L2 if not "[unused" in r]
    
    return [r for r in results_IP if not "##" in r], [r for r in results_IP if "##" in r], [r for r in results_L2 if not "##" in r], [r for r in results_L2 if "##" in r]


iface = gr.Interface(
    fn=search,
    
    #inputs=[gr.Textbox(lines=1, label="Vocabulary Token", placeholder="Enter token..."), gr.Number(value=50, label="number of neighbors")],
    inputs=gr.Textbox(lines=1, label="Vocabulary Token", placeholder="Enter token..."),
    outputs=[gr.Textbox(label="IP-Nearest tokens"), gr.Textbox(label="IP-Nearest subwords"), gr.Textbox(label="L2-Nearest tokens"), gr.Textbox(label="L2-Nearest subwords")],
    examples=[
        ["##logy"],
        ["##ness"],
        ["##nity"],
        ["responded"],
        ["queen"],
        ["king"],
        ["hospital"],
        ["disease"],
        ["grammar"],
        ["philosophy"],
        ["aristotle"],
        ["##ting"],
        ["woman"],
        ["man"]
    ],
)
iface.launch(enable_queue=True, debug=True, show_error=True)