m7n commited on
Commit
cabc445
·
1 Parent(s): 9c57eb7

Implement DOI list handling in CSV uploads and add highlight color feature in predictions. Refactor embedding creation function for clarity and enhance OpenAlex record fetching with new utility function.

Browse files
Files changed (3) hide show
  1. app.py +52 -32
  2. color_utils.py +19 -0
  3. openalex_utils.py +23 -7
app.py CHANGED
@@ -51,6 +51,8 @@ import gradio as gr
51
  print(f"Gradio version: {gr.__version__}")
52
 
53
  import subprocess
 
 
54
 
55
  def print_datamapplot_version():
56
  try:
@@ -101,7 +103,7 @@ try:
101
  except (ImportError, ModuleNotFoundError):
102
  HAS_SPACES = False
103
 
104
- # Provide a harmless fallback so decorators dont explode
105
  if not HAS_SPACES:
106
  class _Dummy:
107
  def GPU(self, *a, **k):
@@ -125,7 +127,8 @@ from openalex_utils import (
125
  openalex_url_to_pyalex_query,
126
  get_field,
127
  process_records_to_df,
128
- openalex_url_to_filename
 
129
  )
130
  from styles import DATAMAP_CUSTOM_CSS
131
  from data_setup import (
@@ -234,10 +237,10 @@ def create_embeddings_299(texts_to_embedd):
234
 
235
 
236
  # else:
237
- # def create_embeddings(texts_to_embedd):
238
- # """Create embeddings for the input texts using the loaded model."""
239
- # return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
240
-
241
 
242
 
243
 
@@ -247,7 +250,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
247
  sample_reduction_method, plot_time_checkbox,
248
  locally_approximate_publication_date_checkbox,
249
  download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
250
- csv_upload,
251
  progress=gr.Progress()):
252
  """
253
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
@@ -264,6 +267,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
264
  download_png_checkbox (bool): Whether to download PNG data
265
  citation_graph_checkbox (bool): Whether to add citation graph
266
  csv_upload (str): Path to uploaded CSV file
 
267
  progress (gr.Progress): Gradio progress tracker
268
 
269
  Returns:
@@ -309,25 +313,33 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
309
  records_df = pd.read_csv(csv_upload)
310
  filename = os.path.splitext(os.path.basename(csv_upload))[0]
311
 
312
- # Convert *every* cell that looks like a serialized list/dict
313
- def _try_parse_obj(cell):
314
- if isinstance(cell, str):
315
- txt = cell.strip()
316
- if (txt.startswith('{') and txt.endswith('}')) or (txt.startswith('[') and txt.endswith(']')):
317
- # Try JSON first
318
- try:
319
- return json.loads(txt)
320
- except Exception:
321
- pass
322
- # Fallback to Python-repr (single quotes etc.)
323
- try:
324
- return ast.literal_eval(txt)
325
- except Exception:
326
- pass
327
- return cell
328
-
329
- records_df = records_df.map(_try_parse_obj)
330
- print(records_df.head())
 
 
 
 
 
 
 
 
331
 
332
  else:
333
  error_message = f"Error: Unsupported file type. Please upload a CSV or PKL file."
@@ -458,8 +470,10 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
458
 
459
  basedata_df['color'] = '#ced4d211'
460
 
 
 
461
  if not plot_time_checkbox:
462
- records_df['color'] = '#5e2784'
463
  else:
464
  cmap = colormaps.haline
465
  if not locally_approximate_publication_date_checkbox:
@@ -521,7 +535,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
521
  height=1000,
522
  point_radius_min_pixels=1,
523
  text_outline_width=5,
524
- point_hover_color='#5e2784',
525
  point_radius_max_pixels=7,
526
  cmap=black_cmap,
527
  background_image=graph_file_name if citation_graph_checkbox else None,
@@ -807,9 +821,14 @@ with gr.Blocks(theme=theme, css="""
807
  label="Upload your own CSV file downloaded via pyalex.",
808
  file_types=[".csv"],
809
  )
810
-
811
 
812
-
 
 
 
 
 
 
813
 
814
  with gr.Column(scale=2):
