somewheresy commited on
Commit
ac664de
·
verified ·
1 Parent(s): a7193d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -53
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # Import necessary libraries
2
  import streamlit as st
3
  import pandas as pd
4
  import numpy as np
@@ -6,16 +5,58 @@ from sklearn.manifold import TSNE
6
  from datasets import load_dataset, Dataset
7
  from sklearn.cluster import KMeans
8
  import plotly.graph_objects as go
9
- import time
10
  import logging
11
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Additional libraries for querying
14
  from FlagEmbedding import FlagModel
15
 
16
  # Global variables and dataset loading
17
  global dataset_name
18
- dataset_name = 'somewheresystems/dataclysm-arxiv'
 
 
 
 
 
 
 
19
  st.session_state.dataclysm_arxiv = load_dataset(dataset_name, split="train")
20
  total_samples = len(st.session_state.dataclysm_arxiv)
21
 
@@ -75,7 +116,7 @@ def perform_tsne(embeddings):
75
  tsne_results = tsne.fit_transform(np.vstack(embeddings.tolist()))
76
 
77
  # Update progress bar to indicate completion
78
- progress_text.text(f"t-SNE completed. Processed {n_samples} samples with perplexity {perplexity}.")
79
  end_time = time.time() # End timing
80
  st.sidebar.text(f't-SNE completed in {end_time - start_time:.3f} seconds')
81
  return tsne_results
@@ -83,20 +124,71 @@ def perform_tsne(embeddings):
83
 
84
  def perform_clustering(df, tsne_results):
85
  start_time = time.time()
86
- # Perform KMeans clustering
87
- logging.info('Performing k-means clustering...')
88
  # Step 3: Visualization with Plotly
89
- df['tsne-3d-one'] = tsne_results[:,0]
90
- df['tsne-3d-two'] = tsne_results[:,1]
91
- df['tsne-3d-three'] = tsne_results[:,2]
92
-
93
- # Perform KMeans clustering
94
- kmeans = KMeans(n_clusters=16) # Change the number of clusters as needed
95
- df['cluster'] = kmeans.fit_predict(df[['tsne-3d-one', 'tsne-3d-two', 'tsne-3d-three']])
 
 
96
  end_time = time.time() # End timing
97
- st.sidebar.text(f'k-means clustering completed in {end_time - start_time:.3f} seconds')
98
  return df
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def main():
101
  # Custom CSS
102
  custom_css = """
@@ -112,48 +204,184 @@ def main():
112
  color: #F8F8F8; /* Set the font color to F8F8F8 */
113
  }
114
  /* Add your CSS styles here */
 
 
 
 
 
115
  h1 {
116
  text-align: center;
117
  }
118
  h2,h3,h4 {
119
  text-align: justify;
120
- font-size: 8px
 
 
 
121
  }
122
  body {
123
- text-align: justify;
 
124
  }
 
125
  .stSlider .css-1cpxqw2 {
126
  background: #202020;
 
 
 
 
 
127
  }
128
  .stButton > button {
129
  background-color: #202020;
130
- width: 100%;
131
- border: none;
 
 
132
  padding: 10px 24px;
133
- border-radius: 5px;
134
  font-size: 16px;
135
  font-weight: bold;
 
 
 
 
 
 
 
 
 
 
136
  }
137
  .reportview-container .main .block-container {
138
- padding: 2rem;
139
  background-color: #202020;
 
 
 
 
 
 
 
140
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  </style>
142
  """
143
 
144
  # Inject custom CSS with markdown
145
  st.markdown(custom_css, unsafe_allow_html=True)
 
146
  st.sidebar.markdown(
147
- f'<img src="https://www.somewhere.systems/S2-white-logo.png" style="float: bottom-left; width: 32px; height: 32px; opacity: 1.0; animation: fadein 2s;">',
148
  unsafe_allow_html=True
149
  )
150
- st.sidebar.title('Spatial Search Engine')
151
-
 
152
  # Check if data needs to be loaded
153
  if 'data_loaded' not in st.session_state or not st.session_state.data_loaded:
154
  # User input for number of samples
155
- num_samples = st.sidebar.slider('Select number of samples', 1000, total_samples, 1000)
156
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if st.sidebar.button('Initialize'):
158
  st.sidebar.text('Initializing data pipeline...')
