Jonas Leeb
commited on
Commit
·
6c71bbc
1
Parent(s):
65f9879
added plot
Browse files
app.py
CHANGED
@@ -8,6 +8,9 @@ from transformers import BertTokenizer, BertModel
|
|
8 |
import numpy as np
|
9 |
from datasets import load_dataset
|
10 |
from gensim.models import KeyedVectors
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
|
@@ -19,6 +22,7 @@ class ArxivSearch:
|
|
19 |
self.titles = []
|
20 |
self.raw_texts = []
|
21 |
self.arxiv_ids = []
|
|
|
22 |
|
23 |
self.embedding_dropdown = gr.Dropdown(
|
24 |
choices=["tfidf", "word2vec", "bert"],
|
@@ -26,16 +30,48 @@ class ArxivSearch:
|
|
26 |
label="Model"
|
27 |
)
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
self.load_data(dataset)
|
41 |
# self.load_model(embedding)
|
@@ -45,15 +81,6 @@ class ArxivSearch:
|
|
45 |
|
46 |
self.iface.launch()
|
47 |
|
48 |
-
|
49 |
-
# # --- Load data and embeddings ---
|
50 |
-
# with open("feature_names.txt", "r") as f:
|
51 |
-
# feature_names = [line.strip() for line in f]
|
52 |
-
|
53 |
-
# tfidf_matrix = load_npz("tfidf_matrix_train.npz")
|
54 |
-
|
55 |
-
# Load dataset and initialize search engine
|
56 |
-
|
57 |
def load_data(self, dataset):
|
58 |
train_data = dataset["train"]
|
59 |
for item in train_data.select(range(len(train_data))):
|
@@ -99,6 +126,57 @@ class ArxivSearch:
|
|
99 |
scores.append((doc_idx, doc_score))
|
100 |
scores.sort(key=lambda x: x[1], reverse=True)
|
101 |
return scores[:top_n]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def word2vec_search(self, query, top_n=5):
|
104 |
tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
|
@@ -163,7 +241,12 @@ class ArxivSearch:
|
|
163 |
return "No results found."
|
164 |
|
165 |
if not results:
|
|
|
166 |
return "No results found."
|
|
|
|
|
|
|
|
|
167 |
|
168 |
output = ""
|
169 |
display_rank = 1
|
|
|
8 |
import numpy as np
|
9 |
from datasets import load_dataset
|
10 |
from gensim.models import KeyedVectors
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
from sklearn.decomposition import PCA
|
13 |
+
|
14 |
|
15 |
|
16 |
|
|
|
22 |
self.titles = []
|
23 |
self.raw_texts = []
|
24 |
self.arxiv_ids = []
|
25 |
+
self.last_results = []
|
26 |
|
27 |
self.embedding_dropdown = gr.Dropdown(
|
28 |
choices=["tfidf", "word2vec", "bert"],
|
|
|
30 |
label="Model"
|
31 |
)
|
32 |
|
33 |
+
|
34 |
+
# Add a button to show the 3D plot
|
35 |
+
self.plot_button = gr.Button("Show 3D Plot")
|
36 |
+
|
37 |
+
# Define the interface using Blocks for more flexibility
|
38 |
+
with gr.Blocks() as self.iface:
|
39 |
+
gr.Markdown("# arXiv Search Engine")
|
40 |
+
gr.Markdown("Search arXiv papers by keyword and embedding model.")
|
41 |
+
with gr.Row():
|
42 |
+
self.query_box = gr.Textbox(lines=1, placeholder="Enter your search query", label="Query")
|
43 |
+
self.embedding_dropdown.render()
|
44 |
+
self.plot_button.render()
|
45 |
+
with gr.Row():
|
46 |
+
self.plot_output = gr.Plot()
|
47 |
+
self.output_md = gr.Markdown()
|
48 |
+
|
49 |
+
self.query_box.submit(
|
50 |
+
self.search_function,
|
51 |
+
inputs=[self.query_box, self.embedding_dropdown],
|
52 |
+
outputs=self.output_md
|
53 |
+
)
|
54 |
+
self.embedding_dropdown.change(
|
55 |
+
self.search_function,
|
56 |
+
inputs=[self.query_box, self.embedding_dropdown],
|
57 |
+
outputs=self.output_md
|
58 |
+
)
|
59 |
+
self.plot_button.click(
|
60 |
+
self.plot_3d_embeddings,
|
61 |
+
inputs=[self.embedding_dropdown],
|
62 |
+
outputs=self.plot_output
|
63 |
+
)
|
64 |
+
|
65 |
+
# self.iface = gr.Interface(
|
66 |
+
# fn=self.search_function,
|
67 |
+
# inputs=[
|
68 |
+
# gr.Textbox(lines=1, placeholder="Enter your search query"),
|
69 |
+
# self.embedding_dropdown
|
70 |
+
# ],
|
71 |
+
# outputs=gr.Markdown(),
|
72 |
+
# title="arXiv Search Engine",
|
73 |
+
# description="Search arXiv papers by keyword and embedding model.",
|
74 |
+
# )
|
75 |
|
76 |
self.load_data(dataset)
|
77 |
# self.load_model(embedding)
|
|
|
81 |
|
82 |
self.iface.launch()
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
def load_data(self, dataset):
|
85 |
train_data = dataset["train"]
|
86 |
for item in train_data.select(range(len(train_data))):
|
|
|
126 |
scores.append((doc_idx, doc_score))
|
127 |
scores.sort(key=lambda x: x[1], reverse=True)
|
128 |
return scores[:top_n]
|
129 |
+
|
130 |
+
def plot_3d_embeddings(self, embedding):
|
131 |
+
# Example: plot random points, replace with your embeddings
|
132 |
+
pca = PCA(n_components=3)
|
133 |
+
results_indices = [i[0] for i in self.last_results]
|
134 |
+
if embedding == "tfidf":
|
135 |
+
reduced_data = pca.fit_transform(self.tfidf_matrix[:5000].toarray())
|
136 |
+
reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
|
137 |
+
|
138 |
+
elif embedding == "word2vec":
|
139 |
+
reduced_data = pca.fit_transform(self.word2vec_embeddings[:5000])
|
140 |
+
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
141 |
+
|
142 |
+
elif embedding == "bert":
|
143 |
+
reduced_data = pca.fit_transform(self.bert_embeddings[:5000])
|
144 |
+
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
145 |
+
else:
|
146 |
+
raise ValueError(f"Unsupported embedding type: {embedding}")
|
147 |
+
trace = go.Scatter3d(
|
148 |
+
x=reduced_data[:, 0],
|
149 |
+
y=reduced_data[:, 1],
|
150 |
+
z=reduced_data[:, 2],
|
151 |
+
mode='markers',
|
152 |
+
marker=dict(size=3.5, color='white', opacity=0.4),
|
153 |
+
)
|
154 |
+
layout = go.Layout(
|
155 |
+
margin=dict(l=0, r=0, b=0, t=0),
|
156 |
+
scene=dict(
|
157 |
+
xaxis_title='X',
|
158 |
+
yaxis_title='Y',
|
159 |
+
zaxis_title='Z',
|
160 |
+
xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
161 |
+
yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
162 |
+
zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
163 |
+
),
|
164 |
+
paper_bgcolor='black', # Outside the plotting area
|
165 |
+
plot_bgcolor='black', # Plotting area
|
166 |
+
font=dict(color='white') # Axis and legend text
|
167 |
+
)
|
168 |
+
if len(reduced_results_points) > 0:
|
169 |
+
results_trace = go.Scatter3d(
|
170 |
+
x=reduced_results_points[:, 0],
|
171 |
+
y=reduced_results_points[:, 1],
|
172 |
+
z=reduced_results_points[:, 2],
|
173 |
+
mode='markers',
|
174 |
+
marker=dict(size=3.5, color='orange', opacity=0.9),
|
175 |
+
)
|
176 |
+
fig = go.Figure(data=[trace, results_trace], layout=layout)
|
177 |
+
else:
|
178 |
+
fig = go.Figure(data=[trace], layout=layout)
|
179 |
+
return fig
|
180 |
|
181 |
def word2vec_search(self, query, top_n=5):
|
182 |
tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
|
|
|
241 |
return "No results found."
|
242 |
|
243 |
if not results:
|
244 |
+
self.last_results = []
|
245 |
return "No results found."
|
246 |
+
|
247 |
+
|
248 |
+
if results:
|
249 |
+
self.last_results = results
|
250 |
|
251 |
output = ""
|
252 |
display_rank = 1
|