815
  html = gr.HTML(
@@ -853,7 +872,7 @@ with gr.Blocks(theme=theme, css="""
853
 
854
  ## I want to use my own data!
855
 
856
- Sure! You can upload csv-files produced by downloading things from OpenAlex using the pyalex package. You will need to provide at least the columns `id`, `title`, `publication_year`, `doi`, `abstract` or `abstract_inverted_index`, `referenced_works` and `primary_topic`.
857
 
858
  </div>
859
  """)
@@ -894,7 +913,8 @@ with gr.Blocks(theme=theme, css="""
894
  download_csv_checkbox,
895
  download_png_checkbox,
896
  citation_graph_checkbox,
897
- csv_upload
 
898
  ],
899
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
900
  )
 
51
  print(f"Gradio version: {gr.__version__}")
52
 
53
  import subprocess
54
+ import re
55
+ from color_utils import rgba_to_hex
56
 
57
  def print_datamapplot_version():
58
  try:
 
103
  except (ImportError, ModuleNotFoundError):
104
  HAS_SPACES = False
105
 
106
+ # Provide a harmless fallback so decorators don't explode
107
  if not HAS_SPACES:
108
  class _Dummy:
109
  def GPU(self, *a, **k):
 
127
  openalex_url_to_pyalex_query,
128
  get_field,
129
  process_records_to_df,
130
+ openalex_url_to_filename,
131
+ get_records_from_dois
132
  )
133
  from styles import DATAMAP_CUSTOM_CSS
134
  from data_setup import (
 
237
 
238
 
239
  # else:
240
+ def create_embeddings(texts_to_embedd):
241
+ """Create embeddings for the input texts using the loaded model."""
242
+ return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
243
+
244
 
245
 
246
 
 
250
  sample_reduction_method, plot_time_checkbox,
251
  locally_approximate_publication_date_checkbox,
252
  download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
253
+ csv_upload, highlight_color,
254
  progress=gr.Progress()):
255
  """
256
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
 
267
  download_png_checkbox (bool): Whether to download PNG data
268
  citation_graph_checkbox (bool): Whether to add citation graph
269
  csv_upload (str): Path to uploaded CSV file
270
+ highlight_color (str): Color for highlighting points
271
  progress (gr.Progress): Gradio progress tracker
272
 
273
  Returns:
 
313
  records_df = pd.read_csv(csv_upload)
314
  filename = os.path.splitext(os.path.basename(csv_upload))[0]
315
 
316
+ # Check if this is a DOI-list CSV (single column, named 'doi' or similar)
317
+ if (len(records_df.columns) == 1 and records_df.columns[0].lower() in ['doi', 'dois']):
318
+ from openalex_utils import get_records_from_dois
319
+ doi_list = records_df.iloc[:,0].dropna().astype(str).tolist()
320
+ print(f"Detected DOI list with {len(doi_list)} DOIs. Downloading records from OpenAlex...")
321
+ records_df = get_records_from_dois(doi_list)
322
+ filename = f"doilist_{len(doi_list)}"
323
+ else:
324
+ # Convert *every* cell that looks like a serialized list/dict
325
+ def _try_parse_obj(cell):
326
+ if isinstance(cell, str):
327
+ txt = cell.strip()
328
+ if (txt.startswith('{') and txt.endswith('}')) or (txt.startswith('[') and txt.endswith(']')):
329
+ # Try JSON first
330
+ try:
331
+ return json.loads(txt)
332
+ except Exception:
333
+ pass
334
+ # Fallback to Python-repr (single quotes etc.)
335
+ try:
336
+ return ast.literal_eval(txt)
337
+ except Exception:
338
+ pass
339
+ return cell
340
+
341
+ records_df = records_df.map(_try_parse_obj)
342
+ print(records_df.head())
343
 
344
  else:
345
  error_message = f"Error: Unsupported file type. Please upload a CSV or PKL file."
 
470
 
471
  basedata_df['color'] = '#ced4d211'
472
 
473
+ highlight_color = rgba_to_hex(highlight_color)
474
+
475
  if not plot_time_checkbox:
476
+ records_df['color'] = highlight_color
477
  else:
478
  cmap = colormaps.haline
479
  if not locally_approximate_publication_date_checkbox:
 
535
  height=1000,
536
  point_radius_min_pixels=1,
537
  text_outline_width=5,
538
+ point_hover_color=highlight_color,
539
  point_radius_max_pixels=7,
540
  cmap=black_cmap,
541
  background_image=graph_file_name if citation_graph_checkbox else None,
 
821
  label="Upload your own CSV file downloaded via pyalex.",
822
  file_types=[".csv"],
823
  )
 
824
 
825
+ # --- Aesthetics Accordion ---
826
+ with gr.Accordion("Aesthetics", open=False):
827
+ highlight_color_picker = gr.ColorPicker(
828
+ label="Highlight Color",
829
+ value="#5e2784",
830
+ info="Choose the highlight color for your query points."
831
+ )
832
 
833
  with gr.Column(scale=2):
834
  html = gr.HTML(
 
872
 
873
  ## I want to use my own data!
874
 
875
+ Sure! You can upload csv-files produced by downloading records from OpenAlex using the pyalex package. You will need to provide at least the columns `id`, `title`, `publication_year`, `doi`, `abstract` or `abstract_inverted_index`, `referenced_works` and `primary_topic`. Alternatively, you can upload a csv-file with only the column `doi`, containing a column of DOIs. These will then be used to download the records from OpenAlex and then embed them on the map.
876
 
877
  </div>
878
  """)
 
913
  download_csv_checkbox,
914
  download_png_checkbox,
915
  citation_graph_checkbox,
916
+ csv_upload,
917
+ highlight_color_picker
918
  ],
919
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
920
  )
color_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def rgba_to_hex(color):
4
+ """Convert rgba or rgb string to hex, or return hex if already hex."""
5
+ if isinstance(color, str):
6
+ color = color.strip()
7
+ # If already hex
8
+ if color.startswith('#') and (len(color) == 7 or len(color) == 4):
9
+ return color
10
+ # If rgba or rgb
11
+ match = re.match(r"rgba?\\(([^)]+)\\)", color)
12
+ if match:
13
+ parts = match.group(1).split(',')
14
+ r = int(float(parts[0]))
15
+ g = int(float(parts[1]))
16
+ b = int(float(parts[2]))
17
+ return '#{:02x}{:02x}{:02x}'.format(r, g, b)
18
+ # fallback
19
+ return '#5e2784'
openalex_utils.py CHANGED
@@ -131,12 +131,6 @@ def process_records_to_df(records):
131
 
132
  return records_df
133
 
134
-
135
-
136
-
137
-
138
-
139
-
140
  def openalex_url_to_filename(url):
141
  """
142
  Convert an OpenAlex URL to a filename-safe string with timestamp.
@@ -197,4 +191,26 @@ def openalex_url_to_filename(url):
197
  if len(filename) > 255:
198
  filename = filename[:251] # leave room for potential extension
199
 
200
- return filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  return records_df
133
 
 
 
 
 
 
 
134
  def openalex_url_to_filename(url):
135
  """
136
  Convert an OpenAlex URL to a filename-safe string with timestamp.
 
191
  if len(filename) > 255:
192
  filename = filename[:251] # leave room for potential extension
193
 
194
+ return filename
195
+
196
+ def get_records_from_dois(doi_list, block_size=50):
197
+ """
198
+ Download OpenAlex records for a list of DOIs in blocks.
199
+ Args:
200
+ doi_list (list): List of DOIs (strings)
201
+ block_size (int): Number of DOIs to fetch per request (default 50)
202
+ Returns:
203
+ pd.DataFrame: DataFrame of OpenAlex records
204
+ """
205
+ from pyalex import Works
206
+ from tqdm import tqdm
207
+ all_records = []
208
+ for i in tqdm(range(0, len(doi_list), block_size)):
209
+ sublist = doi_list[i:i+block_size]
210
+ doi_str = "|".join(sublist)
211
+ try:
212
+ record_list = Works().filter(doi=doi_str).get(per_page=block_size)
213
+ all_records.extend(record_list)
214
+ except Exception as e:
215
+ print(f"Error fetching DOIs {sublist}: {e}")
216
+ return pd.DataFrame(all_records)