Jonas Leeb commited on
Commit
dfc89a9
·
1 Parent(s): b4a0b98

query is also shown now

Browse files
Files changed (1) hide show
  1. app.py +33 -17
app.py CHANGED
@@ -23,6 +23,7 @@ class ArxivSearch:
23
  self.raw_texts = []
24
  self.arxiv_ids = []
25
  self.last_results = []
 
26
 
27
  self.embedding_dropdown = gr.Dropdown(
28
  choices=["tfidf", "word2vec", "bert"],
@@ -113,20 +114,7 @@ class ArxivSearch:
113
  self.documents.append(text.strip())
114
  self.arxiv_ids.append(arxiv_id)
115
 
116
- def keyword_match_ranking(self, query, top_n=5):
117
- query_terms = query.lower().split()
118
- query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms]
119
- if not query_indices:
120
- return []
121
- scores = []
122
- for doc_idx in range(self.tfidf_matrix.shape[0]):
123
- doc_vector = self.tfidf_matrix[doc_idx]
124
- doc_score = sum(doc_vector[0, i] for i in query_indices)
125
- if doc_score > 0:
126
- scores.append((doc_idx, doc_score))
127
- scores.sort(key=lambda x: x[1], reverse=True)
128
- return scores[:top_n]
129
-
130
  def plot_3d_embeddings(self, embedding):
131
  # Example: plot random points, replace with your embeddings
132
  pca = PCA(n_components=3)
@@ -144,6 +132,7 @@ class ArxivSearch:
144
  pca.fit(all_data)
145
  reduced_data = pca.transform(self.word2vec_embeddings[:5000])
146
  reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
 
147
 
148
  elif embedding == "bert":
149
  all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
@@ -151,6 +140,7 @@ class ArxivSearch:
151
  pca.fit(all_data)
152
  reduced_data = pca.transform(self.bert_embeddings[:5000])
153
  reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
 
154
 
155
  else:
156
  raise ValueError(f"Unsupported embedding type: {embedding}")
@@ -159,7 +149,8 @@ class ArxivSearch:
159
  y=reduced_data[:, 1],
160
  z=reduced_data[:, 2],
161
  mode='markers',
162
- marker=dict(size=3.5, color='white', opacity=0.4),
 
163
  )
164
  layout = go.Layout(
165
  margin=dict(l=0, r=0, b=0, t=0),
@@ -182,18 +173,42 @@ class ArxivSearch:
182
  z=reduced_results_points[:, 2],
183
  mode='markers',
184
  marker=dict(size=3.5, color='orange', opacity=0.75),
 
 
 
 
 
 
 
 
 
185
  )
186
- fig = go.Figure(data=[trace, results_trace], layout=layout)
187
  else:
188
  fig = go.Figure(data=[trace], layout=layout)
189
  return fig
190
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def word2vec_search(self, query, top_n=5):
192
  tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
193
  if not tokens:
194
  return []
195
  vectors = np.array([self.wv_model[word] for word in tokens])
196
  query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1))
 
197
  sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
198
  top_indices = sims.argsort()[::-1][:top_n]
199
  return [(i, sims[i]) for i in top_indices]
@@ -203,6 +218,7 @@ class ArxivSearch:
203
  inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
204
  outputs = self.model(**inputs)
205
  query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
 
206
  sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
207
  top_indices = sims.argsort()[::-1][:top_n]
208
  return [(i, sims[i]) for i in top_indices]
 
23
  self.raw_texts = []
24
  self.arxiv_ids = []
25
  self.last_results = []
26
+ self.query_encoding = None
27
 
28
  self.embedding_dropdown = gr.Dropdown(
29
  choices=["tfidf", "word2vec", "bert"],
 
114
  self.documents.append(text.strip())
115
  self.arxiv_ids.append(arxiv_id)
116
 
117
+
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def plot_3d_embeddings(self, embedding):
119
  # Example: plot random points, replace with your embeddings
120
  pca = PCA(n_components=3)
 
132
  pca.fit(all_data)
133
  reduced_data = pca.transform(self.word2vec_embeddings[:5000])
134
  reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
135
+ query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
136
 
137
  elif embedding == "bert":
138
  all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
 
140
  pca.fit(all_data)
141
  reduced_data = pca.transform(self.bert_embeddings[:5000])
142
  reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
143
+ query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
144
 
145
  else:
146
  raise ValueError(f"Unsupported embedding type: {embedding}")
 
149
  y=reduced_data[:, 1],
150
  z=reduced_data[:, 2],
151
  mode='markers',
152
+ marker=dict(size=3.5, color='#cccccc', opacity=0.35),
153
+ name='All Documents'
154
  )
155
  layout = go.Layout(
156
  margin=dict(l=0, r=0, b=0, t=0),
 
173
  z=reduced_results_points[:, 2],
174
  mode='markers',
175
  marker=dict(size=3.5, color='orange', opacity=0.75),
176
+ name='Results'
177
+ )
178
+ query_trace = go.Scatter3d(
179
+ x=query_point[:, 0],
180
+ y=query_point[:, 1],
181
+ z=query_point[:, 2],
182
+ mode='markers',
183
+ marker=dict(size=5, color='red', opacity=0.8),
184
+ name='Query'
185
  )
186
+ fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
187
  else:
188
  fig = go.Figure(data=[trace], layout=layout)
189
  return fig
190
+
191
+ def keyword_match_ranking(self, query, top_n=5):
192
+ query_terms = query.lower().split()
193
+ query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms]
194
+ if not query_indices:
195
+ return []
196
+ scores = []
197
+ for doc_idx in range(self.tfidf_matrix.shape[0]):
198
+ doc_vector = self.tfidf_matrix[doc_idx]
199
+ doc_score = sum(doc_vector[0, i] for i in query_indices)
200
+ if doc_score > 0:
201
+ scores.append((doc_idx, doc_score))
202
+ scores.sort(key=lambda x: x[1], reverse=True)
203
+ return scores[:top_n]
204
+
205
  def word2vec_search(self, query, top_n=5):
206
  tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
207
  if not tokens:
208
  return []
209
  vectors = np.array([self.wv_model[word] for word in tokens])
210
  query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1))
211
+ self.query_encoding = query_vec
212
  sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
213
  top_indices = sims.argsort()[::-1][:top_n]
214
  return [(i, sims[i]) for i in top_indices]
 
218
  inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
219
  outputs = self.model(**inputs)
220
  query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
221
+ self.query_encoding = query_vec
222
  sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
223
  top_indices = sims.argsort()[::-1][:top_n]
224
  return [(i, sims[i]) for i in top_indices]