MaxNoichl commited on
Commit
20cd1f4
·
1 Parent(s): e5dee3b

Fixed categorical colors.

Browse files
Files changed (2) hide show
  1. app.py +83 -44
  2. openalex_utils.py +14 -3
app.py CHANGED
@@ -558,8 +558,9 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
558
  break
559
 
560
  if should_break_current_query:
561
- print(f"Successfully broke from page loop for query {i+1}")
562
- break
 
563
  # Continue to next query - don't break out of the main query loop
564
  print(f"Query completed in {time.time() - start_time:.2f} seconds")
565
  print(f"Total records collected: {len(records)}")
@@ -576,6 +577,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
576
 
577
  # Add query_index to the dataframe
578
  records_df['query_index'] = query_indices[:len(records_df)]
 
579
 
580
  if reduce_sample_checkbox and sample_reduction_method != "All" and sample_reduction_method != "n random samples":
581
  # Note: We skip "n random samples" here because PyAlex sampling is already done above
@@ -611,7 +613,9 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
611
  if sample_reduction_method == "First n samples":
612
  records_df = records_df.iloc[:sample_size]
613
  print(f"Records processed in {time.time() - processing_start:.2f} seconds")
614
-
 
 
615
  # Create embeddings - this happens regardless of data source
616
  embedding_start = time.time()
617
  progress(0.3, desc="Embedding Data...")
@@ -655,7 +659,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
655
  print('Highlight color:', highlight_color)
656
 
657
  # Check if we have multiple queries and categorical coloring is enabled
658
- urls = [url.strip() for url in text_input.split(';')] if text_input else ['']
659
  has_multiple_queries = len(urls) > 1 and not csv_upload
660
 
661
  if treat_as_categorical_checkbox and has_multiple_queries:
@@ -677,45 +681,29 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
677
  print(f"Warning: Could not load colormap '{selected_colormap_name}' for categorical coloring: {e}")
678
  # Fallback to default categorical colors
679
  categorical_colors = [
680
- '#e41a1c', # Red
681
- '#377eb8', # Blue
682
- '#4daf4a', # Green
683
- '#984ea3', # Purple
684
- '#ff7f00', # Orange
685
- '#ffff33', # Yellow
686
- '#a65628', # Brown
687
- '#f781bf', # Pink
688
- '#999999', # Gray
689
- '#66c2a5', # Teal
690
- '#fc8d62', # Light Orange
691
- '#8da0cb', # Light Blue
692
- '#e78ac3', # Light Pink
693
- '#a6d854', # Light Green
694
- '#ffd92f', # Light Yellow
695
- '#e5c494', # Beige
696
- '#b3b3b3', # Light Gray
697
- ]
698
  else:
699
  # Use default categorical colors
700
  categorical_colors = [
701
- '#e41a1c', # Red
702
- '#377eb8', # Blue
703
- '#4daf4a', # Green
704
- '#984ea3', # Purple
705
- '#ff7f00', # Orange
706
- '#ffff33', # Yellow
707
- '#a65628', # Brown
708
- '#f781bf', # Pink
709
- '#999999', # Gray
710
- '#66c2a5', # Teal
711
- '#fc8d62', # Light Orange
712
- '#8da0cb', # Light Blue
713
- '#e78ac3', # Light Pink
714
- '#a6d854', # Light Green
715
- '#ffd92f', # Light Yellow
716
- '#e5c494', # Beige
717
- '#b3b3b3', # Light Gray
718
- ]
719
 
720
  # Assign colors based on query_index
