Christopher Akiki commited on
Commit
b830b93
Β·
1 Parent(s): 8ee1828

Minor fixes

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -6,14 +6,14 @@ from transformers import AutoTokenizer
6
 
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
9
- normalized_input_embeddings = np.load("embeddings/bert-large-uncased/normalized.npy")
10
- unnormalized_input_embeddings = np.load("embeddings/bert-large-uncased/unnormalized.npy")
11
 
12
- index_L2 = IndexFlatL2(unnormalized_input_embeddings.shape[-1])
13
- index_L2.add(unnormalized_input_embeddings)
14
 
15
- index_IP = IndexFlatIP(normalized_input_embeddings.shape[-1])
16
- index_IP.add(normalized_input_embeddings)
17
 
18
 
19
  vocab = {v:k for k,v in tokenizer.vocab.items()}
@@ -27,12 +27,12 @@ def get_first_subword(word):
27
 
28
  def search(token_to_lookup, num_neighbors=250):
29
  i = get_first_subword(token_to_lookup)
30
- _ , I_IP = index_IP.search(normalized_input_embeddings[i:i+1], num_neighbors)
31
  hits_IP = lookup_table.take(I_IP[0])
32
  results_IP = hits_IP.values[1:]
33
  results_IP = [r for r in results_IP if not "[unused" in r]
34
 
35
- _ , I_L2 = index_L2.search(unnormalized_input_embeddings[i:i+1], num_neighbors)
36
  hits_L2 = lookup_table.take(I_L2[0])
37
  results_L2 = hits_L2.values[1:]
38
  results_L2 = [r for r in results_L2 if not "[unused" in r]
 
6
 
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
9
+ normalized = np.load("embeddings/bert-large-uncased/normalized.npy")
10
+ unnormalized = np.load("embeddings/bert-large-uncased/unnormalized.npy")
11
 
12
+ index_L2 = IndexFlatL2(unnormalized.shape[-1])
13
+ index_L2.add(unnormalized)
14
 
15
+ index_IP = IndexFlatIP(normalized.shape[-1])
16
+ index_IP.add(normalized)
17
 
18
 
19
  vocab = {v:k for k,v in tokenizer.vocab.items()}
 
27
 
28
  def search(token_to_lookup, num_neighbors=250):
29
  i = get_first_subword(token_to_lookup)
30
+ _ , I_IP = index_IP.search(normalized[i:i+1], num_neighbors)
31
  hits_IP = lookup_table.take(I_IP[0])
32
  results_IP = hits_IP.values[1:]
33
  results_IP = [r for r in results_IP if not "[unused" in r]
34
 
35
+ _ , I_L2 = index_L2.search(unnormalized[i:i+1], num_neighbors)
36
  hits_L2 = lookup_table.take(I_L2[0])
37
  results_L2 = hits_L2.values[1:]
38
  results_L2 = [r for r in results_L2 if not "[unused" in r]