token-explorer / app.py
christopher's picture
Update app.py
19c7d3c
raw
history blame
2.37 kB
import gradio as gr
from faiss import IndexFlatIP, IndexFlatL2
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
normalized_input_embeddings = np.load("normalized_bert_input_embeddings.npy")
unnormalized_input_embeddings = np.load("unnormalized_bert_input_embeddings.npy")
index_L2 = IndexFlatL2(unnormalized_input_embeddings.shape[-1])
index_L2.add(unnormalized_input_embeddings)
index_IP = IndexFlatIP(normalized_input_embeddings.shape[-1])
index_IP.add(normalized_input_embeddings)
vocab = {v:k for k,v in tokenizer.vocab.items()}
lookup_table = pd.Series(vocab).sort_index()
def get_first_subword(word):
try:
return tokenizer.vocab[word]
except:
return tokenizer(word, add_special_tokens=False)['input_ids'][0]
def search(token_to_lookup, num_neighbors=100):
i = get_first_subword(token_to_lookup)
_ , I_IP = index_IP.search(normalized_input_embeddings[i:i+1], num_neighbors)
hits_IP = lookup_table.take(I_IP[0])
results_IP = hits_IP.values[1:]
results_IP = [r for r in results_IP if not "[unused" in r]
_ , I_L2 = index_L2.search(unnormalized_input_embeddings[i:i+1], num_neighbors)
hits_L2 = lookup_table.take(I_L2[0])
results_L2 = hits_L2.values[1:]
results_L2 = [r for r in results_L2 if not "[unused" in r]
return [r for r in results_IP if not "##" in r], [r for r in results_IP if "##" in r], [r for r in results_L2 if not "##" in r], [r for r in results_L2 if "##" in r]
iface = gr.Interface(
fn=search,
#inputs=[gr.Textbox(lines=1, label="Vocabulary Token", placeholder="Enter token..."), gr.Number(value=50, label="number of neighbors")],
inputs=gr.Textbox(lines=1, label="Vocabulary Token", placeholder="Enter token..."),
outputs=[gr.Textbox(label="IP-Nearest tokens"), gr.Textbox(label="IP-Nearest subwords"), gr.Textbox(label="L2-Nearest tokens"), gr.Textbox(label="L2-Nearest subwords")],
examples=[
["##logy"],
["##ness"],
["##ity"],
["responded"],
["queen"],
["king"],
["hospital"],
["disease"],
["grammar"],
["philosophy"],
["aristotle"],
["##ting"],
["woman"],
["man"]
],
)
iface.launch(enable_queue=True, debug=True, show_error=True)