|
import re |
|
import gradio as gr |
|
from scipy.sparse import load_npz |
|
import torch |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from sklearn.preprocessing import normalize |
|
from transformers import BertTokenizer, BertModel |
|
import numpy as np |
|
from datasets import load_dataset |
|
from gensim.models import KeyedVectors |
|
|
|
|
|
|
|
class ArxivSearch: |
|
def __init__(self, dataset, embedding="tfidf"): |
|
self.dataset = dataset |
|
self.embedding = embedding |
|
self.documents = [] |
|
self.titles = [] |
|
self.raw_texts = [] |
|
self.arxiv_ids = [] |
|
|
|
self.embedding_dropdown = gr.Dropdown( |
|
choices=["tfidf", "word2vec", "bert"], |
|
value="tfidf", |
|
label="Model" |
|
) |
|
|
|
self.iface = gr.Interface( |
|
fn=self.search_function, |
|
inputs=[ |
|
gr.Textbox(lines=1, placeholder="Enter your search query"), |
|
self.embedding_dropdown |
|
], |
|
outputs=gr.Markdown(), |
|
title="arXiv Search Engine", |
|
description="Search arXiv papers by keyword and embedding model.", |
|
) |
|
|
|
self.load_data(dataset) |
|
self.load_model(embedding) |
|
|
|
self.iface.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(self, dataset): |
|
train_data = dataset["train"] |
|
for item in train_data.select(range(len(train_data))): |
|
text = item["text"] |
|
if not text or len(text.strip()) < 10: |
|
continue |
|
|
|
lines = text.splitlines() |
|
title_lines = [] |
|
found_arxiv = False |
|
arxiv_id = None |
|
|
|
for line in lines: |
|
line_strip = line.strip() |
|
if not found_arxiv and line_strip.lower().startswith("arxiv:"): |
|
found_arxiv = True |
|
match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE) |
|
if match: |
|
arxiv_id = match.group(0).lower() |
|
elif not found_arxiv: |
|
title_lines.append(line_strip) |
|
else: |
|
if line_strip.lower().startswith("abstract"): |
|
break |
|
|
|
title = " ".join(title_lines).strip() |
|
|
|
self.raw_texts.append(text.strip()) |
|
self.titles.append(title) |
|
self.documents.append(text.strip()) |
|
self.arxiv_ids.append(arxiv_id) |
|
|
|
def keyword_match_ranking(self, query, top_n=5): |
|
query_terms = query.lower().split() |
|
query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms] |
|
if not query_indices: |
|
return [] |
|
scores = [] |
|
for doc_idx in range(self.tfidf_matrix.shape[0]): |
|
doc_vector = self.tfidf_matrix[doc_idx] |
|
doc_score = sum(doc_vector[0, i] for i in query_indices) |
|
if doc_score > 0: |
|
scores.append((doc_idx, doc_score)) |
|
scores.sort(key=lambda x: x[1], reverse=True) |
|
return scores[:top_n] |
|
|
|
def word2vec_search(self, query, top_n=5): |
|
tokens = [word for word in query.split() if word in self.wv_model.key_to_index] |
|
if not tokens: |
|
return [] |
|
vectors = np.array([self.wv_model[word] for word in tokens]) |
|
query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1)) |
|
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten() |
|
top_indices = sims.argsort()[::-1][:top_n] |
|
return [(i, sims[i]) for i in top_indices] |
|
|
|
def bert_search(self, query, top_n=5): |
|
with torch.no_grad(): |
|
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True) |
|
outputs = self.model(**inputs) |
|
query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy()) |
|
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten() |
|
top_indices = sims.argsort()[::-1][:top_n] |
|
return [(i, sims[i]) for i in top_indices] |
|
|
|
def load_model(self, embedding): |
|
if embedding == "tfidf": |
|
self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz") |
|
with open("TF-IDF embeddings/feature_names.txt", "r") as f: |
|
self.feature_names = [line.strip() for line in f.readlines()] |
|
elif embedding == "word2vec": |
|
|
|
self.word2vec_embeddings = normalize(np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"]) |
|
self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model") |
|
elif embedding == "bert": |
|
self.bert_embeddings = normalize(np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"]) |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.model = BertModel.from_pretrained('bert-base-uncased') |
|
self.model.eval() |
|
else: |
|
raise ValueError(f"Unsupported embedding type: {embedding}") |
|
|
|
def on_model_change(self, change): |
|
new_model = change["new"] |
|
self.embedding = new_model |
|
self.load_model(new_model) |
|
|
|
|
|
def snippet_before_abstract(self, text): |
|
pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE) |
|
match = pattern.search(text) |
|
if match: |
|
return text[:match.start()].strip() |
|
else: |
|
return text[:100].strip() |
|
|
|
|
|
def search_function(self, query, embedding): |
|
|
|
if embedding == "tfidf": |
|
results = self.keyword_match_ranking(query) |
|
elif embedding == "word2vec": |
|
results = self.word2vec_search(query) |
|
elif embedding == "bert": |
|
results = self.bert_search(query) |
|
else: |
|
return "No results found." |
|
|
|
if not results: |
|
return "No results found." |
|
|
|
output = "" |
|
display_rank = 1 |
|
for idx, score in results: |
|
if not self.arxiv_ids[idx]: |
|
continue |
|
|
|
link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}" |
|
snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>') |
|
output += f"### Document {display_rank}\n" |
|
output += f"[arXiv Link]({link})\n\n" |
|
output += f"<pre>{snippet}</pre>\n\n---\n" |
|
display_rank += 1 |
|
|
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset = load_dataset("ccdv/arxiv-classification", "no_ref") |
|
search_engine = ArxivSearch(dataset, embedding="tfidf") |
|
search_engine.iface.launch() |