|
import re |
|
import gradio as gr |
|
from scipy.sparse import load_npz |
|
import torch |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from sklearn.preprocessing import normalize |
|
from transformers import BertTokenizer, BertModel |
|
import numpy as np |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from gensim.models import KeyedVectors |
|
import plotly.graph_objects as go |
|
from sklearn.decomposition import PCA |
|
from transformers import AutoTokenizer, AutoModel |
|
from sentence_transformers import CrossEncoder |
|
from sentence_transformers import SentenceTransformer |
|
|
|
class ArxivSearch: |
|
def __init__(self, dataset, embedding="sbert"): |
|
self.dataset = dataset |
|
self.embedding = embedding |
|
self.query = None |
|
self.documents = [] |
|
self.titles = [] |
|
self.raw_texts = [] |
|
self.arxiv_ids = [] |
|
self.last_results = [] |
|
self.query_encoding = None |
|
|
|
|
|
self.embedding_dropdown = gr.Dropdown( |
|
choices=["tfidf", "word2vec", "bert", "sbert", "clustered sbert"], |
|
value="sbert", |
|
label="Model" |
|
) |
|
|
|
self.plot_button = gr.Button("Show 3D Plot") |
|
|
|
|
|
with gr.Blocks() as self.iface: |
|
gr.Markdown("# arXiv Search Engine") |
|
gr.Markdown("Search arXiv papers by keyword and embedding model.") |
|
|
|
self.plot_output = gr.Plot() |
|
|
|
with gr.Row(): |
|
self.query_box = gr.Textbox(lines=1, placeholder="Enter your search query", label="Query") |
|
self.embedding_dropdown.render() |
|
self.plot_button.render() |
|
with gr.Column(): |
|
self.search_button = gr.Button("Search") |
|
|
|
self.output_md = gr.Markdown() |
|
|
|
self.query_box.submit( |
|
self.search_function, |
|
inputs=[self.query_box, self.embedding_dropdown], |
|
outputs=self.output_md |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.embedding_dropdown.change( |
|
self.search_function, |
|
inputs=[self.query_box, self.embedding_dropdown], |
|
outputs=self.output_md |
|
) |
|
self.plot_button.click( |
|
self.plot_3d_embeddings, |
|
inputs=[], |
|
outputs=self.plot_output |
|
) |
|
self.search_button.click( |
|
self.search_function, |
|
inputs=[self.query_box, self.embedding_dropdown], |
|
outputs=self.output_md |
|
) |
|
|
|
self.load_data(dataset) |
|
|
|
self.load_model('tfidf') |
|
self.load_model('word2vec') |
|
self.load_model('bert') |
|
|
|
|
|
self.load_model('clustered sbert') |
|
|
|
self.iface.launch() |
|
|
|
def load_data(self, dataset): |
|
train_data = dataset["train"] |
|
for item in train_data.select(range(len(train_data))): |
|
text = item["text"] |
|
if not text or len(text.strip()) < 10: |
|
continue |
|
|
|
lines = text.splitlines() |
|
title_lines = [] |
|
found_arxiv = False |
|
arxiv_id = None |
|
|
|
for line in lines: |
|
line_strip = line.strip() |
|
if not found_arxiv and line_strip.lower().startswith("arxiv:"): |
|
found_arxiv = True |
|
match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE) |
|
if match: |
|
arxiv_id = match.group(0).lower() |
|
elif not found_arxiv: |
|
title_lines.append(line_strip) |
|
else: |
|
if line_strip.lower().startswith("abstract"): |
|
break |
|
|
|
title = " ".join(title_lines).strip() |
|
|
|
self.raw_texts.append(text.strip()) |
|
self.titles.append(title) |
|
self.documents.append(text.strip()) |
|
self.arxiv_ids.append(arxiv_id) |
|
|
|
def plot_dense(self, embedding, pca, results_indices): |
|
all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0])))) |
|
all_data = embedding[all_indices] |
|
pca.fit(all_data) |
|
reduced_data = pca.transform(embedding[:5000]) |
|
reduced_results_points = pca.transform(embedding[results_indices]) if len(results_indices) > 0 else np.empty((0, 3)) |
|
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3)) |
|
return reduced_data, reduced_results_points, query_point |
|
|
|
def plot_3d_embeddings(self): |
|
|
|
pca = PCA(n_components=3) |
|
results_indices = [i[0] for i in self.last_results] |
|
|
|
if self.embedding == "tfidf": |
|
all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0])))) |
|
all_data = self.tfidf_matrix[all_indices].toarray() |
|
pca.fit(all_data) |
|
reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray()) |
|
reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3)) |
|
elif self.embedding == "word2vec": |
|
reduced_data, reduced_results_points, query_point = self.plot_dense(self.word2vec_embeddings, pca, results_indices) |
|
elif self.embedding == "bert": |
|
reduced_data, reduced_results_points, query_point = self.plot_dense(self.bert_embeddings, pca, results_indices) |
|
elif self.embedding == "sbert" or self.embedding == "clustered sbert": |
|
reduced_data, reduced_results_points, query_point = self.plot_dense(self.sbert_embedding, pca, results_indices) |
|
if self.embedding == "clustered sbert": |
|
cluster_colors = ["#00b7ff" if i in np.where(self.clusters == self.top_cluster_index)[0] else "#ffffff" for i in range(len(self.documents))] |
|
|
|
|
|
else: |
|
raise ValueError(f"Unsupported embedding type: {self.embedding}") |
|
|
|
results_scores = [i[1] for i in self.last_results] |
|
|
|
traces = [] |
|
|
|
trace = go.Scatter3d( |
|
x=reduced_data[:, 0], |
|
y=reduced_data[:, 1], |
|
z=reduced_data[:, 2], |
|
mode='markers', |
|
marker=dict(size=3.5, |
|
color="#ffffff" if self.embedding != "clustered sbert" else cluster_colors, |
|
opacity=0.2), |
|
name='All Documents', |
|
text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))], |
|
hoverinfo='text' |
|
) |
|
|
|
traces.append(trace) |
|
|
|
layout = go.Layout( |
|
margin=dict(l=0, r=0, b=0, t=0), |
|
scene=dict( |
|
xaxis_title='PCA 1', |
|
yaxis_title='PCA 2', |
|
zaxis_title='PCA 3', |
|
xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
|
yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
|
zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
|
), |
|
paper_bgcolor='black', |
|
plot_bgcolor='black', |
|
font=dict(color='white'), |
|
legend=dict( |
|
bgcolor='rgba(0,0,0,0)', |
|
bordercolor='rgba(0,0,0,0)', |
|
x=0.01, |
|
y=0.99, |
|
xanchor='left', |
|
yanchor='top' |
|
) |
|
) |
|
|
|
if len(reduced_results_points) > 0: |
|
custom_colorscale = [ |
|
[0.0, "#00ffea"], |
|
[1.0, "#ffea00"], |
|
] |
|
|
|
results_trace = go.Scatter3d( |
|
x=reduced_results_points[:, 0], |
|
y=reduced_results_points[:, 1], |
|
z=reduced_results_points[:, 2], |
|
mode='markers', |
|
marker=dict(size=4.25, |
|
color=results_scores, |
|
colorscale=custom_colorscale, |
|
opacity=0.99, |
|
colorbar=dict( |
|
title="Score", |
|
bgcolor='rgba(0,0,0,0)', |
|
bordercolor='rgba(0,0,0,0)' |
|
|
|
) |
|
), |
|
name='Results', |
|
text=[f"<br>{self.documents[i][:100]}" for i in results_indices], |
|
hoverinfo='text' |
|
) |
|
|
|
traces.append(results_trace) |
|
|
|
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0: |
|
query_trace = go.Scatter3d( |
|
x=query_point[:, 0], |
|
y=query_point[:, 1], |
|
z=query_point[:, 2], |
|
mode='markers', |
|
marker=dict(size=5, color='red', opacity=0.8), |
|
name='Query', |
|
text=[f"<br>Query: {self.query}"], |
|
hoverinfo='text' |
|
) |
|
traces.append(query_trace) |
|
|
|
fig = go.Figure(data=traces, layout=layout) |
|
|
|
return fig |
|
|
|
def keyword_match_ranking(self, query, top_n=10): |
|
query_terms = query.lower().split() |
|
query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms] |
|
if not query_indices: |
|
return [] |
|
scores = [] |
|
for doc_idx in range(self.tfidf_matrix.shape[0]): |
|
doc_vector = self.tfidf_matrix[doc_idx] |
|
doc_score = sum(doc_vector[0, i] for i in query_indices) |
|
if doc_score > 0: |
|
scores.append((doc_idx, doc_score)) |
|
scores.sort(key=lambda x: x[1], reverse=True) |
|
return scores[:top_n] |
|
|
|
def word2vec_search(self, query, top_n=10): |
|
tokens = [word for word in query.split() if word in self.wv_model.key_to_index] |
|
if not tokens: |
|
return [] |
|
vectors = np.array([self.wv_model[word] for word in tokens]) |
|
query_vec = np.mean(vectors, axis=0).reshape(1, -1) |
|
self.query_encoding = query_vec |
|
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten() |
|
top_indices = sims.argsort()[::-1][:top_n] |
|
return [(i, sims[i]) for i in top_indices] |
|
|
|
def bert_search(self, query, top_n=10): |
|
with torch.no_grad(): |
|
inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length') |
|
outputs = self.model(**inputs) |
|
query_vec = outputs.last_hidden_state[:, 0, :].numpy() |
|
|
|
self.query_encoding = query_vec |
|
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten() |
|
top_indices = sims.argsort()[::-1][:top_n] |
|
print(f"sim, top_indices: {sims}, {top_indices}") |
|
return [(i, sims[i]) for i in top_indices] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sbert_search(self, query, top_n=10): |
|
query_vec = self.sbert_model.encode([query]) |
|
self.query_encoding = query_vec |
|
cos_scores = cosine_similarity(query_vec, self.sbert_embedding)[0] |
|
top_k_indices = np.argsort(cos_scores)[-50:][::-1] |
|
candidates = [dataset['train'][int(i)]['text'] for i in top_k_indices] |
|
scores = self.cross_encoder.predict([(query, doc) for doc in candidates]) |
|
final_scores = 0.7 * scores + 0.3 * cos_scores[top_k_indices] |
|
top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]] |
|
print(f"sim, top_indices: {final_scores}, {top_indices}") |
|
return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]] |
|
|
|
def clustered_sbert_search(self, query, top_n=10): |
|
query_vec = self.sbert_model.encode([query]) |
|
self.query_encoding = query_vec |
|
cos_cluster_scores = cosine_similarity(query_vec, self.cluster_centers)[0] |
|
self.top_cluster_index = np.argmax(cos_cluster_scores) |
|
cos_scores = cosine_similarity(query_vec, self.clustered_embeddings[self.top_cluster_index])[0] |
|
top_k_indices = np.argsort(cos_scores)[-50:][::-1] |
|
top_full_dataset_indices = np.where(self.clusters == self.top_cluster_index)[0][top_k_indices] |
|
candidates = [self.dataset['train'][int(i)]['text'] for i in top_full_dataset_indices] |
|
scores = self.cross_encoder.predict([(query, doc) for doc in candidates]) |
|
final_scores = 0.7 * scores + 0.3 * cos_scores[top_k_indices] |
|
top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]] |
|
top_indices_full = np.where(self.clusters == self.top_cluster_index)[0][top_indices] |
|
print(f"sim, top_indices: {final_scores}, {top_indices}") |
|
return [(i, final_scores[j]) for j, i in enumerate(top_indices_full)] |
|
|
|
def model_switch(self, embedding, progress=gr.Progress()): |
|
if self.embedding != embedding: |
|
old_embedding = self.embedding |
|
print(f"Switching model to {embedding}") |
|
self.load_model(embedding) |
|
print(f"Loaded {embedding} model") |
|
self.embedding = embedding |
|
if old_embedding == "tfidf": |
|
del self.tfidf_matrix |
|
del self.feature_names |
|
if old_embedding == "word2vec": |
|
del self.word2vec_embeddings |
|
del self.wv_model |
|
if old_embedding == "bert": |
|
del self.bert_embeddings |
|
del self.tokenizer |
|
del self.model |
|
if old_embedding == "scibert": |
|
del self.scibert_embeddings |
|
del self.sci_tokenizer |
|
del self.sci_model |
|
if old_embedding == "sbert": |
|
del self.sbert_model |
|
del self.sbert_embedding |
|
del self.cross_encoder |
|
print(f"old embedding removed") |
|
if hasattr(self, "query") and self.query: |
|
return self.search_function(self.query, self.embedding) |
|
else: |
|
return "" |
|
return gr.update() |
|
|
|
def load_model(self, embedding): |
|
self.embedding = embedding |
|
if self.embedding == "tfidf": |
|
self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz") |
|
with open("TF-IDF embeddings/feature_names.txt", "r") as f: |
|
self.feature_names = [line.strip() for line in f.readlines()] |
|
elif self.embedding == "word2vec": |
|
|
|
self.word2vec_embeddings = np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"] |
|
self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model") |
|
elif self.embedding == "bert": |
|
self.bert_embeddings = np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"] |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.model = BertModel.from_pretrained('bert-base-uncased') |
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
|
|
elif self.embedding == "sbert" or self.embedding == "clustered sbert": |
|
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
self.sbert_embedding = np.load("BERT embeddings/sbert_embedding.npz")["sbert_embedding"] |
|
|
|
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-2-v2") |
|
if self.embedding == "clustered sbert": |
|
self.clusters = pd.read_csv(f'raf_clusters/cluster_labels_sbert.csv')['cluster_label'].values |
|
self.cluster_centers = pd.read_csv(f'BERT embeddings/sbert_cluster_centers.csv').values |
|
self.clustered_embeddings = [self.sbert_embedding[self.clusters == i] for i in np.unique(self.clusters)] |
|
else: |
|
raise ValueError(f"Unsupported embedding type: {self.embedding}") |
|
|
|
def snippet_before_abstract(self, text): |
|
pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE) |
|
match = pattern.search(text) |
|
if match: |
|
return text[:match.start()].strip() if match.start() < 1000 else text[:100].strip() |
|
else: |
|
return text[:300].strip() |
|
|
|
def set_embedding(self, embedding): |
|
self.embedding = embedding |
|
|
|
def search_function(self, query, embedding, progress=gr.Progress()): |
|
self.set_embedding(embedding) |
|
self.query = query |
|
query = query.encode().decode('unicode_escape') |
|
search_methods = { |
|
"tfidf": self.keyword_match_ranking, |
|
"word2vec": self.word2vec_search, |
|
"bert": self.bert_search, |
|
|
|
"sbert": self.sbert_search, |
|
"clustered sbert": self.clustered_sbert_search, |
|
} |
|
|
|
results = search_methods.get(self.embedding, lambda q: [])(query) |
|
|
|
if not results: |
|
self.last_results = [] |
|
return "No results found." |
|
|
|
if results: |
|
self.last_results = results |
|
|
|
output = "" |
|
display_rank = 1 |
|
for idx, score in results: |
|
if not self.arxiv_ids[idx]: |
|
output += f"### Document {display_rank}\n" |
|
output += f"<pre>{self.documents[idx][:200]}</pre>\n\n" |
|
else: |
|
link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}" |
|
snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>') |
|
output += f"### Document {display_rank}\n" |
|
output += f"[arXiv Link]({link})\n\n" |
|
output += f"<pre>{snippet}</pre>\n\n---\n" |
|
display_rank += 1 |
|
|
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset = load_dataset("ccdv/arxiv-classification", "no_ref") |
|
search_engine = ArxivSearch(dataset) |
|
search_engine.iface.launch() |