Spaces:
Running
on
Zero
Running
on
Zero
Implement DOI list handling in CSV uploads and add highlight color feature in predictions. Refactor embedding creation function for clarity and enhance OpenAlex record fetching with new utility function.
Browse files- app.py +52 -32
- color_utils.py +19 -0
- openalex_utils.py +23 -7
app.py
CHANGED
@@ -51,6 +51,8 @@ import gradio as gr
|
|
51 |
print(f"Gradio version: {gr.__version__}")
|
52 |
|
53 |
import subprocess
|
|
|
|
|
54 |
|
55 |
def print_datamapplot_version():
|
56 |
try:
|
@@ -101,7 +103,7 @@ try:
|
|
101 |
except (ImportError, ModuleNotFoundError):
|
102 |
HAS_SPACES = False
|
103 |
|
104 |
-
# Provide a harmless fallback so decorators don
|
105 |
if not HAS_SPACES:
|
106 |
class _Dummy:
|
107 |
def GPU(self, *a, **k):
|
@@ -125,7 +127,8 @@ from openalex_utils import (
|
|
125 |
openalex_url_to_pyalex_query,
|
126 |
get_field,
|
127 |
process_records_to_df,
|
128 |
-
openalex_url_to_filename
|
|
|
129 |
)
|
130 |
from styles import DATAMAP_CUSTOM_CSS
|
131 |
from data_setup import (
|
@@ -234,10 +237,10 @@ def create_embeddings_299(texts_to_embedd):
|
|
234 |
|
235 |
|
236 |
# else:
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
|
242 |
|
243 |
|
@@ -247,7 +250,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
247 |
sample_reduction_method, plot_time_checkbox,
|
248 |
locally_approximate_publication_date_checkbox,
|
249 |
download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
|
250 |
-
csv_upload,
|
251 |
progress=gr.Progress()):
|
252 |
"""
|
253 |
Main prediction pipeline that processes OpenAlex queries and creates visualizations.
|
@@ -264,6 +267,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
264 |
download_png_checkbox (bool): Whether to download PNG data
|
265 |
citation_graph_checkbox (bool): Whether to add citation graph
|
266 |
csv_upload (str): Path to uploaded CSV file
|
|
|
267 |
progress (gr.Progress): Gradio progress tracker
|
268 |
|
269 |
Returns:
|
@@ -309,25 +313,33 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
309 |
records_df = pd.read_csv(csv_upload)
|
310 |
filename = os.path.splitext(os.path.basename(csv_upload))[0]
|
311 |
|
312 |
-
#
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
else:
|
333 |
error_message = f"Error: Unsupported file type. Please upload a CSV or PKL file."
|
@@ -458,8 +470,10 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
458 |
|
459 |
basedata_df['color'] = '#ced4d211'
|
460 |
|
|
|
|
|
461 |
if not plot_time_checkbox:
|
462 |
-
records_df['color'] =
|
463 |
else:
|
464 |
cmap = colormaps.haline
|
465 |
if not locally_approximate_publication_date_checkbox:
|
@@ -521,7 +535,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
521 |
height=1000,
|
522 |
point_radius_min_pixels=1,
|
523 |
text_outline_width=5,
|
524 |
-
point_hover_color=
|
525 |
point_radius_max_pixels=7,
|
526 |
cmap=black_cmap,
|
527 |
background_image=graph_file_name if citation_graph_checkbox else None,
|
@@ -807,9 +821,14 @@ with gr.Blocks(theme=theme, css="""
|
|
807 |
label="Upload your own CSV file downloaded via pyalex.",
|
808 |
file_types=[".csv"],
|
809 |
)
|
810 |
-
|
811 |
|
812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
|
814 |
with gr.Column(scale=2):
|
815 |
html = gr.HTML(
|
@@ -853,7 +872,7 @@ with gr.Blocks(theme=theme, css="""
|
|
853 |
|
854 |
## I want to use my own data!
|
855 |
|
856 |
-
Sure! You can upload csv-files produced by downloading
|
857 |
|
858 |
</div>
|
859 |
""")
|
@@ -894,7 +913,8 @@ with gr.Blocks(theme=theme, css="""
|
|
894 |
download_csv_checkbox,
|
895 |
download_png_checkbox,
|
896 |
citation_graph_checkbox,
|
897 |
-
csv_upload
|
|
|
898 |
],
|
899 |
outputs=[html, html_download, csv_download, png_download, cancel_btn]
|
900 |
)
|
|
|
51 |
print(f"Gradio version: {gr.__version__}")
|
52 |
|
53 |
import subprocess
|
54 |
+
import re
|
55 |
+
from color_utils import rgba_to_hex
|
56 |
|
57 |
def print_datamapplot_version():
|
58 |
try:
|
|
|
103 |
except (ImportError, ModuleNotFoundError):
|
104 |
HAS_SPACES = False
|
105 |
|
106 |
+
# Provide a harmless fallback so decorators don't explode
|
107 |
if not HAS_SPACES:
|
108 |
class _Dummy:
|
109 |
def GPU(self, *a, **k):
|
|
|
127 |
openalex_url_to_pyalex_query,
|
128 |
get_field,
|
129 |
process_records_to_df,
|
130 |
+
openalex_url_to_filename,
|
131 |
+
get_records_from_dois
|
132 |
)
|
133 |
from styles import DATAMAP_CUSTOM_CSS
|
134 |
from data_setup import (
|
|
|
237 |
|
238 |
|
239 |
# else:
|
240 |
+
def create_embeddings(texts_to_embedd):
|
241 |
+
"""Create embeddings for the input texts using the loaded model."""
|
242 |
+
return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
|
243 |
+
|
244 |
|
245 |
|
246 |
|
|
|
250 |
sample_reduction_method, plot_time_checkbox,
|
251 |
locally_approximate_publication_date_checkbox,
|
252 |
download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
|
253 |
+
csv_upload, highlight_color,
|
254 |
progress=gr.Progress()):
|
255 |
"""
|
256 |
Main prediction pipeline that processes OpenAlex queries and creates visualizations.
|
|
|
267 |
download_png_checkbox (bool): Whether to download PNG data
|
268 |
citation_graph_checkbox (bool): Whether to add citation graph
|
269 |
csv_upload (str): Path to uploaded CSV file
|
270 |
+
highlight_color (str): Color for highlighting points
|
271 |
progress (gr.Progress): Gradio progress tracker
|
272 |
|
273 |
Returns:
|
|
|
313 |
records_df = pd.read_csv(csv_upload)
|
314 |
filename = os.path.splitext(os.path.basename(csv_upload))[0]
|
315 |
|
316 |
+
# Check if this is a DOI-list CSV (single column, named 'doi' or similar)
|
317 |
+
if (len(records_df.columns) == 1 and records_df.columns[0].lower() in ['doi', 'dois']):
|
318 |
+
from openalex_utils import get_records_from_dois
|
319 |
+
doi_list = records_df.iloc[:,0].dropna().astype(str).tolist()
|
320 |
+
print(f"Detected DOI list with {len(doi_list)} DOIs. Downloading records from OpenAlex...")
|
321 |
+
records_df = get_records_from_dois(doi_list)
|
322 |
+
filename = f"doilist_{len(doi_list)}"
|
323 |
+
else:
|
324 |
+
# Convert *every* cell that looks like a serialized list/dict
|
325 |
+
def _try_parse_obj(cell):
|
326 |
+
if isinstance(cell, str):
|
327 |
+
txt = cell.strip()
|
328 |
+
if (txt.startswith('{') and txt.endswith('}')) or (txt.startswith('[') and txt.endswith(']')):
|
329 |
+
# Try JSON first
|
330 |
+
try:
|
331 |
+
return json.loads(txt)
|
332 |
+
except Exception:
|
333 |
+
pass
|
334 |
+
# Fallback to Python-repr (single quotes etc.)
|
335 |
+
try:
|
336 |
+
return ast.literal_eval(txt)
|
337 |
+
except Exception:
|
338 |
+
pass
|
339 |
+
return cell
|
340 |
+
|
341 |
+
records_df = records_df.map(_try_parse_obj)
|
342 |
+
print(records_df.head())
|
343 |
|
344 |
else:
|
345 |
error_message = f"Error: Unsupported file type. Please upload a CSV or PKL file."
|
|
|
470 |
|
471 |
basedata_df['color'] = '#ced4d211'
|
472 |
|
473 |
+
highlight_color = rgba_to_hex(highlight_color)
|
474 |
+
|
475 |
if not plot_time_checkbox:
|
476 |
+
records_df['color'] = highlight_color
|
477 |
else:
|
478 |
cmap = colormaps.haline
|
479 |
if not locally_approximate_publication_date_checkbox:
|
|
|
535 |
height=1000,
|
536 |
point_radius_min_pixels=1,
|
537 |
text_outline_width=5,
|
538 |
+
point_hover_color=highlight_color,
|
539 |
point_radius_max_pixels=7,
|
540 |
cmap=black_cmap,
|
541 |
background_image=graph_file_name if citation_graph_checkbox else None,
|
|
|
821 |
label="Upload your own CSV file downloaded via pyalex.",
|
822 |
file_types=[".csv"],
|
823 |
)
|
|
|
824 |
|
825 |
+
# --- Aesthetics Accordion ---
|
826 |
+
with gr.Accordion("Aesthetics", open=False):
|
827 |
+
highlight_color_picker = gr.ColorPicker(
|
828 |
+
label="Highlight Color",
|
829 |
+
value="#5e2784",
|
830 |
+
info="Choose the highlight color for your query points."
|
831 |
+
)
|
832 |
|
833 |
with gr.Column(scale=2):
|
834 |
html = gr.HTML(
|
|
|
872 |
|
873 |
## I want to use my own data!
|
874 |
|
875 |
+
Sure! You can upload csv-files produced by downloading records from OpenAlex using the pyalex package. You will need to provide at least the columns `id`, `title`, `publication_year`, `doi`, `abstract` or `abstract_inverted_index`, `referenced_works` and `primary_topic`. Alternatively, you can upload a csv-file with only the column `doi`, containing a column of DOIs. These will then be used to download the records from OpenAlex and then embed them on the map.
|
876 |
|
877 |
</div>
|
878 |
""")
|
|
|
913 |
download_csv_checkbox,
|
914 |
download_png_checkbox,
|
915 |
citation_graph_checkbox,
|
916 |
+
csv_upload,
|
917 |
+
highlight_color_picker
|
918 |
],
|
919 |
outputs=[html, html_download, csv_download, png_download, cancel_btn]
|
920 |
)
|
color_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def rgba_to_hex(color):
|
4 |
+
"""Convert rgba or rgb string to hex, or return hex if already hex."""
|
5 |
+
if isinstance(color, str):
|
6 |
+
color = color.strip()
|
7 |
+
# If already hex
|
8 |
+
if color.startswith('#') and (len(color) == 7 or len(color) == 4):
|
9 |
+
return color
|
10 |
+
# If rgba or rgb
|
11 |
+
match = re.match(r"rgba?\\(([^)]+)\\)", color)
|
12 |
+
if match:
|
13 |
+
parts = match.group(1).split(',')
|
14 |
+
r = int(float(parts[0]))
|
15 |
+
g = int(float(parts[1]))
|
16 |
+
b = int(float(parts[2]))
|
17 |
+
return '#{:02x}{:02x}{:02x}'.format(r, g, b)
|
18 |
+
# fallback
|
19 |
+
return '#5e2784'
|
openalex_utils.py
CHANGED
@@ -131,12 +131,6 @@ def process_records_to_df(records):
|
|
131 |
|
132 |
return records_df
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
def openalex_url_to_filename(url):
|
141 |
"""
|
142 |
Convert an OpenAlex URL to a filename-safe string with timestamp.
|
@@ -197,4 +191,26 @@ def openalex_url_to_filename(url):
|
|
197 |
if len(filename) > 255:
|
198 |
filename = filename[:251] # leave room for potential extension
|
199 |
|
200 |
-
return filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
return records_df
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
def openalex_url_to_filename(url):
|
135 |
"""
|
136 |
Convert an OpenAlex URL to a filename-safe string with timestamp.
|
|
|
191 |
if len(filename) > 255:
|
192 |
filename = filename[:251] # leave room for potential extension
|
193 |
|
194 |
+
return filename
|
195 |
+
|
196 |
+
def get_records_from_dois(doi_list, block_size=50):
|
197 |
+
"""
|
198 |
+
Download OpenAlex records for a list of DOIs in blocks.
|
199 |
+
Args:
|
200 |
+
doi_list (list): List of DOIs (strings)
|
201 |
+
block_size (int): Number of DOIs to fetch per request (default 50)
|
202 |
+
Returns:
|
203 |
+
pd.DataFrame: DataFrame of OpenAlex records
|
204 |
+
"""
|
205 |
+
from pyalex import Works
|
206 |
+
from tqdm import tqdm
|
207 |
+
all_records = []
|
208 |
+
for i in tqdm(range(0, len(doi_list), block_size)):
|
209 |
+
sublist = doi_list[i:i+block_size]
|
210 |
+
doi_str = "|".join(sublist)
|
211 |
+
try:
|
212 |
+
record_list = Works().filter(doi=doi_str).get(per_page=block_size)
|
213 |
+
all_records.extend(record_list)
|
214 |
+
except Exception as e:
|
215 |
+
print(f"Error fetching DOIs {sublist}: {e}")
|
216 |
+
return pd.DataFrame(all_records)
|