christopher commited on
Commit
a147e52
Β·
1 Parent(s): a7c877e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -7,8 +7,21 @@ from transformers import AutoTokenizer
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
9
  input_embeddings = np.load("bert_input_embeddings.npy")
10
- index = IndexFlatL2(input_embeddings.shape[-1])
11
- index.add(input_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  vocab = {v:k for k,v in tokenizer.vocab.items()}
13
  lookup_table = pd.Series(vocab).sort_index()
14
 
@@ -18,9 +31,9 @@ def get_first_subword(word):
18
  except:
19
  return tokenizer(word, add_special_tokens=False)['input_ids'][0]
20
 
21
- def search(token_to_lookup, num_neighbors=50):
22
  i = get_first_subword(token_to_lookup)
23
- _ , I = index.search(input_embeddings[i:i+1], num_neighbors)
24
  hits = lookup_table.take(I[0])
25
  results = hits.values[1:]
26
  return [r for r in results if not "##" in r], [[r for r in results if "##" in r]]
 
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
9
  input_embeddings = np.load("bert_input_embeddings.npy")
10
+ unnormalized_input_embeddings = np.load("unnormalized_bert_input_embeddings.npy")
11
+
12
+ index_L2 = IndexFlatL2(input_embeddings.shape[-1])
13
+ index_L2.add(input_embeddings)
14
+
15
+ index_IP = IndexFlatIP(input_embeddings.shape[-1])
16
+ index_IP.add(input_embeddings)
17
+
18
+ index_L2_unnormalized = IndexFlatL2(unnormalized_input_embeddings.shape[-1])
19
+ index_L2_unnormalized.add(unnormalized_input_embeddings)
20
+
21
+ index_IP_unnormalized = IndexFlatIP(unnormalized_input_embeddings.shape[-1])
22
+ index_IP_unnormalized.add(unnormalized_input_embeddings)
23
+
24
+
25
  vocab = {v:k for k,v in tokenizer.vocab.items()}
26
  lookup_table = pd.Series(vocab).sort_index()
27
 
 
31
  except:
32
  return tokenizer(word, add_special_tokens=False)['input_ids'][0]
33
 
34
+ def search(token_to_lookup, num_neighbors=200):
35
  i = get_first_subword(token_to_lookup)
36
+ _ , I = index_L2_unnormalized.search(unnormalized_input_embeddings[i:i+1], num_neighbors)
37
  hits = lookup_table.take(I[0])
38
  results = hits.values[1:]
39
  return [r for r in results if not "##" in r], [[r for r in results if "##" in r]]