Jonas Leeb commited on
Commit
fbba6d9
·
1 Parent(s): 5e54614

fixed bert not finding documents

Browse files
Files changed (2) hide show
  1. SciBERT_embeddings/scibert_embedding.npz +3 -0
  2. app.py +55 -47
SciBERT_embeddings/scibert_embedding.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb6a12b25db606552aba11f26f9bc5c2ac475183b93a21c6abddc743087e3bcd
3
+ size 80887259
app.py CHANGED
@@ -10,14 +10,15 @@ from datasets import load_dataset
10
  from gensim.models import KeyedVectors
11
  import plotly.graph_objects as go
12
  from sklearn.decomposition import PCA
13
-
14
 
15
 
16
 
17
  class ArxivSearch:
18
- def __init__(self, dataset, embedding="tfidf"):
19
  self.dataset = dataset
20
  self.embedding = embedding
 
21
  self.documents = []
22
  self.titles = []
23
  self.raw_texts = []
@@ -25,17 +26,16 @@ class ArxivSearch:
25
  self.last_results = []
26
  self.query_encoding = None
27
 
 
28
  self.embedding_dropdown = gr.Dropdown(
29
- choices=["tfidf", "word2vec", "bert"],
30
- value="tfidf",
31
  label="Model"
32
  )
33
-
34
-
35
- # Add a button to show the 3D plot
36
  self.plot_button = gr.Button("Show 3D Plot")
37
 
38
- # Define the interface using Blocks for more flexibility
39
  with gr.Blocks() as self.iface:
40
  gr.Markdown("# arXiv Search Engine")
41
  gr.Markdown("Search arXiv papers by keyword and embedding model.")
@@ -64,7 +64,7 @@ class ArxivSearch:
64
  )
65
  self.plot_button.click(
66
  self.plot_3d_embeddings,
67
- inputs=[self.embedding_dropdown],
68
  outputs=self.plot_output
69
  )
70
  self.search_button.click(
@@ -73,22 +73,11 @@ class ArxivSearch:
73
  outputs=self.output_md
74
  )
75
 
76
- # self.iface = gr.Interface(
77
- # fn=self.search_function,
78
- # inputs=[
79
- # gr.Textbox(lines=1, placeholder="Enter your search query"),
80
- # self.embedding_dropdown
81
- # ],
82
- # outputs=gr.Markdown(),
83
- # title="arXiv Search Engine",
84
- # description="Search arXiv papers by keyword and embedding model.",
85
- # )
86
-
87
  self.load_data(dataset)
88
- # self.load_model(embedding)
89
  self.load_model('tfidf')
90
  self.load_model('word2vec')
91
  self.load_model('bert')
 
92
 
93
  self.iface.launch()
94
 
@@ -124,19 +113,18 @@ class ArxivSearch:
124
  self.documents.append(text.strip())
125
  self.arxiv_ids.append(arxiv_id)
126
 
127
-
128
- def plot_3d_embeddings(self, embedding):
129
  # Example: plot random points, replace with your embeddings
130
  pca = PCA(n_components=3)
131
  results_indices = [i[0] for i in self.last_results]
132
- if embedding == "tfidf":
133
  all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0]))))
134
  all_data = self.tfidf_matrix[all_indices].toarray()
135
  pca.fit(all_data)
136
  reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
137
  reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
138
 
139
- elif embedding == "word2vec":
140
  all_indices = list(set(results_indices) | set(range(min(5000, self.word2vec_embeddings.shape[0]))))
141
  all_data = self.word2vec_embeddings[all_indices]
142
  pca.fit(all_data)
@@ -144,16 +132,22 @@ class ArxivSearch:
144
  reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
145
  query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
146
 
147
- elif embedding == "bert":
148
  all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
149
  all_data = self.bert_embeddings[all_indices]
150
  pca.fit(all_data)
151
  reduced_data = pca.transform(self.bert_embeddings[:5000])
152
  reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
153
  query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
154
-
 
 
 
 
 
 
155
  else:
156
- raise ValueError(f"Unsupported embedding type: {embedding}")
157
  trace = go.Scatter3d(
158
  x=reduced_data[:, 0],
159
  y=reduced_data[:, 1],
@@ -185,7 +179,7 @@ class ArxivSearch:
185
  marker=dict(size=3.5, color='orange', opacity=0.75),
186
  name='Results'
187
  )
