Jonas Leeb commited on
Commit
4bc7b36
·
1 Parent(s): 0cbf59f

added sbert and simplified plotting

Browse files
Files changed (2) hide show
  1. BERT embeddings/sbert_embedding.npz +3 -0
  2. app.py +57 -53
BERT embeddings/sbert_embedding.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5285b4f848c6dd13e7141a2684857daf6d8d02fdb18fb4812182fe31780c717
3
+ size 40407781
app.py CHANGED
@@ -11,8 +11,8 @@ from gensim.models import KeyedVectors
11
  import plotly.graph_objects as go
12
  from sklearn.decomposition import PCA
13
  from transformers import AutoTokenizer, AutoModel
14
-
15
-
16
 
17
  class ArxivSearch:
18
  def __init__(self, dataset, embedding="bert"):
@@ -28,7 +28,7 @@ class ArxivSearch:
28
 
29
  # model selection
30
  self.embedding_dropdown = gr.Dropdown(
31
- choices=["tfidf", "word2vec", "bert", "scibert"],
32
  value="bert",
33
  label="Model"
34
  )
@@ -56,7 +56,6 @@ class ArxivSearch:
56
  inputs=[self.query_box, self.embedding_dropdown],
57
  outputs=self.output_md
58
  )
59
-
60
  self.embedding_dropdown.change(
61
  self.search_function,
62
  inputs=[self.query_box, self.embedding_dropdown],
@@ -78,6 +77,7 @@ class ArxivSearch:
78
  self.load_model('word2vec')
79
  self.load_model('bert')
80
  self.load_model('scibert')
 
81
 
82
  self.iface.launch()
83
 
@@ -113,6 +113,16 @@ class ArxivSearch:
113
  self.documents.append(text.strip())
114
  self.arxiv_ids.append(arxiv_id)
115
 
 
 
 
 
 
 
 
 
 
 
116
  def plot_3d_embeddings(self):
117
  # Example: plot random points, replace with your embeddings
118
  pca = PCA(n_components=3)
@@ -123,29 +133,14 @@ class ArxivSearch:
123
  pca.fit(all_data)
124
  reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
125
  reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
126
-
127
  elif self.embedding == "word2vec":
128
- all_indices = list(set(results_indices) | set(range(min(5000, self.word2vec_embeddings.shape[0]))))
129
- all_data = self.word2vec_embeddings[all_indices]
130
- pca.fit(all_data)
131
- reduced_data = pca.transform(self.word2vec_embeddings[:5000])
132
- reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
133
- query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
134
-
135
  elif self.embedding == "bert":
136
- all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
137
- all_data = self.bert_embeddings[all_indices]
138
- pca.fit(all_data)
139
- reduced_data = pca.transform(self.bert_embeddings[:5000])
140
- reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
141
- query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
142
  elif self.embedding == "scibert":
143
- all_indices = list(set(results_indices) | set(range(min(5000, self.scibert_embeddings.shape[0]))))
144
- all_data = self.scibert_embeddings[all_indices]
145
- pca.fit(all_data)
146
- reduced_data = pca.transform(self.scibert_embeddings[:5000])
147
- reduced_results_points = pca.transform(self.scibert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
148
- query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
149
  else:
150
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
151
  trace = go.Scatter3d(
@@ -159,9 +154,9 @@ class ArxivSearch:
159
  layout = go.Layout(
160
  margin=dict(l=0, r=0, b=0, t=0),
161
  scene=dict(
162
- xaxis_title='X',
163
- yaxis_title='Y',
164
- zaxis_title='Z',
165
  xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
166
  yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
167
  zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
@@ -222,40 +217,40 @@ class ArxivSearch:
222
 
223
  def bert_search(self, query, top_n=10):
224
  with torch.no_grad():
225
- inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
226
  outputs = self.model(**inputs)
227
- query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
 
228
 
229
  self.query_encoding = query_vec
230
  sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
231
  top_indices = sims.argsort()[::-1][:top_n]
 
232
  return [(i, sims[i]) for i in top_indices]
233
 
234
  def scibert_search(self, query, top_n=10):
235
  with torch.no_grad():
236
  inputs = self.sci_tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=512)
237
  outputs = self.sci_model(**inputs)
238
- query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
239
 
240
  self.query_encoding = query_vec
241
  sims = cosine_similarity(query_vec, self.scibert_embeddings).flatten()
242
  top_indices = sims.argsort()[::-1][:top_n]
 
243
  return [(i, sims[i]) for i in top_indices]
244
 
245
- def bert_search_2(self, query, top_n=10):
246
- with torch.no_grad():
247
- inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
248
- outputs = self.model(**inputs)
249
- token_embeddings = outputs.last_hidden_state
250
- attention_mask = inputs['attention_mask']
251
- mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
252
- sentence_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
253
- sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
254
- query_vec = sentence_embeddings / sum_mask
255
  self.query_encoding = query_vec
256
- sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
257
- top_indices = sims.argsort()[::-1][:top_n]
258
- return [(i, sims[i]) for i in top_indices]
 
 
 
 
 
259
 
260
  def load_model(self, embedding):
261
  self.embedding = embedding
@@ -277,6 +272,10 @@ class ArxivSearch:
277
  self.sci_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
278
  self.sci_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
279
  self.sci_model.eval()
 
 
 
 
280
  else:
281
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
282
 
@@ -285,13 +284,15 @@ class ArxivSearch:
285
  pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
286
  match = pattern.search(text)
287
  if match:
288
- return text[:match.start()].strip()
289
  else:
290
- return text[:100].strip()
291
 
 
 
292
 
293
  def search_function(self, query, embedding):
294
- self.embedding = embedding
295
  query = query.encode().decode('unicode_escape') # Interpret escape sequences
296
 
297
  # Load or switch embedding model here if needed
@@ -303,6 +304,8 @@ class ArxivSearch:
303
  results = self.bert_search(query)
304
  elif self.embedding == "scibert":
305
  results = self.scibert_search(query)
 
 
306
  else:
307
  return "No results found."
308
 
@@ -317,13 +320,14 @@ class ArxivSearch:
317
  display_rank = 1
318
  for idx, score in results:
319
  if not self.arxiv_ids[idx]:
320
- continue
321
-
322
- link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}"
323
- snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>')
324
- output += f"### Document {display_rank}\n"
325
- output += f"[arXiv Link]({link})\n\n"
326
- output += f"<pre>{snippet}</pre>\n\n---\n"
 
327
  display_rank += 1
328
 
329
  return output
 
11
  import plotly.graph_objects as go
12
  from sklearn.decomposition import PCA
13
  from transformers import AutoTokenizer, AutoModel
14
+ from sentence_transformers import CrossEncoder
15
+ from sentence_transformers import SentenceTransformer
16
 
17
  class ArxivSearch:
18
  def __init__(self, dataset, embedding="bert"):
 
28
 
29
  # model selection
30
  self.embedding_dropdown = gr.Dropdown(
31
+ choices=["tfidf", "word2vec", "bert", "scibert", "sbert"],
32
  value="bert",
33
  label="Model"
34
  )
 
56
  inputs=[self.query_box, self.embedding_dropdown],
57
  outputs=self.output_md
58
  )
 
59
  self.embedding_dropdown.change(
60
  self.search_function,
61
  inputs=[self.query_box, self.embedding_dropdown],
 
77
  self.load_model('word2vec')
78
  self.load_model('bert')
79
  self.load_model('scibert')
80
+ self.load_model('sbert')
81
 
82
  self.iface.launch()
83
 
 
113
  self.documents.append(text.strip())
114
  self.arxiv_ids.append(arxiv_id)
115
 
116
+ def plot_dense(self, embedding, pca, results_indices):
117
+ print(self.query_encoding.shape[0])
118
+ all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
119
+ all_data = embedding[all_indices]
120
+ pca.fit(all_data)
121
+ reduced_data = pca.transform(embedding[:5000])
122
+ reduced_results_points = pca.transform(embedding[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
123
+ query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
124
+ return reduced_data, reduced_results_points, query_point
125
+
126
  def plot_3d_embeddings(self):
127
  # Example: plot random points, replace with your embeddings
128
  pca = PCA(n_components=3)
 
133
  pca.fit(all_data)
134
  reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
135
  reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
 
136
  elif self.embedding == "word2vec":
137
+ reduced_data, reduced_results_points, query_point = self.plot_dense(self.word2vec_embeddings, pca, results_indices)
 
 
 
 
 
 
138
  elif self.embedding == "bert":
139
+ reduced_data, reduced_results_points, query_point = self.plot_dense(self.bert_embeddings, pca, results_indices)
140
+ elif self.embedding == "sbert":
141
+ reduced_data, reduced_results_points, query_point = self.plot_dense(self.sbert_embedding, pca, results_indices)
 
 
 
142
  elif self.embedding == "scibert":
143
+ reduced_data, reduced_results_points, query_point = self.plot_dense(self.scibert_embeddings, pca, results_indices)
 
 
 
 
 
144
  else:
145
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
146
  trace = go.Scatter3d(
 
154
  layout = go.Layout(
155
  margin=dict(l=0, r=0, b=0, t=0),
156
  scene=dict(
157
+ xaxis_title='PCA 1',
158
+ yaxis_title='PCA 2',
159
+ zaxis_title='PCA 3',
160
  xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
161
  yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
162
  zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
 
217
 
218
  def bert_search(self, query, top_n=10):
219
  with torch.no_grad():
220
+ inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
221
  outputs = self.model(**inputs)
222
+ # query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
223
+ query_vec = outputs.last_hidden_state[:, 0, :].numpy()
224
 
225
  self.query_encoding = query_vec
226
  sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
227
  top_indices = sims.argsort()[::-1][:top_n]
228
+ print(f"sim, top_indices: {sims}, {top_indices}")
229
  return [(i, sims[i]) for i in top_indices]
230
 
231
  def scibert_search(self, query, top_n=10):
232
  with torch.no_grad():
233
  inputs = self.sci_tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=512)
234
  outputs = self.sci_model(**inputs)
235
+ query_vec = outputs.last_hidden_state[:, 0, :].numpy()
236
 
237
  self.query_encoding = query_vec
238
  sims = cosine_similarity(query_vec, self.scibert_embeddings).flatten()
239
  top_indices = sims.argsort()[::-1][:top_n]
240
+ print(f"sim, top_indices: {sims}, {top_indices}")
241
  return [(i, sims[i]) for i in top_indices]
242
 
243
+ def sbert_search(self, query, top_n=10):
244
+ query_vec = self.sbert_model.encode([query])
 
 
 
 
 
 
 
 
245
  self.query_encoding = query_vec
246
+ cos_scores = cosine_similarity(query_vec, self.sbert_embedding)[0]
247
+ top_k_indices = np.argsort(cos_scores)[-50:][::-1]
248
+ candidates = [dataset['train'][int(i)]['text'] for i in top_k_indices]
249
+ scores = self.cross_encoder.predict([(query, doc) for doc in candidates])
250
+ final_scores = 0.7 * scores + 0.3 * cos_scores[top_k_indices]
251
+ top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
252
+ print(f"sim, top_indices: {final_scores}, {top_indices}")
253
+ return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]
254
 
255
  def load_model(self, embedding):
256
  self.embedding = embedding
 
272
  self.sci_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
273
  self.sci_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
274
  self.sci_model.eval()
275
+ elif self.embedding == "sbert":
276
+ self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
277
+ self.sbert_embedding = np.load("BERT embeddings/sbert_embedding.npz")["sbert_embedding"]
278
+ self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
279
  else:
280
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
281
 
 
284
  pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
285
  match = pattern.search(text)
286
  if match:
287
+ return text[:match.start()].strip() if match.start() < 1000 else text[:100].strip()
288
  else:
289
+ return text[:300].strip()
290
 
291
+ def set_embedding(self, embedding):
292
+ self.embedding = embedding
293
 
294
  def search_function(self, query, embedding):
295
+ self.set_embedding(embedding)
296
  query = query.encode().decode('unicode_escape') # Interpret escape sequences
297
 
298
  # Load or switch embedding model here if needed
 
304
  results = self.bert_search(query)
305
  elif self.embedding == "scibert":
306
  results = self.scibert_search(query)
307
+ elif self.embedding == "sbert":
308
+ results = self.sbert_search(query)
309
  else:
310
  return "No results found."
311
 
 
320
  display_rank = 1
321
  for idx, score in results:
322
  if not self.arxiv_ids[idx]:
323
+ output += f"### Document {display_rank}\n"
324
+ output += f"<pre>{self.documents[idx][:200]}</pre>\n\n"
325
+ else:
326
+ link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}"
327
+ snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>')
328
+ output += f"### Document {display_rank}\n"
329
+ output += f"[arXiv Link]({link})\n\n"
330
+ output += f"<pre>{snippet}</pre>\n\n---\n"
331
  display_rank += 1
332
 
333
  return output