Jonas Leeb
commited on
Commit
·
fbba6d9
1
Parent(s):
5e54614
fixed bert not finding documents
Browse files- SciBERT_embeddings/scibert_embedding.npz +3 -0
- app.py +55 -47
SciBERT_embeddings/scibert_embedding.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb6a12b25db606552aba11f26f9bc5c2ac475183b93a21c6abddc743087e3bcd
|
3 |
+
size 80887259
|
app.py
CHANGED
@@ -10,14 +10,15 @@ from datasets import load_dataset
|
|
10 |
from gensim.models import KeyedVectors
|
11 |
import plotly.graph_objects as go
|
12 |
from sklearn.decomposition import PCA
|
13 |
-
|
14 |
|
15 |
|
16 |
|
17 |
class ArxivSearch:
|
18 |
-
def __init__(self, dataset, embedding="
|
19 |
self.dataset = dataset
|
20 |
self.embedding = embedding
|
|
|
21 |
self.documents = []
|
22 |
self.titles = []
|
23 |
self.raw_texts = []
|
@@ -25,17 +26,16 @@ class ArxivSearch:
|
|
25 |
self.last_results = []
|
26 |
self.query_encoding = None
|
27 |
|
|
|
28 |
self.embedding_dropdown = gr.Dropdown(
|
29 |
-
choices=["tfidf", "word2vec", "bert"],
|
30 |
-
value="
|
31 |
label="Model"
|
32 |
)
|
33 |
-
|
34 |
-
|
35 |
-
# Add a button to show the 3D plot
|
36 |
self.plot_button = gr.Button("Show 3D Plot")
|
37 |
|
38 |
-
#
|
39 |
with gr.Blocks() as self.iface:
|
40 |
gr.Markdown("# arXiv Search Engine")
|
41 |
gr.Markdown("Search arXiv papers by keyword and embedding model.")
|
@@ -64,7 +64,7 @@ class ArxivSearch:
|
|
64 |
)
|
65 |
self.plot_button.click(
|
66 |
self.plot_3d_embeddings,
|
67 |
-
inputs=[
|
68 |
outputs=self.plot_output
|
69 |
)
|
70 |
self.search_button.click(
|
@@ -73,22 +73,11 @@ class ArxivSearch:
|
|
73 |
outputs=self.output_md
|
74 |
)
|
75 |
|
76 |
-
# self.iface = gr.Interface(
|
77 |
-
# fn=self.search_function,
|
78 |
-
# inputs=[
|
79 |
-
# gr.Textbox(lines=1, placeholder="Enter your search query"),
|
80 |
-
# self.embedding_dropdown
|
81 |
-
# ],
|
82 |
-
# outputs=gr.Markdown(),
|
83 |
-
# title="arXiv Search Engine",
|
84 |
-
# description="Search arXiv papers by keyword and embedding model.",
|
85 |
-
# )
|
86 |
-
|
87 |
self.load_data(dataset)
|
88 |
-
# self.load_model(embedding)
|
89 |
self.load_model('tfidf')
|
90 |
self.load_model('word2vec')
|
91 |
self.load_model('bert')
|
|
|
92 |
|
93 |
self.iface.launch()
|
94 |
|
@@ -124,19 +113,18 @@ class ArxivSearch:
|
|
124 |
self.documents.append(text.strip())
|
125 |
self.arxiv_ids.append(arxiv_id)
|
126 |
|
127 |
-
|
128 |
-
def plot_3d_embeddings(self, embedding):
|
129 |
# Example: plot random points, replace with your embeddings
|
130 |
pca = PCA(n_components=3)
|
131 |
results_indices = [i[0] for i in self.last_results]
|
132 |
-
if embedding == "tfidf":
|
133 |
all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0]))))
|
134 |
all_data = self.tfidf_matrix[all_indices].toarray()
|
135 |
pca.fit(all_data)
|
136 |
reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
|
137 |
reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
|
138 |
|
139 |
-
elif embedding == "word2vec":
|
140 |
all_indices = list(set(results_indices) | set(range(min(5000, self.word2vec_embeddings.shape[0]))))
|
141 |
all_data = self.word2vec_embeddings[all_indices]
|
142 |
pca.fit(all_data)
|
@@ -144,16 +132,22 @@ class ArxivSearch:
|
|
144 |
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
145 |
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
|
146 |
|
147 |
-
elif embedding == "bert":
|
148 |
all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
|
149 |
all_data = self.bert_embeddings[all_indices]
|
150 |
pca.fit(all_data)
|
151 |
reduced_data = pca.transform(self.bert_embeddings[:5000])
|
152 |
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
153 |
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
else:
|
156 |
-
raise ValueError(f"Unsupported embedding type: {embedding}")
|
157 |
trace = go.Scatter3d(
|
158 |
x=reduced_data[:, 0],
|
159 |
y=reduced_data[:, 1],
|
@@ -185,7 +179,7 @@ class ArxivSearch:
|
|
185 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
186 |
name='Results'
|
187 |
)
|
188 |
-
if not "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
189 |
query_trace = go.Scatter3d(
|
190 |
x=query_point[:, 0],
|
191 |
y=query_point[:, 1],
|
@@ -231,11 +225,23 @@ class ArxivSearch:
|
|
231 |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
232 |
outputs = self.model(**inputs)
|
233 |
query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
|
|
|
234 |
self.query_encoding = query_vec
|
235 |
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
|
236 |
top_indices = sims.argsort()[::-1][:top_n]
|
237 |
return [(i, sims[i]) for i in top_indices]
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
def bert_search_2(self, query, top_n=10):
|
240 |
with torch.no_grad():
|
241 |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
@@ -252,27 +258,28 @@ class ArxivSearch:
|
|
252 |
return [(i, sims[i]) for i in top_indices]
|
253 |
|
254 |
def load_model(self, embedding):
|
255 |
-
|
|
|
256 |
self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
|
257 |
with open("TF-IDF embeddings/feature_names.txt", "r") as f:
|
258 |
self.feature_names = [line.strip() for line in f.readlines()]
|
259 |
-
elif embedding == "word2vec":
|
260 |
# Use trimmed model here
|
261 |
-
self.word2vec_embeddings =
|
262 |
self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model")
|
263 |
-
elif embedding == "bert":
|
264 |
-
self.bert_embeddings =
|
265 |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
266 |
self.model = BertModel.from_pretrained('bert-base-uncased')
|
267 |
self.model.eval()
|
|
|
|
|
|
|
|
|
|
|
268 |
else:
|
269 |
-
raise ValueError(f"Unsupported embedding type: {embedding}")
|
270 |
|
271 |
-
def on_model_change(self, change):
|
272 |
-
new_model = change["new"]
|
273 |
-
self.embedding = new_model
|
274 |
-
self.load_model(new_model)
|
275 |
-
|
276 |
|
277 |
def snippet_before_abstract(self, text):
|
278 |
pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
|
@@ -284,16 +291,18 @@ class ArxivSearch:
|
|
284 |
|
285 |
|
286 |
def search_function(self, query, embedding):
|
287 |
-
|
288 |
-
query = query.
|
289 |
|
290 |
# Load or switch embedding model here if needed
|
291 |
-
if embedding == "tfidf":
|
292 |
results = self.keyword_match_ranking(query)
|
293 |
-
elif embedding == "word2vec":
|
294 |
results = self.word2vec_search(query)
|
295 |
-
elif embedding == "bert":
|
296 |
results = self.bert_search(query)
|
|
|
|
|
297 |
else:
|
298 |
return "No results found."
|
299 |
|
@@ -301,7 +310,6 @@ class ArxivSearch:
|
|
301 |
self.last_results = []
|
302 |
return "No results found."
|
303 |
|
304 |
-
|
305 |
if results:
|
306 |
self.last_results = results
|
307 |
|
@@ -323,5 +331,5 @@ class ArxivSearch:
|
|
323 |
|
324 |
if __name__ == "__main__":
|
325 |
dataset = load_dataset("ccdv/arxiv-classification", "no_ref") # replace with your dataset
|
326 |
-
search_engine = ArxivSearch(dataset
|
327 |
search_engine.iface.launch()
|
|
|
10 |
from gensim.models import KeyedVectors
|
11 |
import plotly.graph_objects as go
|
12 |
from sklearn.decomposition import PCA
|
13 |
+
from transformers import AutoTokenizer, AutoModel
|
14 |
|
15 |
|
16 |
|
17 |
class ArxivSearch:
|
18 |
+
def __init__(self, dataset, embedding="bert"):
|
19 |
self.dataset = dataset
|
20 |
self.embedding = embedding
|
21 |
+
self.query = None
|
22 |
self.documents = []
|
23 |
self.titles = []
|
24 |
self.raw_texts = []
|
|
|
26 |
self.last_results = []
|
27 |
self.query_encoding = None
|
28 |
|
29 |
+
# model selection
|
30 |
self.embedding_dropdown = gr.Dropdown(
|
31 |
+
choices=["tfidf", "word2vec", "bert", "scibert"],
|
32 |
+
value="bert",
|
33 |
label="Model"
|
34 |
)
|
35 |
+
|
|
|
|
|
36 |
self.plot_button = gr.Button("Show 3D Plot")
|
37 |
|
38 |
+
# Gradio blocks for UI elements
|
39 |
with gr.Blocks() as self.iface:
|
40 |
gr.Markdown("# arXiv Search Engine")
|
41 |
gr.Markdown("Search arXiv papers by keyword and embedding model.")
|
|
|
64 |
)
|
65 |
self.plot_button.click(
|
66 |
self.plot_3d_embeddings,
|
67 |
+
inputs=[],
|
68 |
outputs=self.plot_output
|
69 |
)
|
70 |
self.search_button.click(
|
|
|
73 |
outputs=self.output_md
|
74 |
)
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
self.load_data(dataset)
|
|
|
77 |
self.load_model('tfidf')
|
78 |
self.load_model('word2vec')
|
79 |
self.load_model('bert')
|
80 |
+
self.load_model('scibert')
|
81 |
|
82 |
self.iface.launch()
|
83 |
|
|
|
113 |
self.documents.append(text.strip())
|
114 |
self.arxiv_ids.append(arxiv_id)
|
115 |
|
116 |
+
def plot_3d_embeddings(self):
|
|
|
117 |
# Example: plot random points, replace with your embeddings
|
118 |
pca = PCA(n_components=3)
|
119 |
results_indices = [i[0] for i in self.last_results]
|
120 |
+
if self.embedding == "tfidf":
|
121 |
all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0]))))
|
122 |
all_data = self.tfidf_matrix[all_indices].toarray()
|
123 |
pca.fit(all_data)
|
124 |
reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
|
125 |
reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
|
126 |
|
127 |
+
elif self.embedding == "word2vec":
|
128 |
all_indices = list(set(results_indices) | set(range(min(5000, self.word2vec_embeddings.shape[0]))))
|
129 |
all_data = self.word2vec_embeddings[all_indices]
|
130 |
pca.fit(all_data)
|
|
|
132 |
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
133 |
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
|
134 |
|
135 |
+
elif self.embedding == "bert":
|
136 |
all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
|
137 |
all_data = self.bert_embeddings[all_indices]
|
138 |
pca.fit(all_data)
|
139 |
reduced_data = pca.transform(self.bert_embeddings[:5000])
|
140 |
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
141 |
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
|
142 |
+
elif self.embedding == "scibert":
|
143 |
+
all_indices = list(set(results_indices) | set(range(min(5000, self.scibert_embeddings.shape[0]))))
|
144 |
+
all_data = self.scibert_embeddings[all_indices]
|
145 |
+
pca.fit(all_data)
|
146 |
+
reduced_data = pca.transform(self.scibert_embeddings[:5000])
|
147 |
+
reduced_results_points = pca.transform(self.scibert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
148 |
+
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
|
149 |
else:
|
150 |
+
raise ValueError(f"Unsupported embedding type: {self.embedding}")
|
151 |
trace = go.Scatter3d(
|
152 |
x=reduced_data[:, 0],
|
153 |
y=reduced_data[:, 1],
|
|
|
179 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
180 |
name='Results'
|
181 |
)
|
182 |
+
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
183 |
query_trace = go.Scatter3d(
|
184 |
x=query_point[:, 0],
|
185 |
y=query_point[:, 1],
|
|
|
225 |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
226 |
outputs = self.model(**inputs)
|
227 |
query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
|
228 |
+
|
229 |
self.query_encoding = query_vec
|
230 |
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
|
231 |
top_indices = sims.argsort()[::-1][:top_n]
|
232 |
return [(i, sims[i]) for i in top_indices]
|
233 |
|
234 |
+
def scibert_search(self, query, top_n=10):
|
235 |
+
with torch.no_grad():
|
236 |
+
inputs = self.sci_tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
237 |
+
outputs = self.sci_model(**inputs)
|
238 |
+
query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
|
239 |
+
|
240 |
+
self.query_encoding = query_vec
|
241 |
+
sims = cosine_similarity(query_vec, self.scibert_embeddings).flatten()
|
242 |
+
top_indices = sims.argsort()[::-1][:top_n]
|
243 |
+
return [(i, sims[i]) for i in top_indices]
|
244 |
+
|
245 |
def bert_search_2(self, query, top_n=10):
|
246 |
with torch.no_grad():
|
247 |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
|
|
258 |
return [(i, sims[i]) for i in top_indices]
|
259 |
|
260 |
def load_model(self, embedding):
|
261 |
+
self.embedding = embedding
|
262 |
+
if self.embedding == "tfidf":
|
263 |
self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
|
264 |
with open("TF-IDF embeddings/feature_names.txt", "r") as f:
|
265 |
self.feature_names = [line.strip() for line in f.readlines()]
|
266 |
+
elif self.embedding == "word2vec":
|
267 |
# Use trimmed model here
|
268 |
+
self.word2vec_embeddings = np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"]
|
269 |
self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model")
|
270 |
+
elif self.embedding == "bert":
|
271 |
+
self.bert_embeddings = np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"]
|
272 |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
273 |
self.model = BertModel.from_pretrained('bert-base-uncased')
|
274 |
self.model.eval()
|
275 |
+
elif self.embedding == "scibert":
|
276 |
+
self.scibert_embeddings = np.load("SciBERT_embeddings/scibert_embedding.npz")["bert_embedding"]
|
277 |
+
self.sci_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
|
278 |
+
self.sci_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
|
279 |
+
self.sci_model.eval()
|
280 |
else:
|
281 |
+
raise ValueError(f"Unsupported embedding type: {self.embedding}")
|
282 |
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
def snippet_before_abstract(self, text):
|
285 |
pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
|
|
|
291 |
|
292 |
|
293 |
def search_function(self, query, embedding):
|
294 |
+
self.embedding = embedding
|
295 |
+
query = query.encode().decode('unicode_escape') # Interpret escape sequences
|
296 |
|
297 |
# Load or switch embedding model here if needed
|
298 |
+
if self.embedding == "tfidf":
|
299 |
results = self.keyword_match_ranking(query)
|
300 |
+
elif self.embedding == "word2vec":
|
301 |
results = self.word2vec_search(query)
|
302 |
+
elif self.embedding == "bert":
|
303 |
results = self.bert_search(query)
|
304 |
+
elif self.embedding == "scibert":
|
305 |
+
results = self.scibert_search(query)
|
306 |
else:
|
307 |
return "No results found."
|
308 |
|
|
|
310 |
self.last_results = []
|
311 |
return "No results found."
|
312 |
|
|
|
313 |
if results:
|
314 |
self.last_results = results
|
315 |
|
|
|
331 |
|
332 |
if __name__ == "__main__":
|
333 |
dataset = load_dataset("ccdv/arxiv-classification", "no_ref") # replace with your dataset
|
334 |
+
search_engine = ArxivSearch(dataset)
|
335 |
search_engine.iface.launch()
|