188
- if not "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
189
  query_trace = go.Scatter3d(
190
  x=query_point[:, 0],
191
  y=query_point[:, 1],
@@ -231,11 +225,23 @@ class ArxivSearch:
231
  inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
232
  outputs = self.model(**inputs)
233
  query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
 
234
  self.query_encoding = query_vec
235
  sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
236
  top_indices = sims.argsort()[::-1][:top_n]
237
  return [(i, sims[i]) for i in top_indices]
238
 
 
 
 
 
 
 
 
 
 
 
 
239
  def bert_search_2(self, query, top_n=10):
240
  with torch.no_grad():
241
  inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
@@ -252,27 +258,28 @@ class ArxivSearch:
252
  return [(i, sims[i]) for i in top_indices]
253
 
254
  def load_model(self, embedding):
255
- if embedding == "tfidf":
 
256
  self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
257
  with open("TF-IDF embeddings/feature_names.txt", "r") as f:
258
  self.feature_names = [line.strip() for line in f.readlines()]
259
- elif embedding == "word2vec":
260
  # Use trimmed model here
261
- self.word2vec_embeddings = normalize(np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"])
262
  self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model")
263
- elif embedding == "bert":
264
- self.bert_embeddings = normalize(np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"])
265
  self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
266
  self.model = BertModel.from_pretrained('bert-base-uncased')
267
  self.model.eval()
 
 
 
 
 
268
  else:
269
- raise ValueError(f"Unsupported embedding type: {embedding}")
270
 
271
- def on_model_change(self, change):
272
- new_model = change["new"]
273
- self.embedding = new_model
274
- self.load_model(new_model)
275
-
276
 
277
  def snippet_before_abstract(self, text):
278
  pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
@@ -284,16 +291,18 @@ class ArxivSearch:
284
 
285
 
286
  def search_function(self, query, embedding):
287
- # Preprocess the query
288
- query = query.strip().lower()
289
 
290
  # Load or switch embedding model here if needed
291
- if embedding == "tfidf":
292
  results = self.keyword_match_ranking(query)
293
- elif embedding == "word2vec":
294
  results = self.word2vec_search(query)
295
- elif embedding == "bert":
296
  results = self.bert_search(query)
 
 
297
  else:
298
  return "No results found."
299
 
@@ -301,7 +310,6 @@ class ArxivSearch:
301
  self.last_results = []
302
  return "No results found."
303
 
304
-
305
  if results:
306
  self.last_results = results
307
 
@@ -323,5 +331,5 @@ class ArxivSearch:
323
 
324
  if __name__ == "__main__":
325
  dataset = load_dataset("ccdv/arxiv-classification", "no_ref") # replace with your dataset
326
- search_engine = ArxivSearch(dataset, embedding="tfidf") # Initialize with tfidf or any other embedding
327
  search_engine.iface.launch()
 
10
  from gensim.models import KeyedVectors
11
  import plotly.graph_objects as go
12
  from sklearn.decomposition import PCA
13
+ from transformers import AutoTokenizer, AutoModel
14
 
15
 
16
 
17
  class ArxivSearch:
18
+ def __init__(self, dataset, embedding="bert"):
19
  self.dataset = dataset
20
  self.embedding = embedding
21
+ self.query = None
22
  self.documents = []
23
  self.titles = []
24
  self.raw_texts = []
 
26
  self.last_results = []
27
  self.query_encoding = None
28
 
29
+ # model selection
30
  self.embedding_dropdown = gr.Dropdown(
31
+ choices=["tfidf", "word2vec", "bert", "scibert"],
32
+ value="bert",
33
  label="Model"
34
  )
35
+
 
 
36
  self.plot_button = gr.Button("Show 3D Plot")
37
 
38
+ # Gradio blocks for UI elements
39
  with gr.Blocks() as self.iface:
40
  gr.Markdown("# arXiv Search Engine")
41
  gr.Markdown("Search arXiv papers by keyword and embedding model.")
 
64
  )
65
  self.plot_button.click(
66
  self.plot_3d_embeddings,
67
+ inputs=[],
68
  outputs=self.plot_output
69
  )
70
  self.search_button.click(
 
73
  outputs=self.output_md
74
  )
75
 
 
 
 
 
 
 
 
 
 
 
 
76
  self.load_data(dataset)
 
77
  self.load_model('tfidf')
78
  self.load_model('word2vec')
79
  self.load_model('bert')
80
+ self.load_model('scibert')
81
 
82
  self.iface.launch()
83
 
 
113
  self.documents.append(text.strip())
114
  self.arxiv_ids.append(arxiv_id)
115
 
116
+ def plot_3d_embeddings(self):
 
117
  # Example: plot random points, replace with your embeddings
118
  pca = PCA(n_components=3)
119
  results_indices = [i[0] for i in self.last_results]
120
+ if self.embedding == "tfidf":
121
  all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0]))))
122
  all_data = self.tfidf_matrix[all_indices].toarray()