159
 
@@ -171,8 +399,6 @@ def main():
171
  print(f"FAISS index for {column_name} added.")
172
 
173
  return dataset
174
-
175
-
176
 
177
  # Load data and perform t-SNE and clustering
178
  df, embeddings = load_data(num_samples)
@@ -209,35 +435,46 @@ def main():
209
  marker=dict(
210
  size=1,
211
  color=df['cluster'],
212
- colorscale='Viridis',
213
- opacity=0.8
214
  )
215
  )])
 
 
 
 
216
 
217
  fig.update_layout(
218
- plot_bgcolor='#202020',
 
219
  height=800,
220
  margin=dict(l=0, r=0, b=0, t=0),
221
- scene=dict(
222
- xaxis=dict(showbackground=True, backgroundcolor="#000000"),
223
- yaxis=dict(showbackground=True, backgroundcolor="#000000"),
224
- zaxis=dict(showbackground=True, backgroundcolor="#000000"),
225
- ),
226
- scene_camera=dict(eye=dict(x=0.001, y=0.001, z=0.001))
227
  )
228
  st.session_state.fig = fig
229
 
230
  # Display the plot if data is loaded
231
  if 'data_loaded' in st.session_state and st.session_state.data_loaded:
232
- st.plotly_chart(st.session_state.fig, use_container_width=True)
233
 
234
 
235
  # Sidebar for detailed view
236
  if 'df' in st.session_state:
237
  # Sidebar for querying
238
  with st.sidebar:
239
- st.sidebar.markdown("### Query Embeddings")
240
- query = st.text_input("Enter your query:")
 
 
 
 
 
 
 
 
 
 
 
241
  if st.button("Search"):
242
  # Define the model
243
  print("Initializing model...")
@@ -248,7 +485,7 @@ def main():
248
 
249
  query_embedding = model.encode([query])
250
  # Retrieve examples by title similarity (or abstract, depending on your preference)
251
- scores_title, retrieved_examples_title = st.session_state.dataclysm_title_indexed.get_nearest_examples('title_embedding', query_embedding, k=10)
252
  df_query = pd.DataFrame(retrieved_examples_title)
253
  df_query['proximity'] = scores_title
254
  df_query = df_query.sort_values(by='proximity', ascending=True)
@@ -257,19 +494,19 @@ def main():
257
  # Fix the <a href link> to display properly
258
  df_query['URL'] = df_query['id'].apply(lambda x: f'<a href="https://arxiv.org/abs/{x}" target="_blank">Link</a>')
259
  st.sidebar.markdown(df_query[['title', 'proximity', 'id']].to_html(escape=False), unsafe_allow_html=True)
260
- st.sidebar.markdown("# Detailed View")
261
- selected_index = st.sidebar.selectbox("Select Key", st.session_state.df.id)
262
 
263
- # Display metadata for the selected article
264
- selected_row = st.session_state.df[st.session_state.df['id'] == selected_index].iloc[0]
265
- st.markdown(f"### Title\n{selected_row['title']}", unsafe_allow_html=True)
266
- st.markdown(f"### Abstract\n{selected_row['abstract']}", unsafe_allow_html=True)
267
- st.markdown(f"[Read the full paper](https://arxiv.org/abs/{selected_row['id']})", unsafe_allow_html=True)
268
- st.markdown(f"[Download PDF](https://arxiv.org/pdf/{selected_row['id']})", unsafe_allow_html=True)
269
-
270
-
271
 
272
- if __name__ == "__main__":
273
- main()
274
 
 
 
275
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
 
5
  from datasets import load_dataset, Dataset
6
  from sklearn.cluster import KMeans
7
  import plotly.graph_objects as go
8
+ import time, random, datetime
9
  import logging