721
  query_color_map = {query_idx: categorical_colors[i % len(categorical_colors)]
@@ -813,19 +801,39 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
813
  color_mapping = {}
814
 
815
  # Get readable names for each query URL
 
816
  for i, query_idx in enumerate(unique_queries):
817
  try:
818
  if query_idx < len(urls):
819
  readable_name = openalex_url_to_readable_name(urls[query_idx])
820
- # Truncate long names for legend display
821
- if len(readable_name) > 25:
822
- readable_name = readable_name[:22] + "..."
 
 
823
  else:
824
  readable_name = f"Query {query_idx + 1}"
825
- except Exception:
826
  readable_name = f"Query {query_idx + 1}"
 
 
 
 
 
 
 
 
 
 
 
 
 
827
 
 
828
  color_mapping[readable_name] = query_color_map[query_idx]
 
 
 
829
 
830
  legend_html, legend_css = categorical_legend_html_css(
831
  color_mapping,
@@ -1043,6 +1051,37 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
1043
  alpha=0.8,
1044
  s=point_size
1045
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1046
  print(f"Scatter plot creation completed in {time.time() - scatter_start:.2f} seconds")
1047
 
1048
  # Save plot
 
558
  break
559
 
560
  if should_break_current_query:
561
+ print(f"Successfully downloaded target size for query {i+1}, moving to next query")
562
+ # Continue to next query instead of breaking the entire query loop
563
+ continue
564
  # Continue to next query - don't break out of the main query loop
565
  print(f"Query completed in {time.time() - start_time:.2f} seconds")
566
  print(f"Total records collected: {len(records)}")
 
577
 
578
  # Add query_index to the dataframe
579
  records_df['query_index'] = query_indices[:len(records_df)]
580
+
581
 
582
  if reduce_sample_checkbox and sample_reduction_method != "All" and sample_reduction_method != "n random samples":
583
  # Note: We skip "n random samples" here because PyAlex sampling is already done above
 
613
  if sample_reduction_method == "First n samples":
614
  records_df = records_df.iloc[:sample_size]
615
  print(f"Records processed in {time.time() - processing_start:.2f} seconds")
616
+
617
+ print(query_indices)
618
+ print(records_df)
619
  # Create embeddings - this happens regardless of data source
620
  embedding_start = time.time()
621
  progress(0.3, desc="Embedding Data...")
 
659
  print('Highlight color:', highlight_color)
660
 
661
  # Check if we have multiple queries and categorical coloring is enabled
662
+ # Note: urls was already parsed earlier in the function, so we should use that
663
  has_multiple_queries = len(urls) > 1 and not csv_upload
664
 
665
  if treat_as_categorical_checkbox and has_multiple_queries:
 
681
  print(f"Warning: Could not load colormap '{selected_colormap_name}' for categorical coloring: {e}")
682
  # Fallback to default categorical colors
683
  categorical_colors = [
684
+ "#80418F", # Plum
685
+ "#EDA958", # Earth Yellow
686
+ "#F35264", # Crayola Red
687
+ "#087CA7", # Cerulean
688
+ "#FA826B", # Salmon
689
+ "#475C8F", # Navy Blue
690
+ "#579DA3", # Moonstone Green
691
+ "#d61d22", # Bright Red
692
+ "#97bb3c", # Lime Green
693
+ ]
 
 
 
 
 
 
 
 
694
  else:
695
  # Use default categorical colors
696
  categorical_colors = [
697
+ "#80418F", # Plum
698
+ "#EDA958", # Earth Yellow
699
+ "#F35264", # Crayola Red
700
+ "#087CA7", # Cerulean
701
+ "#FA826B", # Salmon
702
+ "#475C8F", # Navy Blue
703
+ "#579DA3", # Moonstone Green
704
+ "#d61d22", # Bright Red
705
+ "#97bb3c", # Lime Green
706
+ ]
 
 
 
 
 
 
 
 
707
 
708
  # Assign colors based on query_index
709
  query_color_map = {query_idx: categorical_colors[i % len(categorical_colors)]
 
801
  color_mapping = {}
802
 
803
  # Get readable names for each query URL
804
+ used_names = set() # Track used names to ensure uniqueness
805
  for i, query_idx in enumerate(unique_queries):
806
  try:
807
  if query_idx < len(urls):
808
  readable_name = openalex_url_to_readable_name(urls[query_idx])
809
+ print(f"Query {query_idx}: Original readable name: '{readable_name}'")
810
+ # Truncate long names for legend display (increased from 25 to 40 chars)
811
+ if len(readable_name) > 40:
812
+ readable_name = readable_name[:37] + "..."
813
+ print(f"Query {query_idx}: Truncated to: '{readable_name}'")
814
  else:
815
  readable_name = f"Query {query_idx + 1}"
816
+ except Exception as e:
817
  readable_name = f"Query {query_idx + 1}"
818
+ print(f"Query {query_idx}: Exception generating name: {e}")
819
+
820
+ # Ensure uniqueness - if name is already used, append query number
821
+ original_name = readable_name
822
+ counter = 1
823
+ while readable_name in used_names:
824
+ print(f"Query {query_idx}: Name '{readable_name}' already used, making unique...")
825
+ readable_name = f"{original_name} ({query_idx + 1})"
826
+ if len(readable_name) > 40:
827
+ # Re-truncate if needed after adding query number
828
+ base_name = original_name[:32] + "..."
829
+ readable_name = f"{base_name} ({query_idx + 1})"
830
+ counter += 1
831
 
832
+ used_names.add(readable_name)
833
  color_mapping[readable_name] = query_color_map[query_idx]
834
+ print(f"Query {query_idx}: Final legend name: '{readable_name}' -> color: {query_color_map[query_idx]}")
835
+
836
+ print(f"Final color mapping: {color_mapping}")
837
 
838
  legend_html, legend_css = categorical_legend_html_css(
839
  color_mapping,
 
1051
  alpha=0.8,
1052
  s=point_size
1053
  )
1054
+ # Add legend for categorical coloring (not time-based)
1055
+ if plot_type_dropdown != "Time-based coloring" and treat_as_categorical_checkbox and has_multiple_queries:
1056
+ # Get unique categories and their colors from the color mapping dict
1057
+ unique_categories = records_df['query_index'].unique()
1058
+
1059
+ # Create legend handles with larger point size using the color mapping
1060
+ legend_handles = []
1061
+ for query_idx in sorted(unique_categories):
1062
+ # Get the readable name for this query
1063
+ try:
1064
+ if query_idx < len(urls):
1065
+ readable_name = openalex_url_to_readable_name(urls[query_idx])
1066
+ # Truncate long names for legend display
1067
+ if len(readable_name) > 40:
1068
+ readable_name = readable_name[:37] + "..."
1069
+ else:
1070
+ readable_name = f"Query {query_idx + 1}"
1071
+ except Exception as e:
1072
+ readable_name = f"Query {query_idx + 1}"
1073
+
1074
+ color = query_color_map[query_idx]
1075
+ legend_handles.append(plt.Line2D([0], [0], marker='o', color='w',
1076
+ markerfacecolor=color, markersize=9,
1077
+ label=readable_name, linestyle='None'))
1078
+
1079
+ # Add legend in upper left corner
1080
+ plt.legend(handles=legend_handles, loc='upper left', frameon=False,
1081
+ fancybox=False, shadow=False, framealpha=0.9, fontsize=9,
1082
+ #prop={'weight': 'bold'}
1083
+ )
1084
+
1085
  print(f"Scatter plot creation completed in {time.time() - scatter_start:.2f} seconds")
1086
 
1087
  # Save plot
openalex_utils.py CHANGED
@@ -258,6 +258,12 @@ def openalex_url_to_readable_name(url):
258
  search_term = value.strip('"\'')
259
  parts.append(f"Search: '{search_term}'")
260
 
 
 
 
 
 
 
261
  elif key == 'publication_year':
262
  # Handle year ranges or single years
263
  if '-' in value:
@@ -348,8 +354,13 @@ def openalex_url_to_readable_name(url):
348
 
349
  else:
350
  # Generic handling for other filters
 
351
  clean_key = key.replace('_', ' ').replace('.', ' ').title()
352
- clean_value = value.replace('_', ' ')
 
 
 
 
353
  parts.append(f"{clean_key}: {clean_value}")
354
 
355
  except Exception as e:
@@ -370,7 +381,7 @@ def openalex_url_to_readable_name(url):
370
  description = f"Works from {year_range}"
371
 
372
  # Limit length to keep it readable
373
- if len(description) > 100:
374
- description = description[:97] + "..."
375
 
376
  return description
 
258
  search_term = value.strip('"\'')
259
  parts.append(f"Search: '{search_term}'")
260
 
261
+ elif key == 'title_and_abstract.search':
262
+ # Handle title and abstract search specifically
263
+ from urllib.parse import unquote_plus
264
+ search_term = unquote_plus(value).strip('"\'')
265
+ parts.append(f"T&A: '{search_term}'")
266
+
267
  elif key == 'publication_year':
268
  # Handle year ranges or single years
269
  if '-' in value:
 
354
 
355
  else:
356
  # Generic handling for other filters
357
+ from urllib.parse import unquote_plus
358
  clean_key = key.replace('_', ' ').replace('.', ' ').title()
359
+ # Properly decode URL-encoded values
360
+ try:
361
+ clean_value = unquote_plus(value).replace('_', ' ')
362
+ except:
363
+ clean_value = value.replace('_', ' ')
364
  parts.append(f"{clean_key}: {clean_value}")
365
 
366
  except Exception as e:
 
381
  description = f"Works from {year_range}"
382
 
383
  # Limit length to keep it readable
384
+ if len(description) > 60:
385
+ description = description[:57] + "..."
386
 
387
  return description