m7n commited on
Commit
ad4e2b9
·
1 Parent(s): 3673672

Enhance data processing by adding CSV and Pickle file upload support, improving error handling, and refining the prediction pipeline. Update the `process_records_to_df` function to handle existing DataFrames and ensure required fields are processed correctly.

Browse files
Files changed (2) hide show
  1. app.py +182 -129
  2. openalex_utils.py +23 -5
app.py CHANGED
@@ -1,14 +1,16 @@
1
  import time
2
  print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}")
3
-
4
  # Standard library imports
5
  import os
6
  from pathlib import Path
7
  from datetime import datetime
8
  from itertools import chain
 
9
 
10
  import base64
11
  import json
 
12
 
13
  # Third-party imports
14
  import numpy as np
@@ -169,13 +171,13 @@ else:
169
 
170
 
171
 
172
-
173
-
174
 
175
  def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
176
  sample_reduction_method, plot_time_checkbox,
177
  locally_approximate_publication_date_checkbox,
178
  download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
 
179
  progress=gr.Progress()):
180
  """
181
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
@@ -188,11 +190,28 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
188
  sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results")
189
  plot_time_checkbox (bool): Whether to color points by publication date
190
  locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting.
 
 
 
 
191
  progress (gr.Progress): Gradio progress tracker
192
 
193
  Returns:
194
  tuple: (link to visualization, iframe HTML)
195
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # Get the authentication token
197
  if is_running_in_hf_space():
198
  token = _get_token(request)
@@ -208,103 +227,145 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
208
  else:
209
  user_type = "registered"
210
  print(f"User type: {user_type}")
211
-
212
-
213
- # Check if input is empty or whitespace
214
- print(f"Input: {text_input}")
215
- if not text_input or text_input.isspace():
216
- error_message = "Error: Please enter a valid OpenAlex URL in the 'OpenAlex-search URL'-field"
217
- return [
218
- error_message, # iframe HTML
219
- gr.DownloadButton(label="Download Interactive Visualization", value='html_file_path', visible=False), # html download
220
- gr.DownloadButton(label="Download CSV Data", value='csv_file_path', visible=False), # csv download
221
- gr.DownloadButton(label="Download Static Plot", value='png_file_path', visible=False), # png download
222
- gr.Button(visible=False) # cancel button state
223
- ]
224
 
225
-
226
-
227
- # Check if the input is a valid OpenAlex URL
228
-
229
-
230
-
231
- start_time = time.time()
232
- print('Starting data projection pipeline')
233
- progress(0.1, desc="Starting...")
234
-
235
- # Split input into multiple URLs if present
236
- urls = [url.strip() for url in text_input.split(';')]
237
- records = []
238
- total_query_length = 0
239
-
240
- # Use first URL for filename
241
- first_query, first_params = openalex_url_to_pyalex_query(urls[0])
242
- filename = openalex_url_to_filename(urls[0])
243
- print(f"Filename: {filename}")
244
-
245
- # Process each URL
246
- for i, url in enumerate(urls):
247
- query, params = openalex_url_to_pyalex_query(url)
248
- query_length = query.count()
249
- total_query_length += query_length
250
- print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
251
-
252
- target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
253
- records_per_query = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- should_break = False
256
- for page in query.paginate(per_page=200, n_max=None):
257
- # Add retry mechanism for processing each page
258
- max_retries = 5
259
- base_wait_time = 1 # Starting wait time in seconds
260
- exponent = 1.5 # Exponential factor
 
 
 
 
 
261
 
262
- for retry_attempt in range(max_retries):
263
- try:
264
- for record in page:
265
- records.append(record)
266
- records_per_query += 1
267
- progress(0.1 + (0.2 * len(records) / (total_query_length)),
268
- desc=f"Getting data from query {i+1}/{len(urls)}...")
269
-
270
- if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
271
- should_break = True
272
- break
273
- # If we get here without an exception, break the retry loop
274
- break
275
- except Exception as e:
276
- print(f"Error processing page: {e}")
277
- if retry_attempt < max_retries - 1:
278
- wait_time = base_wait_time * (exponent ** retry_attempt) + random.random()
279
- print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...")
280
- time.sleep(wait_time)
281
- else:
282
- print(f"Maximum retries reached. Continuing with next page.")
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  if should_break:
285
  break
286
- if should_break:
287
- break
288
- print(f"Query completed in {time.time() - start_time:.2f} seconds")
289
-
290
- # Process records
291
- processing_start = time.time()
292
- records_df = process_records_to_df(records)
293
-
294
- if reduce_sample_checkbox and sample_reduction_method != "All":
295
- sample_size = min(sample_size_slider, len(records_df))
296
- if sample_reduction_method == "n random samples":
297
- records_df = records_df.sample(sample_size)
298
- elif sample_reduction_method == "First n samples":
299
- records_df = records_df.iloc[:sample_size]
300
- print(f"Records processed in {time.time() - processing_start:.2f} seconds")
301
-
302
- # Create embeddings
303
  embedding_start = time.time()
304
  progress(0.3, desc="Embedding Data...")
305
- texts_to_embedd = [f"{title} {abstract}" for title, abstract
306
- in zip(records_df['title'], records_df['abstract'])]
307
-
308
 
309
  if is_running_in_hf_space():
310
  if len(texts_to_embedd) < 2000:
@@ -357,13 +418,18 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
357
  norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
358
  records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
359
 
360
-
361
-
362
  stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
363
  stacked_df = stacked_df.fillna("Unlabelled")
364
  stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()]
365
  extra_data = pd.DataFrame(stacked_df['doi'])
366
  print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
 
 
 
 
 
 
 
367
  if citation_graph_checkbox:
368
  citation_graph_start = time.time()
369
  citation_graph = create_citation_graph(records_df)
@@ -372,9 +438,6 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
372
  draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True,
373
  min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])])
374
  print(f"Citation graph created and saved in {time.time() - citation_graph_start:.2f} seconds")
375
-
376
-
377
-
378
 
379
  # Create and save plot
380
  plot_start = time.time()
@@ -382,10 +445,9 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
382
  # Create a solid black colormap
383
  black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000'])
384
 
385
-
386
  plot = datamapplot.create_interactive_plot(
387
  stacked_df[['x','y']].values,
388
- np.array(stacked_df['cluster_2_labels']),
389
  np.array(['Unlabelled' if pd.isna(x) else x for x in stacked_df['parsed_field']]),
390
 
391
  hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
@@ -413,23 +475,16 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
413
  )
414
 
415
  # Save plot
416
- html_file_name = f"{filename}.html"
417
- html_file_path = static_dir / html_file_name
418
  plot.save(html_file_path)
419
  print(f"Plot created and saved in {time.time() - plot_start:.2f} seconds")
420
-
421
 
422
- #datamapplot==0.5.1
423
  # Save additional files if requested
424
- csv_file_path = static_dir / f"{filename}.csv"
425
- png_file_path = static_dir / f"{filename}.png"
426
-
427
  if download_csv_checkbox:
428
  # Export relevant column
429
  export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
430
- export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()]
431
  export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
432
- if locally_approximate_publication_date_checkbox:
433
  export_df['approximate_publication_year'] = local_years
434
  export_df.to_csv(csv_file_path, index=False)
435
 
@@ -453,17 +508,10 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
453
 
454
  # Replace less common labels with 'Unlabelled'
455
  combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels])
456
- #combined_labels = np.array(['Unlabelled' for label in combined_labels])
457
- #if label not in top_30_labels else label
458
  colors_base = ['#536878' for _ in range(len(labels1))]
459
  print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds")
460
 
461
  # Create main plot
462
- print(labels1)
463
- print(labels2)
464
- print(sample_to_plot[['x','y']].values)
465
- print(combined_labels)
466
-
467
  main_plot_start = time.time()
468
  fig, ax = datamapplot.create_plot(
469
  sample_to_plot[['x','y']].values,
@@ -487,15 +535,12 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
487
  )
488
  print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds")
489
 
490
-
491
  if citation_graph_checkbox:
492
-
493
  # Read and add the graph image
494
  graph_img = plt.imread(graph_file_path)
495
  ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])],
496
- alpha=0.9, aspect='auto')
497
-
498
-
499
  if len(records_df) > 50_000:
500
  point_size = .5
501
  elif len(records_df) > 10_000:
@@ -539,17 +584,12 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
539
  # Save plot
540
  save_start = time.time()
541
  plt.axis('off')
542
- png_file_path = static_dir / f"{filename}.png"
543
  plt.savefig(png_file_path, dpi=300, bbox_inches='tight')
544
  plt.close()
545
  print(f"Plot saving completed in {time.time() - save_start:.2f} seconds")
546
 
547
  print(f"Total PNG generation completed in {time.time() - png_start_time:.2f} seconds")
548
 
549
-
550
-
551
-
552
-
553
  progress(1.0, desc="Done!")
554
  print(f"Total pipeline completed in {time.time() - start_time:.2f} seconds")
555
  iframe = f"""<iframe src="{html_file_path}" width="100%" height="1000px"></iframe>"""
@@ -650,10 +690,16 @@ with gr.Blocks(theme=theme, css="""
650
  value="First n samples",