10
+ from sklearn.cluster import HDBSCAN
11
+
12
+
13
+ BACKGROUND_COLOR = 'black'
14
+ COLOR = 'white'
15
+
16
+ def set_page_container_style(
17
+ max_width: int = 10000, max_width_100_percent: bool = False,
18
+ padding_top: int = 1, padding_right: int = 10, padding_left: int = 1, padding_bottom: int = 10,
19
+ color: str = COLOR, background_color: str = BACKGROUND_COLOR,
20
+ ):
21
+ if max_width_100_percent:
22
+ max_width_str = f'max-width: 100%;'
23
+ else:
24
+ max_width_str = f'max-width: {max_width}px;'
25
+ st.markdown(
26
+ f'''
27
+ <style>
28
+ .reportview-container .css-1lcbmhc .css-1outpf7 {{
29
+ padding-top: 35px;
30
+ }}
31
+ .reportview-container .main .block-container {{
32
+ {max_width_str}
33
+ padding-top: {padding_top}rem;
34
+ padding-right: {padding_right}rem;
35
+ padding-left: {padding_left}rem;
36
+ padding-bottom: {padding_bottom}rem;
37
+ }}
38
+ .reportview-container .main {{
39
+ color: {color};
40
+ background-color: {background_color};
41
+ }}
42
+ </style>
43
+ ''',
44
+ unsafe_allow_html=True,
45
+ )
46
 
47
  # Additional libraries for querying
48
  from FlagEmbedding import FlagModel
49
 
50
  # Global variables and dataset loading
51
  global dataset_name
52
+ st.set_page_config(layout="wide")
53
+
54
+ dataset_name = "somewheresystems/dataclysm-arxiv"
55
+
56
+ set_page_container_style(
57
+ max_width = 1600, max_width_100_percent = True,
58
+ padding_top = 0, padding_right = 10, padding_left = 5, padding_bottom = 10
59
+ )
60
  st.session_state.dataclysm_arxiv = load_dataset(dataset_name, split="train")
61
  total_samples = len(st.session_state.dataclysm_arxiv)
62
 
 
116
  tsne_results = tsne.fit_transform(np.vstack(embeddings.tolist()))
117
 
118
  # Update progress bar to indicate completion
119
+ progress_text.text(f"t-SNE completed at {datetime.datetime.now()}. Processed {n_samples} samples with perplexity {perplexity}.")
120
  end_time = time.time() # End timing
121
  st.sidebar.text(f't-SNE completed in {end_time - start_time:.3f} seconds')
122
  return tsne_results
 
124
 
125
  def perform_clustering(df, tsne_results):
126
  start_time = time.time()
127
+ # Perform DBSCAN clustering
128
+ logging.info('Performing HDBSCAN clustering...')
129
  # Step 3: Visualization with Plotly
130
+ # Normalize the t-SNE results between 0 and 1
131
+ df['tsne-3d-one'] = (tsne_results[:,0] - tsne_results[:,0].min()) / (tsne_results[:,0].max() - tsne_results[:,0].min())
132
+ df['tsne-3d-two'] = (tsne_results[:,1] - tsne_results[:,1].min()) / (tsne_results[:,1].max() - tsne_results[:,1].min())
133
+ df['tsne-3d-three'] = (tsne_results[:,2] - tsne_results[:,2].min()) / (tsne_results[:,2].max() - tsne_results[:,2].min())
134
+
135
+ # Perform DBSCAN clustering
136
+ hdbscan = HDBSCAN(min_cluster_size=10, min_samples=50)
137
+ cluster_labels = hdbscan.fit_predict(df[['tsne-3d-one', 'tsne-3d-two', 'tsne-3d-three']])
138
+ df['cluster'] = cluster_labels
139
  end_time = time.time() # End timing
140
+ st.sidebar.text(f'HDBSCAN clustering completed in {end_time - start_time:.3f} seconds')
141
  return df
142
 
