Jonas Leeb commited on
Commit
6c71bbc
·
1 Parent(s): 65f9879

added plot

Browse files
Files changed (1) hide show
  1. app.py +102 -19
app.py CHANGED
@@ -8,6 +8,9 @@ from transformers import BertTokenizer, BertModel
8
  import numpy as np
9
  from datasets import load_dataset
10
  from gensim.models import KeyedVectors
 
 
 
11
 
12
 
13
 
@@ -19,6 +22,7 @@ class ArxivSearch:
19
  self.titles = []
20
  self.raw_texts = []
21
  self.arxiv_ids = []
 
22
 
23
  self.embedding_dropdown = gr.Dropdown(
24
  choices=["tfidf", "word2vec", "bert"],
@@ -26,16 +30,48 @@ class ArxivSearch:
26
  label="Model"
27
  )
28
 
29
- self.iface = gr.Interface(
30
- fn=self.search_function,
31
- inputs=[
32
- gr.Textbox(lines=1, placeholder="Enter your search query"),
33
- self.embedding_dropdown
34
- ],
35
- outputs=gr.Markdown(),
36
- title="arXiv Search Engine",
37
- description="Search arXiv papers by keyword and embedding model.",
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  self.load_data(dataset)
41
  # self.load_model(embedding)
@@ -45,15 +81,6 @@ class ArxivSearch:
45
 
46
  self.iface.launch()
47
 
48
-
49
- # # --- Load data and embeddings ---
50
- # with open("feature_names.txt", "r") as f:
51
- # feature_names = [line.strip() for line in f]
52
-
53
- # tfidf_matrix = load_npz("tfidf_matrix_train.npz")
54
-
55
- # Load dataset and initialize search engine
56
-
57
  def load_data(self, dataset):
58
  train_data = dataset["train"]
59
  for item in train_data.select(range(len(train_data))):
@@ -99,6 +126,57 @@ class ArxivSearch:
99
  scores.append((doc_idx, doc_score))
100
  scores.sort(key=lambda x: x[1], reverse=True)
101
  return scores[:top_n]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def word2vec_search(self, query, top_n=5):
104
  tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
@@ -163,7 +241,12 @@ class ArxivSearch:
163
  return "No results found."
164
 
165
  if not results:
 
166
  return "No results found."
 
 
 
 
167
 
168
  output = ""
169
  display_rank = 1
 
8
  import numpy as np
9
  from datasets import load_dataset
10
  from gensim.models import KeyedVectors
11
+ import plotly.graph_objects as go
12
+ from sklearn.decomposition import PCA
13
+
14
 
15
 
16
 
 
22
  self.titles = []
23
  self.raw_texts = []
24
  self.arxiv_ids = []
25
+ self.last_results = []
26
 
27
  self.embedding_dropdown = gr.Dropdown(
28
  choices=["tfidf", "word2vec", "bert"],
 
30
  label="Model"
31
  )
32
 
33
+
34
+ # Add a button to show the 3D plot
35
+ self.plot_button = gr.Button("Show 3D Plot")
36
+
37
+ # Define the interface using Blocks for more flexibility
38
+ with gr.Blocks() as self.iface:
39
+ gr.Markdown("# arXiv Search Engine")
40
+ gr.Markdown("Search arXiv papers by keyword and embedding model.")
41
+ with gr.Row():
42
+ self.query_box = gr.Textbox(lines=1, placeholder="Enter your search query", label="Query")
43
+ self.embedding_dropdown.render()
44
+ self.plot_button.render()
45
+ with gr.Row():
46
+ self.plot_output = gr.Plot()
47
+ self.output_md = gr.Markdown()
48
+
49
+ self.query_box.submit(
50
+ self.search_function,
51
+ inputs=[self.query_box, self.embedding_dropdown],
52
+ outputs=self.output_md
53
+ )
54
+ self.embedding_dropdown.change(
55
+ self.search_function,
56
+ inputs=[self.query_box, self.embedding_dropdown],
57
+ outputs=self.output_md
58
+ )
59
+ self.plot_button.click(
60
+ self.plot_3d_embeddings,
61
+ inputs=[self.embedding_dropdown],
62
+ outputs=self.plot_output
63
+ )
64
+
65
+ # self.iface = gr.Interface(
66
+ # fn=self.search_function,
67
+ # inputs=[
68
+ # gr.Textbox(lines=1, placeholder="Enter your search query"),
69
+ # self.embedding_dropdown
70
+ # ],
71
+ # outputs=gr.Markdown(),
72
+ # title="arXiv Search Engine",
73
+ # description="Search arXiv papers by keyword and embedding model.",
74
+ # )
75
 
76
  self.load_data(dataset)
77
  # self.load_model(embedding)
 
81
 
82
  self.iface.launch()
83
 
 
 
 
 
 
 
 
 
 
84
  def load_data(self, dataset):
85
  train_data = dataset["train"]
86
  for item in train_data.select(range(len(train_data))):
 
126
  scores.append((doc_idx, doc_score))
127
  scores.sort(key=lambda x: x[1], reverse=True)
128
  return scores[:top_n]
129
+
130
+ def plot_3d_embeddings(self, embedding):
131
+ # Example: plot random points, replace with your embeddings
132
+ pca = PCA(n_components=3)
133
+ results_indices = [i[0] for i in self.last_results]
134
+ if embedding == "tfidf":
135
+ reduced_data = pca.fit_transform(self.tfidf_matrix[:5000].toarray())
136
+ reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
137
+
138
+ elif embedding == "word2vec":
139
+ reduced_data = pca.fit_transform(self.word2vec_embeddings[:5000])
140
+ reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
141
+
142
+ elif embedding == "bert":
143
+ reduced_data = pca.fit_transform(self.bert_embeddings[:5000])
144
+ reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
145
+ else:
146
+ raise ValueError(f"Unsupported embedding type: {embedding}")
147
+ trace = go.Scatter3d(
148
+ x=reduced_data[:, 0],
149
+ y=reduced_data[:, 1],
150
+ z=reduced_data[:, 2],
151
+ mode='markers',
152
+ marker=dict(size=3.5, color='white', opacity=0.4),
153
+ )
154
+ layout = go.Layout(
155
+ margin=dict(l=0, r=0, b=0, t=0),
156
+ scene=dict(
157
+ xaxis_title='X',
158
+ yaxis_title='Y',
159
+ zaxis_title='Z',
160
+ xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
161
+ yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
162
+ zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
163
+ ),
164
+ paper_bgcolor='black', # Outside the plotting area
165
+ plot_bgcolor='black', # Plotting area
166
+ font=dict(color='white') # Axis and legend text
167
+ )
168
+ if len(reduced_results_points) > 0:
169
+ results_trace = go.Scatter3d(
170
+ x=reduced_results_points[:, 0],
171
+ y=reduced_results_points[:, 1],
172
+ z=reduced_results_points[:, 2],
173
+ mode='markers',
174
+ marker=dict(size=3.5, color='orange', opacity=0.9),
175
+ )
176
+ fig = go.Figure(data=[trace, results_trace], layout=layout)
177
+ else:
178
+ fig = go.Figure(data=[trace], layout=layout)
179
+ return fig
180
 
181
  def word2vec_search(self, query, top_n=5):
182
  tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
 
241
  return "No results found."
242
 
243
  if not results:
244
+ self.last_results = []
245
  return "No results found."
246
+
247
+
248
+ if results:
249
+ self.last_results = results
250
 
251
  output = ""
252
  display_rank = 1