Spaces:
Running
on
Zero
Running
on
Zero
Enhance data processing by adding CSV and Pickle file upload support, improving error handling, and refining the prediction pipeline. Update the `process_records_to_df` function to handle existing DataFrames and ensure required fields are processed correctly.
Browse files- app.py +182 -129
- openalex_utils.py +23 -5
app.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1 |
import time
|
2 |
print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
3 |
-
|
4 |
# Standard library imports
|
5 |
import os
|
6 |
from pathlib import Path
|
7 |
from datetime import datetime
|
8 |
from itertools import chain
|
|
|
9 |
|
10 |
import base64
|
11 |
import json
|
|
|
12 |
|
13 |
# Third-party imports
|
14 |
import numpy as np
|
@@ -169,13 +171,13 @@ else:
|
|
169 |
|
170 |
|
171 |
|
172 |
-
|
173 |
-
|
174 |
|
175 |
def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
|
176 |
sample_reduction_method, plot_time_checkbox,
|
177 |
locally_approximate_publication_date_checkbox,
|
178 |
download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
|
|
|
179 |
progress=gr.Progress()):
|
180 |
"""
|
181 |
Main prediction pipeline that processes OpenAlex queries and creates visualizations.
|
@@ -188,11 +190,28 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
188 |
sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results")
|
189 |
plot_time_checkbox (bool): Whether to color points by publication date
|
190 |
locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting.
|
|
|
|
|
|
|
|
|
191 |
progress (gr.Progress): Gradio progress tracker
|
192 |
|
193 |
Returns:
|
194 |
tuple: (link to visualization, iframe HTML)
|
195 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
# Get the authentication token
|
197 |
if is_running_in_hf_space():
|
198 |
token = _get_token(request)
|
@@ -208,103 +227,145 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
208 |
else:
|
209 |
user_type = "registered"
|
210 |
print(f"User type: {user_type}")
|
211 |
-
|
212 |
-
|
213 |
-
# Check if input is empty or whitespace
|
214 |
-
print(f"Input: {text_input}")
|
215 |
-
if not text_input or text_input.isspace():
|
216 |
-
error_message = "Error: Please enter a valid OpenAlex URL in the 'OpenAlex-search URL'-field"
|
217 |
-
return [
|
218 |
-
error_message, # iframe HTML
|
219 |
-
gr.DownloadButton(label="Download Interactive Visualization", value='html_file_path', visible=False), # html download
|
220 |
-
gr.DownloadButton(label="Download CSV Data", value='csv_file_path', visible=False), # csv download
|
221 |
-
gr.DownloadButton(label="Download Static Plot", value='png_file_path', visible=False), # png download
|
222 |
-
gr.Button(visible=False) # cancel button state
|
223 |
-
]
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
for record in page:
|
265 |
-
records.append(record)
|
266 |
-
records_per_query += 1
|
267 |
-
progress(0.1 + (0.2 * len(records) / (total_query_length)),
|
268 |
-
desc=f"Getting data from query {i+1}/{len(urls)}...")
|
269 |
-
|
270 |
-
if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
|
271 |
-
should_break = True
|
272 |
-
break
|
273 |
-
# If we get here without an exception, break the retry loop
|
274 |
-
break
|
275 |
-
except Exception as e:
|
276 |
-
print(f"Error processing page: {e}")
|
277 |
-
if retry_attempt < max_retries - 1:
|
278 |
-
wait_time = base_wait_time * (exponent ** retry_attempt) + random.random()
|
279 |
-
print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...")
|
280 |
-
time.sleep(wait_time)
|
281 |
-
else:
|
282 |
-
print(f"Maximum retries reached. Continuing with next page.")
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
if should_break:
|
285 |
break
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
# Create embeddings
|
303 |
embedding_start = time.time()
|
304 |
progress(0.3, desc="Embedding Data...")
|
305 |
-
texts_to_embedd = [f"{title} {abstract}" for title, abstract
|
306 |
-
in zip(records_df['title'], records_df['abstract'])]
|
307 |
-
|
308 |
|
309 |
if is_running_in_hf_space():
|
310 |
if len(texts_to_embedd) < 2000:
|
@@ -357,13 +418,18 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
357 |
norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
|
358 |
records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
|
359 |
|
360 |
-
|
361 |
-
|
362 |
stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
|
363 |
stacked_df = stacked_df.fillna("Unlabelled")
|
364 |
stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()]
|
365 |
extra_data = pd.DataFrame(stacked_df['doi'])
|
366 |
print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
if citation_graph_checkbox:
|
368 |
citation_graph_start = time.time()
|
369 |
citation_graph = create_citation_graph(records_df)
|
@@ -372,9 +438,6 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
372 |
draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True,
|
373 |
min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])])
|
374 |
print(f"Citation graph created and saved in {time.time() - citation_graph_start:.2f} seconds")
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
|
379 |
# Create and save plot
|
380 |
plot_start = time.time()
|
@@ -382,10 +445,9 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
382 |
# Create a solid black colormap
|
383 |
black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000'])
|
384 |
|
385 |
-
|
386 |
plot = datamapplot.create_interactive_plot(
|
387 |
stacked_df[['x','y']].values,
|
388 |
-
|
389 |
np.array(['Unlabelled' if pd.isna(x) else x for x in stacked_df['parsed_field']]),
|
390 |
|
391 |
hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
|
@@ -413,23 +475,16 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
413 |
)
|
414 |
|
415 |
# Save plot
|
416 |
-
html_file_name = f"{filename}.html"
|
417 |
-
html_file_path = static_dir / html_file_name
|
418 |
plot.save(html_file_path)
|
419 |
print(f"Plot created and saved in {time.time() - plot_start:.2f} seconds")
|
420 |
-
|
421 |
|
422 |
-
#datamapplot==0.5.1
|
423 |
# Save additional files if requested
|
424 |
-
csv_file_path = static_dir / f"{filename}.csv"
|
425 |
-
png_file_path = static_dir / f"{filename}.png"
|
426 |
-
|
427 |
if download_csv_checkbox:
|
428 |
# Export relevant column
|
429 |
export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
|
430 |
-
export_df['parsed_field'] =
|
431 |
export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
|
432 |
-
if locally_approximate_publication_date_checkbox:
|
433 |
export_df['approximate_publication_year'] = local_years
|
434 |
export_df.to_csv(csv_file_path, index=False)
|
435 |
|
@@ -453,17 +508,10 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
453 |
|
454 |
# Replace less common labels with 'Unlabelled'
|
455 |
combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels])
|
456 |
-
#combined_labels = np.array(['Unlabelled' for label in combined_labels])
|
457 |
-
#if label not in top_30_labels else label
|
458 |
colors_base = ['#536878' for _ in range(len(labels1))]
|
459 |
print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds")
|
460 |
|
461 |
# Create main plot
|
462 |
-
print(labels1)
|
463 |
-
print(labels2)
|
464 |
-
print(sample_to_plot[['x','y']].values)
|
465 |
-
print(combined_labels)
|
466 |
-
|
467 |
main_plot_start = time.time()
|
468 |
fig, ax = datamapplot.create_plot(
|
469 |
sample_to_plot[['x','y']].values,
|
@@ -487,15 +535,12 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
487 |
)
|
488 |
print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds")
|
489 |
|
490 |
-
|
491 |
if citation_graph_checkbox:
|
492 |
-
|
493 |
# Read and add the graph image
|
494 |
graph_img = plt.imread(graph_file_path)
|
495 |
ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])],
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
if len(records_df) > 50_000:
|
500 |
point_size = .5
|
501 |
elif len(records_df) > 10_000:
|
@@ -539,17 +584,12 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
539 |
# Save plot
|
540 |
save_start = time.time()
|
541 |
plt.axis('off')
|
542 |
-
png_file_path = static_dir / f"{filename}.png"
|
543 |
plt.savefig(png_file_path, dpi=300, bbox_inches='tight')
|
544 |
plt.close()
|
545 |
print(f"Plot saving completed in {time.time() - save_start:.2f} seconds")
|
546 |
|
547 |
print(f"Total PNG generation completed in {time.time() - png_start_time:.2f} seconds")
|
548 |
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
progress(1.0, desc="Done!")
|
554 |
print(f"Total pipeline completed in {time.time() - start_time:.2f} seconds")
|
555 |
iframe = f"""<iframe src="{html_file_path}" width="100%" height="1000px"></iframe>"""
|
@@ -650,10 +690,16 @@ with gr.Blocks(theme=theme, css="""
|
|
650 |
value="First n samples",
|
651 |
info="How to choose the samples to keep."
|
652 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
653 |
sample_size_slider = gr.Slider(
|
654 |
label="Sample Size",
|
655 |
minimum=500,
|
656 |
-
maximum=
|
657 |
step=10,
|
658 |
value=1000,
|
659 |
info="How many samples to keep.",
|
@@ -691,6 +737,12 @@ with gr.Blocks(theme=theme, css="""
|
|
691 |
info="Adds a citation graph of the sample to the plot."
|
692 |
)
|
693 |
|
|
|
|
|
|
|
|
|
|
|
|
|
694 |
|
695 |
|
696 |
with gr.Column(scale=2):
|
@@ -706,7 +758,7 @@ with gr.Blocks(theme=theme, css="""
|
|
706 |
|
707 |
## Who made this?
|
708 |
|
709 |
-
This project was developed by [Maximilian Noichl](https://maxnoichl.eu) (Utrecht University), in cooperation with Andrea
|
710 |
|
711 |
This project received funding from the European Research Council under the European Union's Horizon 2020 research and innovation programme (LIFEMODE project, grant agreement No. 818772).
|
712 |
|
@@ -770,7 +822,8 @@ with gr.Blocks(theme=theme, css="""
|
|
770 |
locally_approximate_publication_date_checkbox,
|
771 |
download_csv_checkbox,
|
772 |
download_png_checkbox,
|
773 |
-
citation_graph_checkbox
|
|
|
774 |
],
|
775 |
outputs=[html, html_download, csv_download, png_download, cancel_btn]
|
776 |
)
|
@@ -811,4 +864,4 @@ if __name__ == "__main__":
|
|
811 |
os.environ["GRADIO_SSR_MODE"] = "True"
|
812 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
813 |
else:
|
814 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
import time
|
2 |
print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
3 |
+
# source openalex_env_map/bin/activate
|
4 |
# Standard library imports
|
5 |
import os
|
6 |
from pathlib import Path
|
7 |
from datetime import datetime
|
8 |
from itertools import chain
|
9 |
+
import ast # Add this import at the top with the standard library imports
|
10 |
|
11 |
import base64
|
12 |
import json
|
13 |
+
import pickle
|
14 |
|
15 |
# Third-party imports
|
16 |
import numpy as np
|
|
|
171 |
|
172 |
|
173 |
|
174 |
+
|
|
|
175 |
|
176 |
def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
|
177 |
sample_reduction_method, plot_time_checkbox,
|
178 |
locally_approximate_publication_date_checkbox,
|
179 |
download_csv_checkbox, download_png_checkbox, citation_graph_checkbox,
|
180 |
+
csv_upload,
|
181 |
progress=gr.Progress()):
|
182 |
"""
|
183 |
Main prediction pipeline that processes OpenAlex queries and creates visualizations.
|
|
|
190 |
sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results")
|
191 |
plot_time_checkbox (bool): Whether to color points by publication date
|
192 |
locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting.
|
193 |
+
download_csv_checkbox (bool): Whether to download CSV data
|
194 |
+
download_png_checkbox (bool): Whether to download PNG data
|
195 |
+
citation_graph_checkbox (bool): Whether to add citation graph
|
196 |
+
csv_upload (str): Path to uploaded CSV file
|
197 |
progress (gr.Progress): Gradio progress tracker
|
198 |
|
199 |
Returns:
|
200 |
tuple: (link to visualization, iframe HTML)
|
201 |
"""
|
202 |
+
# Initialize start_time at the beginning of the function
|
203 |
+
start_time = time.time()
|
204 |
+
|
205 |
+
# Helper function to generate error responses
|
206 |
+
def create_error_response(error_message):
|
207 |
+
return [
|
208 |
+
error_message,
|
209 |
+
gr.DownloadButton(label="Download Interactive Visualization", value='html_file_path', visible=False),
|
210 |
+
gr.DownloadButton(label="Download CSV Data", value='csv_file_path', visible=False),
|
211 |
+
gr.DownloadButton(label="Download Static Plot", value='png_file_path', visible=False),
|
212 |
+
gr.Button(visible=False)
|
213 |
+
]
|
214 |
+
|
215 |
# Get the authentication token
|
216 |
if is_running_in_hf_space():
|
217 |
token = _get_token(request)
|
|
|
227 |
else:
|
228 |
user_type = "registered"
|
229 |
print(f"User type: {user_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
+
# Check if a file has been uploaded or if we need to use OpenAlex query
|
232 |
+
if csv_upload is not None:
|
233 |
+
print(f"Using uploaded file instead of OpenAlex query: {csv_upload}")
|
234 |
+
try:
|
235 |
+
file_extension = os.path.splitext(csv_upload)[1].lower()
|
236 |
+
|
237 |
+
if file_extension == '.csv':
|
238 |
+
# Read the CSV file
|
239 |
+
records_df = pd.read_csv(csv_upload)
|
240 |
+
filename = os.path.splitext(os.path.basename(csv_upload))[0]
|
241 |
+
|
242 |
+
# Process dictionary-like strings in the DataFrame
|
243 |
+
for column in records_df.columns:
|
244 |
+
# Check if the column contains dictionary-like strings
|
245 |
+
if records_df[column].dtype == 'object':
|
246 |
+
try:
|
247 |
+
# Use a sample value to check if it looks like a dictionary or list
|
248 |
+
sample_value = records_df[column].dropna().iloc[0] if not records_df[column].dropna().empty else None
|
249 |
+
# Add type checking before using startswith
|
250 |
+
if isinstance(sample_value, str) and (sample_value.startswith('{') or sample_value.startswith('[')):
|
251 |
+
# Try to convert strings to Python objects using ast.literal_eval
|
252 |
+
records_df[column] = records_df[column].apply(
|
253 |
+
lambda x: ast.literal_eval(x) if isinstance(x, str) and (
|
254 |
+
(x.startswith('{') and x.endswith('}')) or
|
255 |
+
(x.startswith('[') and x.endswith(']'))
|
256 |
+
) else x
|
257 |
+
)
|
258 |
+
except (ValueError, SyntaxError, TypeError) as e:
|
259 |
+
# If conversion fails, keep as string
|
260 |
+
print(f"Could not convert column {column} to Python objects: {e}")
|
261 |
+
|
262 |
+
elif file_extension == '.pkl':
|
263 |
+
# Read the pickle file
|
264 |
+
with open(csv_upload, 'rb') as f:
|
265 |
+
records_df = pickle.load(f)
|
266 |
+
filename = os.path.splitext(os.path.basename(csv_upload))[0]
|
267 |
+
|
268 |
+
else:
|
269 |
+
error_message = f"Error: Unsupported file type. Please upload a CSV or PKL file."
|
270 |
+
return create_error_response(error_message)
|
271 |
+
|
272 |
+
records_df = process_records_to_df(records_df)
|
273 |
+
|
274 |
+
# Make sure we have the required columns
|
275 |
+
required_columns = ['title', 'abstract', 'publication_year']
|
276 |
+
missing_columns = [col for col in required_columns if col not in records_df.columns]
|
277 |
+
|
278 |
+
if missing_columns:
|
279 |
+
error_message = f"Error: Uploaded file is missing required columns: {', '.join(missing_columns)}"
|
280 |
+
return create_error_response(error_message)
|
281 |
+
|
282 |
+
print(f"Successfully loaded {len(records_df)} records from uploaded file")
|
283 |
+
progress(0.2, desc="Processing uploaded data...")
|
284 |
+
|
285 |
+
except Exception as e:
|
286 |
+
error_message = f"Error processing uploaded file: {str(e)}"
|
287 |
+
return create_error_response(error_message)
|
288 |
+
else:
|
289 |
+
# Check if input is empty or whitespace
|
290 |
+
print(f"Input: {text_input}")
|
291 |
+
if not text_input or text_input.isspace():
|
292 |
+
error_message = "Error: Please enter a valid OpenAlex URL in the 'OpenAlex-search URL'-field or upload a CSV file"
|
293 |
+
return create_error_response(error_message)
|
294 |
+
|
295 |
+
print('Starting data projection pipeline')
|
296 |
+
progress(0.1, desc="Starting...")
|
297 |
+
|
298 |
+
# Split input into multiple URLs if present
|
299 |
+
urls = [url.strip() for url in text_input.split(';')]
|
300 |
+
records = []
|
301 |
+
total_query_length = 0
|
302 |
|
303 |
+
# Use first URL for filename
|
304 |
+
first_query, first_params = openalex_url_to_pyalex_query(urls[0])
|
305 |
+
filename = openalex_url_to_filename(urls[0])
|
306 |
+
print(f"Filename: {filename}")
|
307 |
+
|
308 |
+
# Process each URL
|
309 |
+
for i, url in enumerate(urls):
|
310 |
+
query, params = openalex_url_to_pyalex_query(url)
|
311 |
+
query_length = query.count()
|
312 |
+
total_query_length += query_length
|
313 |
+
print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
|
314 |
|
315 |
+
target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
|
316 |
+
records_per_query = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
+
should_break = False
|
319 |
+
for page in query.paginate(per_page=200, n_max=None):
|
320 |
+
# Add retry mechanism for processing each page
|
321 |
+
max_retries = 5
|
322 |
+
base_wait_time = 1 # Starting wait time in seconds
|
323 |
+
exponent = 1.5 # Exponential factor
|
324 |
+
|
325 |
+
for retry_attempt in range(max_retries):
|
326 |
+
try:
|
327 |
+
for record in page:
|
328 |
+
records.append(record)
|
329 |
+
records_per_query += 1
|
330 |
+
progress(0.1 + (0.2 * len(records) / (total_query_length)),
|
331 |
+
desc=f"Getting data from query {i+1}/{len(urls)}...")
|
332 |
+
|
333 |
+
if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
|
334 |
+
should_break = True
|
335 |
+
break
|
336 |
+
# If we get here without an exception, break the retry loop
|
337 |
+
break
|
338 |
+
except Exception as e:
|
339 |
+
print(f"Error processing page: {e}")
|
340 |
+
if retry_attempt < max_retries - 1:
|
341 |
+
wait_time = base_wait_time * (exponent ** retry_attempt) + random.random()
|
342 |
+
print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...")
|
343 |
+
time.sleep(wait_time)
|
344 |
+
else:
|
345 |
+
print(f"Maximum retries reached. Continuing with next page.")
|
346 |
+
|
347 |
+
if should_break:
|
348 |
+
break
|
349 |
if should_break:
|
350 |
break
|
351 |
+
print(f"Query completed in {time.time() - start_time:.2f} seconds")
|
352 |
+
|
353 |
+
# Process records
|
354 |
+
processing_start = time.time()
|
355 |
+
records_df = process_records_to_df(records)
|
356 |
+
|
357 |
+
if reduce_sample_checkbox and sample_reduction_method != "All":
|
358 |
+
sample_size = min(sample_size_slider, len(records_df))
|
359 |
+
if sample_reduction_method == "n random samples":
|
360 |
+
records_df = records_df.sample(sample_size)
|
361 |
+
elif sample_reduction_method == "First n samples":
|
362 |
+
records_df = records_df.iloc[:sample_size]
|
363 |
+
print(f"Records processed in {time.time() - processing_start:.2f} seconds")
|
364 |
+
|
365 |
+
# Create embeddings - this happens regardless of data source
|
|
|
|
|
366 |
embedding_start = time.time()
|
367 |
progress(0.3, desc="Embedding Data...")
|
368 |
+
texts_to_embedd = [f"{title} {abstract}" for title, abstract in zip(records_df['title'], records_df['abstract'])]
|
|
|
|
|
369 |
|
370 |
if is_running_in_hf_space():
|
371 |
if len(texts_to_embedd) < 2000:
|
|
|
418 |
norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
|
419 |
records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
|
420 |
|
|
|
|
|
421 |
stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
|
422 |
stacked_df = stacked_df.fillna("Unlabelled")
|
423 |
stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()]
|
424 |
extra_data = pd.DataFrame(stacked_df['doi'])
|
425 |
print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
|
426 |
+
|
427 |
+
# Prepare file paths
|
428 |
+
html_file_name = f"{filename}.html"
|
429 |
+
html_file_path = static_dir / html_file_name
|
430 |
+
csv_file_path = static_dir / f"{filename}.csv"
|
431 |
+
png_file_path = static_dir / f"{filename}.png"
|
432 |
+
|
433 |
if citation_graph_checkbox:
|
434 |
citation_graph_start = time.time()
|
435 |
citation_graph = create_citation_graph(records_df)
|
|
|
438 |
draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True,
|
439 |
min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])])
|
440 |
print(f"Citation graph created and saved in {time.time() - citation_graph_start:.2f} seconds")
|
|
|
|
|
|
|
441 |
|
442 |
# Create and save plot
|
443 |
plot_start = time.time()
|
|
|
445 |
# Create a solid black colormap
|
446 |
black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000'])
|
447 |
|
|
|
448 |
plot = datamapplot.create_interactive_plot(
|
449 |
stacked_df[['x','y']].values,
|
450 |
+
np.array(stacked_df['cluster_2_labels']),
|
451 |
np.array(['Unlabelled' if pd.isna(x) else x for x in stacked_df['parsed_field']]),
|
452 |
|
453 |
hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
|
|
|
475 |
)
|
476 |
|
477 |
# Save plot
|
|
|
|
|
478 |
plot.save(html_file_path)
|
479 |
print(f"Plot created and saved in {time.time() - plot_start:.2f} seconds")
|
|
|
480 |
|
|
|
481 |
# Save additional files if requested
|
|
|
|
|
|
|
482 |
if download_csv_checkbox:
|
483 |
# Export relevant column
|
484 |
export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
|
485 |
+
export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()]
|
486 |
export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
|
487 |
+
if locally_approximate_publication_date_checkbox and plot_time_checkbox:
|
488 |
export_df['approximate_publication_year'] = local_years
|
489 |
export_df.to_csv(csv_file_path, index=False)
|
490 |
|
|
|
508 |
|
509 |
# Replace less common labels with 'Unlabelled'
|
510 |
combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels])
|
|
|
|
|
511 |
colors_base = ['#536878' for _ in range(len(labels1))]
|
512 |
print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds")
|
513 |
|
514 |
# Create main plot
|
|
|
|
|
|
|
|
|
|
|
515 |
main_plot_start = time.time()
|
516 |
fig, ax = datamapplot.create_plot(
|
517 |
sample_to_plot[['x','y']].values,
|
|
|
535 |
)
|
536 |
print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds")
|
537 |
|
|
|
538 |
if citation_graph_checkbox:
|
|
|
539 |
# Read and add the graph image
|
540 |
graph_img = plt.imread(graph_file_path)
|
541 |
ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])],
|
542 |
+
alpha=0.9, aspect='auto')
|
543 |
+
|
|
|
544 |
if len(records_df) > 50_000:
|
545 |
point_size = .5
|
546 |
elif len(records_df) > 10_000:
|
|
|
584 |
# Save plot
|
585 |
save_start = time.time()
|
586 |
plt.axis('off')
|
|
|
587 |
plt.savefig(png_file_path, dpi=300, bbox_inches='tight')
|
588 |
plt.close()
|
589 |
print(f"Plot saving completed in {time.time() - save_start:.2f} seconds")
|
590 |
|
591 |
print(f"Total PNG generation completed in {time.time() - png_start_time:.2f} seconds")
|
592 |
|
|
|
|
|
|
|
|
|
593 |
progress(1.0, desc="Done!")
|
594 |
print(f"Total pipeline completed in {time.time() - start_time:.2f} seconds")
|
595 |
iframe = f"""<iframe src="{html_file_path}" width="100%" height="1000px"></iframe>"""
|
|
|
690 |
value="First n samples",
|
691 |
info="How to choose the samples to keep."
|
692 |
)
|
693 |
+
|
694 |
+
if is_running_in_hf_zero_gpu():
|
695 |
+
max_sample_size = 20000
|
696 |
+
else:
|
697 |
+
max_sample_size = 250000
|
698 |
+
|
699 |
sample_size_slider = gr.Slider(
|
700 |
label="Sample Size",
|
701 |
minimum=500,
|
702 |
+
maximum=max_sample_size,
|
703 |
step=10,
|
704 |
value=1000,
|
705 |
info="How many samples to keep.",
|
|
|
737 |
info="Adds a citation graph of the sample to the plot."
|
738 |
)
|
739 |
|
740 |
+
gr.Markdown("### Upload Your Own Data")
|
741 |
+
csv_upload = gr.File(
|
742 |
+
file_count="single",
|
743 |
+
label="Upload your own CSV or Pickle file downloaded via pyalex.",
|
744 |
+
file_types=[".csv", ".pkl"],
|
745 |
+
)
|
746 |
|
747 |
|
748 |
with gr.Column(scale=2):
|
|
|
758 |
|
759 |
## Who made this?
|
760 |
|
761 |
+
This project was developed by [Maximilian Noichl](https://maxnoichl.eu) (Utrecht University), in cooperation with Andrea Loettgers and Tarja Knuuttila at the [Possible Life project](http://www.possiblelife.eu/), at the University of Vienna. If this project is useful in any way for your research, we would appreciate citation of **...**
|
762 |
|
763 |
This project received funding from the European Research Council under the European Union's Horizon 2020 research and innovation programme (LIFEMODE project, grant agreement No. 818772).
|
764 |
|
|
|
822 |
locally_approximate_publication_date_checkbox,
|
823 |
download_csv_checkbox,
|
824 |
download_png_checkbox,
|
825 |
+
citation_graph_checkbox,
|
826 |
+
csv_upload
|
827 |
],
|
828 |
outputs=[html, html_download, csv_download, png_download, cancel_btn]
|
829 |
)
|
|
|
864 |
os.environ["GRADIO_SSR_MODE"] = "True"
|
865 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
866 |
else:
|
867 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
openalex_utils.py
CHANGED
@@ -83,23 +83,41 @@ def get_field(x):
|
|
83 |
def process_records_to_df(records):
|
84 |
"""
|
85 |
Convert OpenAlex records to a pandas DataFrame with processed fields.
|
|
|
86 |
|
87 |
Args:
|
88 |
-
records (list): List of OpenAlex record dictionaries
|
89 |
|
90 |
Returns:
|
91 |
pandas.DataFrame: Processed DataFrame with abstracts, publications, and titles
|
92 |
"""
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
|
|
97 |
records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
|
98 |
records_df['abstract'] = records_df['abstract'].fillna(' ')
|
99 |
records_df['title'] = records_df['title'].fillna(' ')
|
100 |
records_df = records_df.drop_duplicates(subset=['id']).reset_index(drop=True)
|
101 |
|
102 |
-
return records_df
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def openalex_url_to_filename(url):
|
105 |
"""
|
|
|
83 |
def process_records_to_df(records):
|
84 |
"""
|
85 |
Convert OpenAlex records to a pandas DataFrame with processed fields.
|
86 |
+
Can handle either raw OpenAlex records or an existing DataFrame.
|
87 |
|
88 |
Args:
|
89 |
+
records (list or pd.DataFrame): List of OpenAlex record dictionaries or existing DataFrame
|
90 |
|
91 |
Returns:
|
92 |
pandas.DataFrame: Processed DataFrame with abstracts, publications, and titles
|
93 |
"""
|
94 |
+
# If records is already a DataFrame, use it directly
|
95 |
+
if isinstance(records, pd.DataFrame):
|
96 |
+
records_df = records.copy()
|
97 |
+
# Only process abstract_inverted_index and primary_location if they exist
|
98 |
+
if 'abstract_inverted_index' in records_df.columns:
|
99 |
+
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
|
100 |
+
if 'primary_location' in records_df.columns:
|
101 |
+
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]
|
102 |
+
else:
|
103 |
+
# Process raw records as before
|
104 |
+
records_df = pd.DataFrame(records)
|
105 |
+
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
|
106 |
+
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]
|
107 |
|
108 |
+
# Fill missing values and deduplicate
|
109 |
records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
|
110 |
records_df['abstract'] = records_df['abstract'].fillna(' ')
|
111 |
records_df['title'] = records_df['title'].fillna(' ')
|
112 |
records_df = records_df.drop_duplicates(subset=['id']).reset_index(drop=True)
|
113 |
|
114 |
+
return records_df
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
|
122 |
def openalex_url_to_filename(url):
|
123 |
"""
|