143
+ def update_camera_position(fig, df, df_query, result_id, K=10):
144
+ # Focus the camera on the closest result
145
+ top_K_ids = df_query.sort_values(by='proximity', ascending=True).head(K)['id'].tolist()
146
+ top_K_proximity = df_query['proximity'].tolist()
147
+ top_results = df[df['id'].isin(top_K_ids)]
148
+ camera_focus = dict(
149
+ eye=dict(x=top_results.iloc[0]['tsne-3d-one']*0.1, y=top_results.iloc[0]['tsne-3d-two']*0.1, z=top_results.iloc[0]['tsne-3d-three']*0.1)
150
+ )
151
+ # Normalize the proximity values to range between 1 and 10
152
+ normalized_proximity = [10 - (10 * (prox - min(top_K_proximity)) / (max(top_K_proximity) - min(top_K_proximity))) for prox in top_K_proximity]
153
+ # Create a dictionary mapping id to normalized proximity
154
+ id_to_proximity = dict(zip(top_K_ids, normalized_proximity))
155
+ # Set marker sizes based on proximity for top K ids, all other points stay the same -- 500% zoom
156
+ marker_sizes = [5 * id_to_proximity[id] if id in top_K_ids else 1 for id in df['id']]
157
+ # Store the original colors in a separate column
158
+ df['color'] = df['cluster']
159
+
160
+ fig = go.Figure(data=[go.Scatter3d(
161
+ x=df['tsne-3d-one'],
162
+ y=df['tsne-3d-two'],
163
+ z=df['tsne-3d-three'],
164
+ mode='markers',
165
+ marker=dict(size=marker_sizes, color=df['color'], colorscale='Viridis', opacity=0.8, line_width=0),
166
+ hovertext=df['hovertext'],
167
+ hoverinfo='text',
168
+ )])
169
+ # Set grid opacity to 10%
170
+ fig.update_layout(scene = dict(xaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
171
+ yaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
172
+ zaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)')))
173
+
174
+ # Add lines stemming from the first point to all other points in the top K
175
+ for i in range(1, K): # there are K-1 lines from the first point to the other K-1 points
176
+ fig.add_trace(go.Scatter3d(
177
+ x=[top_results.iloc[0]['tsne-3d-one'], top_results.iloc[i]['tsne-3d-one']],
178
+ y=[top_results.iloc[0]['tsne-3d-two'], top_results.iloc[i]['tsne-3d-two']],
179
+ z=[top_results.iloc[0]['tsne-3d-three'], top_results.iloc[i]['tsne-3d-three']],
180
+ mode='lines',
181
+ line=dict(color='white',width=0.3), # Set line opacity to 50%
182
+ showlegend=True,
183
+ name="centroid" if i == -1 else top_results.iloc[i]['id'], # Set the legend to "Top Result" for the first entry, and to the title of the article for the rest
184
+ hovertext=f'Title: Top K Results\nID: {top_K_ids[i]}, Proximity: {round(top_K_proximity[i], 4)}',
185
+ hoverinfo='text',
186
+ ))
187
+ fig.update_layout(plot_bgcolor='rgba(0,0,0,0)',
188
+ paper_bgcolor='rgba(0,0,0,0)',
189
+ scene_camera=camera_focus)
190
+ return fig
191
+
192
  def main():
193
  # Custom CSS
194
  custom_css = """
 
204
  color: #F8F8F8; /* Set the font color to F8F8F8 */
205
  }
206
  /* Add your CSS styles here */
207
+ .stPlotlyChart {
208
+ width: 100%;
209
+ height: 100%;
210
+ /* Other styles... */
211
+ }
212
  h1 {
213
  text-align: center;
214
  }
215
  h2,h3,h4 {
216
  text-align: justify;
217
+ font-size: 8px;
218
+ }
219
+ st-emotion-cache-1wmy9hl {
220
+ font-size: 8px;
221
  }
222
  body {
223
+ color: #fff;
224
+ background-color: #202020;
225
  }
226
+
227
  .stSlider .css-1cpxqw2 {
228
  background: #202020;
229
+ color: #fd5137;
230
+ }
231
+ .stSlider .text {
232
+ background: #202020;
233
+ color: #fd5137;
234
  }
235
  .stButton > button {
236
  background-color: #202020;
237
+ width: 60%;
238
+ margin-left: auto;
239
+ margin-right: auto;
240
+ display: block;
241
  padding: 10px 24px;
 
242
  font-size: 16px;
243
  font-weight: bold;
244
+ border: 1px solid #f8f8f8;
245
+ }
246
+ .stButton > button:hover {
247
+ color: #Fd5137
248
+ border: 1px solid #fd5137;
249
+ }
250
+ .stButton > button:active {
251
+ color: #F8F8F8;
252
+ border: 1px solid #fd5137;
253
+ background-color: #fd5137;
254
  }
255
  .reportview-container .main .block-container {
256
+ padding: 0;
257
  background-color: #202020;
258
+ width: 100%; /* Make the plotly graph take up full width */
259
+ }
260
+ .sidebar .sidebar-content {
261
+ background-image: linear-gradient(#202020,#202020);
262
+ color: white;
263
+ size: 0.2em; /* Make the text in the sidebar smaller */
264
+ padding: 0;
265
  }
266
+ .reportview-container .main .block-container {
267
+ background-color: #000000;
268
+ }
269
+ .stText {
270
+ padding: 0;
271
+ }
272
+ /* Set the main background color to #202020 */
273
+ .appview-container {
274
+ background-color: #000000;
275
+ padding: 0;
276
+ }
277
+ .stVerticalBlockBorderWrapper{
278
+ padding: 0;
279
+ margin-left: 0px;
280
+ }
281
+ .st-emotion-cache-1cypcdb {
282
+ background-color: #202020;
283
+ background-image: none;
284
+ color: #000000;
285
+ padding: 0;
286
+ }
287
+ .stPlotlyChart {
288
+ background-color: #000000;
289
+ background-image: none;
290
+ color: #000000;
291
+ padding: 0;
292
+ }
293
+ .reportview-container .css-1lcbmhc .css-1outpf7 {
294
+ padding-top: 35px;
295
+ }
296
+ .reportview-container .main .block-container {
297
+ max-width: 100%;
298
+ padding-top: 0rem;
299
+ padding-right: 0rem;
300
+ padding-left: 0rem;
301
+ padding-bottom: 10rem;
302
+ }
303
+ .reportview-container .main {
304
+ color: white;
305
+ background-color: black;
306
+ }
307
+ .st-emotion-cache-1avcm0n {
308
+ color: black;
309
+ background-color: black;
310
+ }
311
+ .st-emotion-cache-z5fcl4 {
312
+ padding-left: 0.1rem;
313
+ padding-right: 0.1rem;
314
+ }
315
+ .st-emotion-cache-z5fcl4 {
316
+ width: 100%;
317
+ padding: 3rem 1rem 1rem;
318
+ min-width: auto;
319
+ max-width: initial;
320
+ }
321
+ .st-emotion-cache-uf99v8 {
322
+ display: flex;
323
+ flex-direction: column;
324
+ width: 100%;
325
+ overflow: hidden;
326
+ -webkit-box-align: center;
327
+ align-items: center;
328
+ }
329
+
330
  </style>
331
  """
332
 
333
  # Inject custom CSS with markdown
334
  st.markdown(custom_css, unsafe_allow_html=True)
335
+ st.sidebar.title('arXiv Spatial Search Engine')
336
  st.sidebar.markdown(
337
+ '<a href="http://dataclysm.xyz" target="_blank" style="display: flex; justify-content: center; padding: 10px;">dataclysm.xyz <img src="https://www.somewhere.systems/S2-white-logo.png" style="width: 8px; height: 8px;"></a>',
338
  unsafe_allow_html=True
339
  )
340
+ # Create a placeholder for the chart
341
+ chart_placeholder = st.empty()
342
+
343
  # Check if data needs to be loaded
344
  if 'data_loaded' not in st.session_state or not st.session_state.data_loaded:
345
  # User input for number of samples
346
+ num_samples = st.sidebar.slider('Select number of samples', 1000, int(round(total_samples/10)), 1000)
347
+ if 'fig' not in st.session_state:
348
+ with open('prayers.txt', 'r') as file:
349
+ lines = file.readlines()
350
+ random_line = random.choice(lines).strip()
351
+ st.session_state.fig = go.Figure(data=[go.Scatter3d(x=[], y=[], z=[], mode='markers')])
352
+ st.session_state.fig.add_annotation(
353
+ x=0.5,
354
+ y=0.5,
355
+ xref="paper",
356
+ yref="paper",
357
+ text=random_line,
358
+ showarrow=False,
359
+ font=dict(
360
+ size=16,
361
+ color="black"
362
+ ),
363
+ align="center",
364
+ ax=0,
365
+ ay=0,
366
+ bordercolor="black",
367
+ borderwidth=2,
368
+ borderpad=4,
369
+ bgcolor="white",
370
+ opacity=0.8
371
+ )
372
+ # Set grid opacity to 10%
373
+ st.session_state.fig.update_layout(scene = dict(xaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
374
+ yaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
375
+ zaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)')))
376
+
377
+ st.session_state.fig.update_layout(
378
+ plot_bgcolor='rgba(0,0,0,0)',
379
+ paper_bgcolor='rgba(0,0,0,0)',
380
+ height=888,
381
+ margin=dict(l=0, r=0, b=0, t=0),
382
+ scene_camera=dict(eye=dict(x=0.1, y=0.1, z=0.1))
383
+ )
384
+ chart_placeholder.plotly_chart(st.session_state.fig, use_container_width=True)
385
  if st.sidebar.button('Initialize'):
386
  st.sidebar.text('Initializing data pipeline...')
387
 
 
399
  print(f"FAISS index for {column_name} added.")
400
 
401
  return dataset
 
 
402
 
403
  # Load data and perform t-SNE and clustering
404
  df, embeddings = load_data(num_samples)
 
435
  marker=dict(
436
  size=1,
437
  color=df['cluster'],
438
+ colorscale='Jet',
439
+ opacity=0.75
440
  )
441
  )])