123
  pca.fit(all_data)
124
  reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
125
  reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
126
 
127
+ elif self.embedding == "word2vec":
128
  all_indices = list(set(results_indices) | set(range(min(5000, self.word2vec_embeddings.shape[0]))))
129
  all_data = self.word2vec_embeddings[all_indices]
130
  pca.fit(all_data)
 
132
  reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
133
  query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
134
 
135
+ elif self.embedding == "bert":
136
  all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
137
  all_data = self.bert_embeddings[all_indices]
138
  pca.fit(all_data)
139
  reduced_data = pca.transform(self.bert_embeddings[:5000])
140
  reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
141
  query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
142
+ elif self.embedding == "scibert":
143
+ all_indices = list(set(results_indices) | set(range(min(5000, self.scibert_embeddings.shape[0]))))
144
+ all_data = self.scibert_embeddings[all_indices]
145
+ pca.fit(all_data)
146
+ reduced_data = pca.transform(self.scibert_embeddings[:5000])
147
+ reduced_results_points = pca.transform(self.scibert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
148
+ query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
149
  else:
150
+ raise ValueError(f"Unsupported embedding type: {self.embedding}")
151
  trace = go.Scatter3d(
152
  x=reduced_data[:, 0],
153
  y=reduced_data[:, 1],
 
179
  marker=dict(size=3.5, color='orange', opacity=0.75),
180
  name='Results'
181
  )
182
+ if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
183
  query_trace = go.Scatter3d(
184
  x=query_point[:, 0],
185
  y=query_point[:, 1],
 
225
  inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
226
  outputs = self.model(**inputs)
227
  query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
228
+
229
  self.query_encoding = query_vec
230
  sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
231
  top_indices = sims.argsort()[::-1][:top_n]
232
  return [(i, sims[i]) for i in top_indices]
233
 
234
+ def scibert_search(self, query, top_n=10):
235
+ with torch.no_grad():
236
+ inputs = self.sci_tokenizer(query, return_tensors="pt", truncation=True, padding=True)
237
+ outputs = self.sci_model(**inputs)
238
+ query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
239
+
240
+ self.query_encoding = query_vec
241
+ sims = cosine_similarity(query_vec, self.scibert_embeddings).flatten()
242
+ top_indices = sims.argsort()[::-1][:top_n]
243
+ return [(i, sims[i]) for i in top_indices]
244
+
245
  def bert_search_2(self, query, top_n=10):
246
  with torch.no_grad():
247
  inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
 
258
  return [(i, sims[i]) for i in top_indices]
259
 
260
  def load_model(self, embedding):
261
+ self.embedding = embedding
262
+ if self.embedding == "tfidf":
263
  self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
264
  with open("TF-IDF embeddings/feature_names.txt", "r") as f:
265
  self.feature_names = [line.strip() for line in f.readlines()]
266
+ elif self.embedding == "word2vec":
267
  # Use trimmed model here
268
+ self.word2vec_embeddings = np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"]
269
  self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model")
270
+ elif self.embedding == "bert":
271
+ self.bert_embeddings = np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"]
272
  self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
273
  self.model = BertModel.from_pretrained('bert-base-uncased')
274
  self.model.eval()
275
+ elif self.embedding == "scibert":
276
+ self.scibert_embeddings = np.load("SciBERT_embeddings/scibert_embedding.npz")["bert_embedding"]
277
+ self.sci_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
278
+ self.sci_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
279
+ self.sci_model.eval()
280
  else:
281
+ raise ValueError(f"Unsupported embedding type: {self.embedding}")
282
 
 
 
 
 
 
283
 
284
  def snippet_before_abstract(self, text):
285
  pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
 
291
 
292
 
293
  def search_function(self, query, embedding):
294
+ self.embedding = embedding
295
+ query = query.encode().decode('unicode_escape') # Interpret escape sequences
296
 
297
  # Load or switch embedding model here if needed
298
+ if self.embedding == "tfidf":
299
  results = self.keyword_match_ranking(query)
300
+ elif self.embedding == "word2vec":
301
  results = self.word2vec_search(query)
302
+ elif self.embedding == "bert":
303
  results = self.bert_search(query)
304
+ elif self.embedding == "scibert":
305
+ results = self.scibert_search(query)
306
  else:
307
  return "No results found."
308
 
 
310
  self.last_results = []
311
  return "No results found."
312
 
 
313
  if results:
314
  self.last_results = results
315
 
 
331
 
332
  if __name__ == "__main__":
333
  dataset = load_dataset("ccdv/arxiv-classification", "no_ref") # replace with your dataset
334
+ search_engine = ArxivSearch(dataset)
335
  search_engine.iface.launch()