|
import re |
|
import gradio as gr |
|
from scipy.sparse import load_npz |
|
import torch |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from sklearn.preprocessing import normalize |
|
from transformers import BertTokenizer, BertModel |
|
import numpy as np |
|
from datasets import load_dataset |
|
from gensim.models import KeyedVectors |
|
import plotly.graph_objects as go |
|
from sklearn.decomposition import PCA |
|
|
|
|
|
|
|
|
|
class ArxivSearch: |
|
def __init__(self, dataset, embedding="tfidf"): |
|
self.dataset = dataset |
|
self.embedding = embedding |
|
self.documents = [] |
|
self.titles = [] |
|
self.raw_texts = [] |
|
self.arxiv_ids = [] |
|
self.last_results = [] |
|
|
|
self.embedding_dropdown = gr.Dropdown( |
|
choices=["tfidf", "word2vec", "bert"], |
|
value="tfidf", |
|
label="Model" |
|
) |
|
|
|
|
|
|
|
self.plot_button = gr.Button("Show 3D Plot") |
|
|
|
|
|
with gr.Blocks() as self.iface: |
|
gr.Markdown("# arXiv Search Engine") |
|
gr.Markdown("Search arXiv papers by keyword and embedding model.") |
|
with gr.Row(): |
|
self.query_box = gr.Textbox(lines=1, placeholder="Enter your search query", label="Query") |
|
self.embedding_dropdown.render() |
|
self.plot_button.render() |
|
with gr.Row(): |
|
self.plot_output = gr.Plot() |
|
self.output_md = gr.Markdown() |
|
|
|
self.query_box.submit( |
|
self.search_function, |
|
inputs=[self.query_box, self.embedding_dropdown], |
|
outputs=self.output_md |
|
) |
|
self.embedding_dropdown.change( |
|
self.search_function, |
|
inputs=[self.query_box, self.embedding_dropdown], |
|
outputs=self.output_md |
|
) |
|
self.plot_button.click( |
|
self.plot_3d_embeddings, |
|
inputs=[self.embedding_dropdown], |
|
outputs=self.plot_output |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load_data(dataset) |
|
|
|
self.load_model('tfidf') |
|
self.load_model('word2vec') |
|
self.load_model('bert') |
|
|
|
self.iface.launch() |
|
|
|
def load_data(self, dataset): |
|
train_data = dataset["train"] |
|
for item in train_data.select(range(len(train_data))): |
|
text = item["text"] |
|
if not text or len(text.strip()) < 10: |
|
continue |
|
|
|
lines = text.splitlines() |
|
title_lines = [] |
|
found_arxiv = False |
|
arxiv_id = None |
|
|
|
for line in lines: |
|
line_strip = line.strip() |
|
if not found_arxiv and line_strip.lower().startswith("arxiv:"): |
|
found_arxiv = True |
|
match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE) |
|
if match: |
|
arxiv_id = match.group(0).lower() |
|
elif not found_arxiv: |
|
title_lines.append(line_strip) |
|
else: |
|
if line_strip.lower().startswith("abstract"): |
|
break |
|
|
|
title = " ".join(title_lines).strip() |
|
|
|
self.raw_texts.append(text.strip()) |
|
self.titles.append(title) |
|
self.documents.append(text.strip()) |
|
self.arxiv_ids.append(arxiv_id) |
|
|
|
def keyword_match_ranking(self, query, top_n=5): |
|
query_terms = query.lower().split() |
|
query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms] |
|
if not query_indices: |
|
return [] |
|
scores = [] |
|
for doc_idx in range(self.tfidf_matrix.shape[0]): |
|
doc_vector = self.tfidf_matrix[doc_idx] |
|
doc_score = sum(doc_vector[0, i] for i in query_indices) |
|
if doc_score > 0: |
|
scores.append((doc_idx, doc_score)) |
|
scores.sort(key=lambda x: x[1], reverse=True) |
|
return scores[:top_n] |
|
|
|
def plot_3d_embeddings(self, embedding): |
|
|
|
pca = PCA(n_components=3) |
|
results_indices = [i[0] for i in self.last_results] |
|
if embedding == "tfidf": |
|
reduced_data = pca.fit_transform(self.tfidf_matrix[:5000].toarray()) |
|
reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3)) |
|
|
|
elif embedding == "word2vec": |
|
reduced_data = pca.fit_transform(self.word2vec_embeddings[:5000]) |
|
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3)) |
|
|
|
elif embedding == "bert": |
|
reduced_data = pca.fit_transform(self.bert_embeddings[:5000]) |
|
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3)) |
|
else: |
|
raise ValueError(f"Unsupported embedding type: {embedding}") |
|
trace = go.Scatter3d( |
|
x=reduced_data[:, 0], |
|
y=reduced_data[:, 1], |
|
z=reduced_data[:, 2], |
|
mode='markers', |
|
marker=dict(size=3.5, color='white', opacity=0.4), |
|
) |
|
layout = go.Layout( |
|
margin=dict(l=0, r=0, b=0, t=0), |
|
scene=dict( |
|
xaxis_title='X', |
|
yaxis_title='Y', |
|
zaxis_title='Z', |
|
xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
|
yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
|
zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
|
), |
|
paper_bgcolor='black', |
|
plot_bgcolor='black', |
|
font=dict(color='white') |
|
) |
|
if len(reduced_results_points) > 0: |
|
results_trace = go.Scatter3d( |
|
x=reduced_results_points[:, 0], |
|
y=reduced_results_points[:, 1], |
|
z=reduced_results_points[:, 2], |
|
mode='markers', |
|
marker=dict(size=3.5, color='orange', opacity=0.9), |
|
) |
|
fig = go.Figure(data=[trace, results_trace], layout=layout) |
|
else: |
|
fig = go.Figure(data=[trace], layout=layout) |
|
return fig |
|
|
|
def word2vec_search(self, query, top_n=5): |
|
tokens = [word for word in query.split() if word in self.wv_model.key_to_index] |
|
if not tokens: |
|
return [] |
|
vectors = np.array([self.wv_model[word] for word in tokens]) |
|
query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1)) |
|
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten() |
|
top_indices = sims.argsort()[::-1][:top_n] |
|
return [(i, sims[i]) for i in top_indices] |
|
|
|
def bert_search(self, query, top_n=5): |
|
with torch.no_grad(): |
|
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True) |
|
outputs = self.model(**inputs) |
|
query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy()) |
|
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten() |
|
top_indices = sims.argsort()[::-1][:top_n] |
|
return [(i, sims[i]) for i in top_indices] |
|
|
|
def load_model(self, embedding): |
|
if embedding == "tfidf": |
|
self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz") |
|
with open("TF-IDF embeddings/feature_names.txt", "r") as f: |
|
self.feature_names = [line.strip() for line in f.readlines()] |
|
elif embedding == "word2vec": |
|
|
|
self.word2vec_embeddings = normalize(np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"]) |
|
self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model") |
|
elif embedding == "bert": |
|
self.bert_embeddings = normalize(np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"]) |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.model = BertModel.from_pretrained('bert-base-uncased') |
|
self.model.eval() |
|
else: |
|
raise ValueError(f"Unsupported embedding type: {embedding}") |
|
|
|
def on_model_change(self, change): |
|
new_model = change["new"] |
|
self.embedding = new_model |
|
self.load_model(new_model) |
|
|
|
|
|
def snippet_before_abstract(self, text): |
|
pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE) |
|
match = pattern.search(text) |
|
if match: |
|
return text[:match.start()].strip() |
|
else: |
|
return text[:100].strip() |
|
|
|
|
|
def search_function(self, query, embedding): |
|
|
|
if embedding == "tfidf": |
|
results = self.keyword_match_ranking(query) |
|
elif embedding == "word2vec": |
|
results = self.word2vec_search(query) |
|
elif embedding == "bert": |
|
results = self.bert_search(query) |
|
else: |
|
return "No results found." |
|
|
|
if not results: |
|
self.last_results = [] |
|
return "No results found." |
|
|
|
|
|
if results: |
|
self.last_results = results |
|
|
|
output = "" |
|
display_rank = 1 |
|
for idx, score in results: |
|
if not self.arxiv_ids[idx]: |
|
continue |
|
|
|
link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}" |
|
snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>') |
|
output += f"### Document {display_rank}\n" |
|
output += f"[arXiv Link]({link})\n\n" |
|
output += f"<pre>{snippet}</pre>\n\n---\n" |
|
display_rank += 1 |
|
|
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset = load_dataset("ccdv/arxiv-classification", "no_ref") |
|
search_engine = ArxivSearch(dataset, embedding="tfidf") |
|
search_engine.iface.launch() |