442
+ # Set grid opacity to 10%
443
+ fig.update_layout(scene = dict(xaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
444
+ yaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
445
+ zaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)')))
446
 
447
  fig.update_layout(
448
+ plot_bgcolor='rgba(0,0,0,0)',
449
+ paper_bgcolor='rgba(0,0,0,0)',
450
  height=800,
451
  margin=dict(l=0, r=0, b=0, t=0),
452
+ scene_camera=dict(eye=dict(x=0.1, y=0.1, z=0.1))
 
 
 
 
 
453
  )
454
  st.session_state.fig = fig
455
 
456
  # Display the plot if data is loaded
457
  if 'data_loaded' in st.session_state and st.session_state.data_loaded:
458
+ chart_placeholder.plotly_chart(st.session_state.fig, use_container_width=True)
459
 
460
 
461
  # Sidebar for detailed view
462
  if 'df' in st.session_state:
463
  # Sidebar for querying
464
  with st.sidebar:
465
+ st.sidebar.markdown("# Detailed View")
466
+ selected_index = st.sidebar.selectbox("Select Key", st.session_state.df.id)
467
+
468
+ # Display metadata for the selected article
469
+ selected_row = st.session_state.df[st.session_state.df['id'] == selected_index].iloc[0]
470
+ st.markdown(f"### Title\n{selected_row['title']}", unsafe_allow_html=True)
471
+ st.markdown(f"### Abstract\n{selected_row['abstract']}", unsafe_allow_html=True)
472
+ st.markdown(f"[Read the full paper](https://arxiv.org/abs/{selected_row['id']})", unsafe_allow_html=True)
473
+ st.markdown(f"[Download PDF](https://arxiv.org/pdf/{selected_row['id']})", unsafe_allow_html=True)
474
+
475
+ st.sidebar.markdown("### Find Similar in Latent Space")
476
+ query = st.text_input("", value=selected_row['title'])
477
+ top_k = st.slider("top k", 1, 100, 10)
478
  if st.button("Search"):
479
  # Define the model
480
  print("Initializing model...")
 
485
 
486
  query_embedding = model.encode([query])
487
  # Retrieve examples by title similarity (or abstract, depending on your preference)
488
+ scores_title, retrieved_examples_title = st.session_state.dataclysm_title_indexed.get_nearest_examples('title_embedding', query_embedding, k=top_k)
489
  df_query = pd.DataFrame(retrieved_examples_title)
490
  df_query['proximity'] = scores_title
491
  df_query = df_query.sort_values(by='proximity', ascending=True)
 
494
  # Fix the <a href link> to display properly
495
  df_query['URL'] = df_query['id'].apply(lambda x: f'<a href="https://arxiv.org/abs/{x}" target="_blank">Link</a>')
496
  st.sidebar.markdown(df_query[['title', 'proximity', 'id']].to_html(escape=False), unsafe_allow_html=True)
497
+ # Get the ID of the top search result
498
+ top_result_id = df_query.iloc[0]['id']
499
 
500
+ # Update the camera position and appearance of points
501
+ updated_fig = update_camera_position(st.session_state.fig, st.session_state.df, df_query, top_result_id,top_k)
 
 
 
 
 
 
502
 
503
+ # Update the figure in the session state and redraw the plot
504
+ st.session_state.fig = updated_fig
505
 
506
+ # Update the chart using the placeholder
507
+ chart_placeholder.plotly_chart(st.session_state.fig, use_container_width=True)
508
 
509
+
510
+
511
+ if __name__ == "__main__":
512
+ main()