651
  info="How to choose the samples to keep."
652
  )
 
 
 
 
 
 
653
  sample_size_slider = gr.Slider(
654
  label="Sample Size",
655
  minimum=500,
656
- maximum=20000,
657
  step=10,
658
  value=1000,
659
  info="How many samples to keep.",
@@ -691,6 +737,12 @@ with gr.Blocks(theme=theme, css="""
691
  info="Adds a citation graph of the sample to the plot."
692
  )
693
 
 
 
 
 
 
 
694
 
695
 
696
  with gr.Column(scale=2):
@@ -706,7 +758,7 @@ with gr.Blocks(theme=theme, css="""
706
 
707
  ## Who made this?
708
 
709
- This project was developed by [Maximilian Noichl](https://maxnoichl.eu) (Utrecht University), in cooperation with Andrea Loettger and Tarja Knuuttila at the [Possible Life project](http://www.possiblelife.eu/), at the University of Vienna. If this project is useful in any way for your research, we would appreciate citation of **...**
710
 
711
  This project received funding from the European Research Council under the European Union's Horizon 2020 research and innovation programme (LIFEMODE project, grant agreement No. 818772).
712
 
@@ -770,7 +822,8 @@ with gr.Blocks(theme=theme, css="""
770
  locally_approximate_publication_date_checkbox,
771
  download_csv_checkbox,
772
  download_png_checkbox,
773
- citation_graph_checkbox
 
774
  ],
775
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
776
  )
@@ -811,4 +864,4 @@ if __name__ == "__main__":
811
  os.environ["GRADIO_SSR_MODE"] = "True"
812
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
813
  else:
814
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import time
2
  print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}")
3
+ # source openalex_env_map/bin/activate
4
  # Standard library imports
5
  import os
6
  from pathlib import Path
7
  from datetime import datetime
8
  from itertools import chain
9
+ import ast # Add this import at the top with the standard library imports
10
 
11
  import base64
12
  import json
13
+ import pickle
14
 
15
  # Third-party imports
16
  import numpy as np
 
171
 
172
 
173
 
174
+
 
175
 
176
  def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
177
  sample_reduction_method, plot_time_checkbox,
178
  locally_approximate_publication_date_checkbox,
179
  download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
180
+ csv_upload,
181
  progress=gr.Progress()):
182
  """
183
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
 
190
  sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results")
191
  plot_time_checkbox (bool): Whether to color points by publication date
192
  locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting.
193
+ download_csv_checkbox (bool): Whether to download CSV data
194
+ download_png_checkbox (bool): Whether to download PNG data
195
+ citation_graph_checkbox (bool): Whether to add citation graph
196
+ csv_upload (str): Path to uploaded CSV file
197
  progress (gr.Progress): Gradio progress tracker
198
 
199
  Returns:
200
  tuple: (link to visualization, iframe HTML)
201
  """
202
+ # Initialize start_time at the beginning of the function
203
+ start_time = time.time()
204
+
205
+ # Helper function to generate error responses
206
+ def create_error_response(error_message):
207
+ return [
208
+ error_message,
209
+ gr.DownloadButton(label="Download Interactive Visualization", value='html_file_path', visible=False),
210
+ gr.DownloadButton(label="Download CSV Data", value='csv_file_path', visible=False),
211
+ gr.DownloadButton(label="Download Static Plot", value='png_file_path', visible=False),
212
+ gr.Button(visible=False)
213
+ ]
214
+
215
  # Get the authentication token
216
  if is_running_in_hf_space():
217
  token = _get_token(request)
 
227
  else:
228
  user_type = "registered"
229
  print(f"User type: {user_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ # Check if a file has been uploaded or if we need to use OpenAlex query
232
+ if csv_upload is not None:
233
+ print(f"Using uploaded file instead of OpenAlex query: {csv_upload}")
234
+ try:
235
+ file_extension = os.path.splitext(csv_upload)[1].lower()
236
+
237
+ if file_extension == '.csv':
238
+ # Read the CSV file
239
+ records_df = pd.read_csv(csv_upload)
240
+ filename = os.path.splitext(os.path.basename(csv_upload))[0]
241
+
242
+ # Process dictionary-like strings in the DataFrame
243
+ for column in records_df.columns:
244
+ # Check if the column contains dictionary-like strings
245
+ if records_df[column].dtype == 'object':
246
+ try:
247
+ # Use a sample value to check if it looks like a dictionary or list
248
+ sample_value = records_df[column].dropna().iloc[0] if not records_df[column].dropna().empty else None
249
+ # Add type checking before using startswith
250
+ if isinstance(sample_value, str) and (sample_value.startswith('{') or sample_value.startswith('[')):
251
+ # Try to convert strings to Python objects using ast.literal_eval
252
+ records_df[column] = records_df[column].apply(
253
+ lambda x: ast.literal_eval(x) if isinstance(x, str) and (
254
+ (x.startswith('{') and x.endswith('}')) or
255
+ (x.startswith('[') and x.endswith(']'))
256
+ ) else x
257
+ )
258
+ except (ValueError, SyntaxError, TypeError) as e:
259
+ # If conversion fails, keep as string
260
+ print(f"Could not convert column {column} to Python objects: {e}")
261
+
262
+ elif file_extension == '.pkl':
263
+ # Read the pickle file
264
+ with open(csv_upload, 'rb') as f:
265
+ records_df = pickle.load(f)
266
+ filename = os.path.splitext(os.path.basename(csv_upload))[0]
267
+
268
+ else:
269
+ error_message = f"Error: Unsupported file type. Please upload a CSV or PKL file."
270
+ return create_error_response(error_message)
271
+
272
+ records_df = process_records_to_df(records_df)
273
+
274
+ # Make sure we have the required columns
275
+ required_columns = ['title', 'abstract', 'publication_year']
276
+ missing_columns = [col for col in required_columns if col not in records_df.columns]
277
+
278
+ if missing_columns:
279
+ error_message = f"Error: Uploaded file is missing required columns: {', '.join(missing_columns)}"
280
+ return create_error_response(error_message)
281
+
282
+ print(f"Successfully loaded {len(records_df)} records from uploaded file")
283
+ progress(0.2, desc="Processing uploaded data...")
284
+
285
+ except Exception as e:
286
+ error_message = f"Error processing uploaded file: {str(e)}"
287
+ return create_error_response(error_message)
288
+ else:
289
+ # Check if input is empty or whitespace
290
+ print(f"Input: {text_input}")
291
+ if not text_input or text_input.isspace():
292
+ error_message = "Error: Please enter a valid OpenAlex URL in the 'OpenAlex-search URL'-field or upload a CSV file"
293
+ return create_error_response(error_message)
294
+
295
+ print('Starting data projection pipeline')
296
+ progress(0.1, desc="Starting...")
297
+
298
+ # Split input into multiple URLs if present
299
+ urls = [url.strip() for url in text_input.split(';')]
300
+ records = []
301
+ total_query_length = 0
302
 
303
+ # Use first URL for filename
304
+ first_query, first_params = openalex_url_to_pyalex_query(urls[0])
305
+ filename = openalex_url_to_filename(urls[0])
306
+ print(f"Filename: {filename}")
307
+
308
+ # Process each URL
309
+ for i, url in enumerate(urls):
310
+ query, params = openalex_url_to_pyalex_query(url)
311
+ query_length = query.count()
312
+ total_query_length += query_length
313
+ print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
314
 
315
+ target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
316
+ records_per_query = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
+ should_break = False
319
+ for page in query.paginate(per_page=200, n_max=None):
320
+ # Add retry mechanism for processing each page
321
+ max_retries = 5
322
+ base_wait_time = 1 # Starting wait time in seconds
323
+ exponent = 1.5 # Exponential factor
324
+
325
+ for retry_attempt in range(max_retries):
326
+ try:
327
+ for record in page:
328
+ records.append(record)
329
+ records_per_query += 1
330
+ progress(0.1 + (0.2 * len(records) / (total_query_length)),
331
+ desc=f"Getting data from query {i+1}/{len(urls)}...")
332
+
333
+ if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
334
+ should_break = True
335
+ break
336
+ # If we get here without an exception, break the retry loop
337
+ break
338
+ except Exception as e:
339
+ print(f"Error processing page: {e}")
340
+ if retry_attempt < max_retries - 1:
341
+ wait_time = base_wait_time * (exponent ** retry_attempt) + random.random()
342
+ print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...")
343
+ time.sleep(wait_time)
344
+ else:
345
+ print(f"Maximum retries reached. Continuing with next page.")
346
+
347
+ if should_break:
348
+ break
349
  if should_break:
350
  break
351
+ print(f"Query completed in {time.time() - start_time:.2f} seconds")
352
+
353
+ # Process records
354
+ processing_start = time.time()
355
+ records_df = process_records_to_df(records)
356
+
357
+ if reduce_sample_checkbox and sample_reduction_method != "All":
358
+ sample_size = min(sample_size_slider, len(records_df))
359
+ if sample_reduction_method == "n random samples":
360
+ records_df = records_df.sample(sample_size)
361
+ elif sample_reduction_method == "First n samples":
362
+ records_df = records_df.iloc[:sample_size]
363
+ print(f"Records processed in {time.time() - processing_start:.2f} seconds")
364
+
365
+ # Create embeddings - this happens regardless of data source
 
 
366
  embedding_start = time.time()
367
  progress(0.3, desc="Embedding Data...")
368
+ texts_to_embedd = [f"{title} {abstract}" for title, abstract in zip(records_df['title'], records_df['abstract'])]
 
 
369
 
370
  if is_running_in_hf_space():
371
  if len(texts_to_embedd) < 2000:
 
418
  norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
419
  records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
420
 
 
 
421
  stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
422
  stacked_df = stacked_df.fillna("Unlabelled")
423
  stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()]
424
  extra_data = pd.DataFrame(stacked_df['doi'])
425
  print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
426
+
427
+ # Prepare file paths
428
+ html_file_name = f"{filename}.html"
429
+ html_file_path = static_dir / html_file_name
430
+ csv_file_path = static_dir / f"{filename}.csv"
431
+ png_file_path = static_dir / f"{filename}.png"
432
+
433
  if citation_graph_checkbox:
434
  citation_graph_start = time.time()
435
  citation_graph = create_citation_graph(records_df)
 
438
  draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True,
439
  min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])])
440
  print(f"Citation graph created and saved in {time.time() - citation_graph_start:.2f} seconds")
 
 
 
441
 
442
  # Create and save plot
443
  plot_start = time.time()
 
445
  # Create a solid black colormap
446
  black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000'])
447
 
 
448
  plot = datamapplot.create_interactive_plot(
449
  stacked_df[['x','y']].values,
450
+ np.array(stacked_df['cluster_2_labels']),
451
  np.array(['Unlabelled' if pd.isna(x) else x for x in stacked_df['parsed_field']]),
452
 
453
  hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
 
475
  )
476
 
477
  # Save plot
 
 
478
  plot.save(html_file_path)
479
  print(f"Plot created and saved in {time.time() - plot_start:.2f} seconds")
 
480
 
 
481
  # Save additional files if requested
 
 
 
482
  if download_csv_checkbox:
483
  # Export relevant column
484
  export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
485
+ export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()]
486
  export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
487
+ if locally_approximate_publication_date_checkbox and plot_time_checkbox:
488
  export_df['approximate_publication_year'] = local_years
489
  export_df.to_csv(csv_file_path, index=False)
490
 
 
508
 
509
  # Replace less common labels with 'Unlabelled'
510
  combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels])
 
 
511
  colors_base = ['#536878' for _ in range(len(labels1))]
512
  print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds")
513
 
514
  # Create main plot
 
 
 
 
 
515
  main_plot_start = time.time()
516
  fig, ax = datamapplot.create_plot(
517
  sample_to_plot[['x','y']].values,
 
535
  )
536
  print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds")
537
 
 
538
  if citation_graph_checkbox:
 
539
  # Read and add the graph image
540
  graph_img = plt.imread(graph_file_path)
541
  ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])],
542
+ alpha=0.9, aspect='auto')
543
+
 
544
  if len(records_df) > 50_000:
545
  point_size = .5
546
  elif len(records_df) > 10_000:
 
584
  # Save plot
585
  save_start = time.time()
586
  plt.axis('off')
 
587
  plt.savefig(png_file_path, dpi=300, bbox_inches='tight')
588
  plt.close()
589
  print(f"Plot saving completed in {time.time() - save_start:.2f} seconds")
590
 
591
  print(f"Total PNG generation completed in {time.time() - png_start_time:.2f} seconds")
592
 
 
 
 
 
593
  progress(1.0, desc="Done!")
594
  print(f"Total pipeline completed in {time.time() - start_time:.2f} seconds")
595
  iframe = f"""<iframe src="{html_file_path}" width="100%" height="1000px"></iframe>"""
 
690
  value="First n samples",
691
  info="How to choose the samples to keep."
692
  )
693
+
694
+ if is_running_in_hf_zero_gpu():
695
+ max_sample_size = 20000
696
+ else:
697
+ max_sample_size = 250000
698
+
699
  sample_size_slider = gr.Slider(
700
  label="Sample Size",
701
  minimum=500,
702
+ maximum=max_sample_size,
703
  step=10,
704
  value=1000,
705
  info="How many samples to keep.",
 
737
  info="Adds a citation graph of the sample to the plot."
738
  )
739
 
740
+ gr.Markdown("### Upload Your Own Data")
741
+ csv_upload = gr.File(
742
+ file_count="single",
743
+ label="Upload your own CSV or Pickle file downloaded via pyalex.",
744
+ file_types=[".csv", ".pkl"],
745
+ )
746
 
747
 
748
  with gr.Column(scale=2):
 
758
 
759
  ## Who made this?
760
 
761
+ This project was developed by [Maximilian Noichl](https://maxnoichl.eu) (Utrecht University), in cooperation with Andrea Loettgers and Tarja Knuuttila at the [Possible Life project](http://www.possiblelife.eu/), at the University of Vienna. If this project is useful in any way for your research, we would appreciate citation of **...**
762
 
763
  This project received funding from the European Research Council under the European Union's Horizon 2020 research and innovation programme (LIFEMODE project, grant agreement No. 818772).
764
 
 
822
  locally_approximate_publication_date_checkbox,
823
  download_csv_checkbox,
824
  download_png_checkbox,
825
+ citation_graph_checkbox,
826
+ csv_upload
827
  ],
828
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
829
  )
 
864
  os.environ["GRADIO_SSR_MODE"] = "True"
865
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
866
  else:
867
+ uvicorn.run(app, host="0.0.0.0", port=7860)
openalex_utils.py CHANGED
@@ -83,23 +83,41 @@ def get_field(x):
83
  def process_records_to_df(records):
84
  """
85
  Convert OpenAlex records to a pandas DataFrame with processed fields.
 
86
 
87
  Args:
88
- records (list): List of OpenAlex record dictionaries
89
 
90
  Returns:
91
  pandas.DataFrame: Processed DataFrame with abstracts, publications, and titles
92
  """
93
- records_df = pd.DataFrame(records)
94
- records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
95
- records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]
 
 
 
 
 
 
 
 
 
 
96
 
 
97
  records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
98
  records_df['abstract'] = records_df['abstract'].fillna(' ')
99
  records_df['title'] = records_df['title'].fillna(' ')
100
  records_df = records_df.drop_duplicates(subset=['id']).reset_index(drop=True)
101
 
102
- return records_df
 
 
 
 
 
 
103
 
104
  def openalex_url_to_filename(url):
105
  """
 
83
  def process_records_to_df(records):
84
  """
85
  Convert OpenAlex records to a pandas DataFrame with processed fields.
86
+ Can handle either raw OpenAlex records or an existing DataFrame.
87
 
88
  Args:
89
+ records (list or pd.DataFrame): List of OpenAlex record dictionaries or existing DataFrame
90
 
91
  Returns:
92
  pandas.DataFrame: Processed DataFrame with abstracts, publications, and titles
93
  """
94
+ # If records is already a DataFrame, use it directly
95
+ if isinstance(records, pd.DataFrame):
96
+ records_df = records.copy()
97
+ # Only process abstract_inverted_index and primary_location if they exist
98
+ if 'abstract_inverted_index' in records_df.columns:
99
+ records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
100
+ if 'primary_location' in records_df.columns:
101
+ records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]
102
+ else:
103
+ # Process raw records as before
104
+ records_df = pd.DataFrame(records)
105
+ records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
106
+ records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]
107
 
108
+ # Fill missing values and deduplicate
109
  records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
110
  records_df['abstract'] = records_df['abstract'].fillna(' ')
111
  records_df['title'] = records_df['title'].fillna(' ')
112
  records_df = records_df.drop_duplicates(subset=['id']).reset_index(drop=True)
113
 
114
+ return records_df
115
+
116
+
117
+
118
+
119
+
120
+
121
 
122
  def openalex_url_to_filename(url):
123
  """