mgyigit commited on
Commit
4166fb4
·
verified ·
1 Parent(s): 2014ab8

Update src/vis_utils.py

Browse files
Files changed (1) hide show
  1. src/vis_utils.py +51 -57
src/vis_utils.py CHANGED
@@ -14,81 +14,77 @@ sys.path.append('..')
14
  sys.path.append('.')
15
 
16
  from about import *
 
17
 
18
- global data_component, filter_component
19
 
20
- def get_baseline_df(selected_methods, selected_metrics):
21
- df = pd.read_csv(CSV_RESULT_PATH)
22
- present_columns = ["method_name"] + selected_metrics
23
- df = df[df['method_name'].isin(selected_methods)][present_columns]
24
- return df
25
 
26
- def get_method_color(method):
27
- return color_dict.get(method, 'black') # If method is not in color_dict, use black
28
 
29
- def set_colors_and_marks_for_representation_groups(ax):
30
- for label in ax.get_xticklabels():
31
- text = label.get_text()
32
- color = group_color_dict.get(text, 'black') # Default to black if label not in dict
33
- label.set_color(color)
34
- label.set_fontweight('bold')
35
-
36
- # Add a caret symbol to specific labels
37
- if text in {'MUT2VEC', 'PFAM', 'GENE2VEC', 'BERT-PFAM'}:
38
- label.set_text(f"^ {text}")
39
 
40
  def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
41
  if benchmark_type == 'similarity':
42
- title = f"{x_metric} vs {y_metric}"
43
- return plot_similarity_results(methods_selected, x_metric, y_metric, title)
44
  elif benchmark_type == 'function':
45
  return plot_function_results("./data/function_results.csv", x_metric, y_metric, methods_selected)
46
  elif benchmark_type == 'family':
47
  return plot_family_results("./data/family_results.csv", methods_selected, x_metric, save_path="./plot_images")
48
  elif benchmark_type == "affinity":
49
  return plot_affinity_results("./data/affinity_results.csv", methods_selected, x_metric, save_path="./plot_images")
50
- else:
51
- # Use general visualizer logic
52
- return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
53
-
54
- def general_visualizer(methods_selected, x_metric, y_metric):
55
- df = pd.read_csv(CSV_RESULT_PATH)
56
- filtered_df = df[df['method_name'].isin(methods_selected)]
57
-
58
- # Create a Seaborn lineplot with method as hue
59
- plt.figure(figsize=(10, 8)) # Increase figure size
60
- sns.lineplot(
61
- data=filtered_df,
62
- x=x_metric,
63
- y=y_metric,
64
- hue="method_name", # Different colors for different methods
65
- marker="o", # Add markers to the line plot
66
- )
67
 
68
- # Add labels and title
69
- plt.xlabel(x_metric)
70
- plt.ylabel(y_metric)
71
- plt.title(f'{y_metric} vs {x_metric} for selected methods')
72
- plt.grid(True)
 
 
73
 
74
- # Save the plot to display it in Gradio
75
- plot_path = "plot.png"
76
- plt.savefig(plot_path)
77
- plt.close()
 
 
 
 
78
 
79
- return plot_path
 
 
 
 
80
 
81
- def plot_similarity_results(methods_selected, x_metric, y_metric, title):
82
- df = pd.read_csv(CSV_RESULT_PATH)
83
- # Filter the dataframe based on selected methods
84
- filtered_df = df[df['method_name'].isin(methods_selected)]
85
 
86
- def get_method_color(method):
87
- return color_dict.get(method.upper(), 'black')
 
 
 
 
 
 
 
88
 
89
  # Add a new column to the dataframe for the color
90
  filtered_df['color'] = filtered_df['method_name'].apply(get_method_color)
91
 
 
 
92
  adjust_text_dict = {
93
  'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5),
94
  'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center',
@@ -104,7 +100,7 @@ def plot_similarity_results(methods_selected, x_metric, y_metric, title):
104
  label='method_name')) # Label each point by the method name
105
  + p9.geom_point(size=3) # Add points with no jitter, set point size
106
  + p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points
107
- + p9.labs(title=title, x=f"{x_metric}", y=f"{y_metric}") # Dynamic labels for X and Y axes
108
  + p9.scale_color_identity() # Use colors directly from the dataframe
109
  + p9.theme(legend_position='none',
110
  figure_size=(8, 8), # Set figure size
@@ -114,10 +110,8 @@ def plot_similarity_results(methods_selected, x_metric, y_metric, title):
114
  )
115
 
116
  # Save the plot as an image
117
- save_path = "./plot_images" # Ensure this folder exists or adjust the path
118
- os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
119
  filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png")
120
-
121
  g.save(filename=filename, dpi=400)
122
 
123
  return filename
 
14
  sys.path.append('.')
15
 
16
  from about import *
17
+ from saving_utils import download_from_hub
18
 
 
19
 
 
 
 
 
 
20
 
21
+ global data_component, filter_component
22
+
23
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
26
  if benchmark_type == 'similarity':
27
+ return plot_similarity_results(methods_selected, x_metric, y_metric)
 
28
  elif benchmark_type == 'function':
29
  return plot_function_results("./data/function_results.csv", x_metric, y_metric, methods_selected)
30
  elif benchmark_type == 'family':
31
  return plot_family_results("./data/family_results.csv", methods_selected, x_metric, save_path="./plot_images")
32
  elif benchmark_type == "affinity":
33
  return plot_affinity_results("./data/affinity_results.csv", methods_selected, x_metric, save_path="./plot_images")
34
+
35
+ return 0
36
+
37
+ def get_method_color(method):
38
+ return color_dict.get(method, 'black') # If method is not in color_dict, use black
39
+
40
+
41
+ def get_labels_and_title(x_metric, y_metric):
42
+ # Define mapping for long forms
43
+ long_form_mapping = {
44
+ "MF": "Molecular Function",
45
+ "BP": "Biological Process",
46
+ "CC": "Cellular Component"
47
+ }
 
 
 
48
 
49
+ # Parse the metrics
50
+ def parse_metric(metric):
51
+ parts = metric.split("_")
52
+ dataset = parts[0] # sparse/200/500
53
+ category = parts[1] # MF/BP/CC
54
+ measure = parts[2] # pvalue/correlation
55
+ return dataset, category, measure
56
 
57
+ x_dataset, x_category, x_measure = parse_metric(x_metric)
58
+ y_dataset, y_category, y_measure = parse_metric(y_metric)
59
+
60
+ # Determine the title
61
+ if x_category == y_category:
62
+ title = long_form_mapping[x_category]
63
+ else:
64
+ title = f"{long_form_mapping[x_category]} vs {long_form_mapping[y_category]}"
65
 
66
+ # Determine the axis labels
67
+ x_label = f"{x_measure.capitalize()} on {x_dataset.capitalize()} Dataset"
68
+ y_label = f"{y_measure.capitalize()} on {y_dataset.capitalize()} Dataset"
69
+
70
+ return title, x_label, y_label
71
 
 
 
 
 
72
 
73
+ def plot_similarity_results(methods_selected, x_metric, y_metric, similarity_path="/tmp/similarity_results.csv"):
74
+ if not os.path.exists(similarity_path):
75
+ benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
76
+ download_from_hub(benchmark_types)
77
+
78
+ similarity_df = pd.read_csv(similarity_path)
79
+
80
+ # Filter the dataframe based on selected methods
81
+ filtered_df = similarity_df[similarity_df['method_name'].isin(methods_selected)]
82
 
83
  # Add a new column to the dataframe for the color
84
  filtered_df['color'] = filtered_df['method_name'].apply(get_method_color)
85
 
86
+ title, x_label, y_label = generate_labels_and_title(x_metric, y_metric)
87
+
88
  adjust_text_dict = {
89
  'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5),
90
  'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center',
 
100
  label='method_name')) # Label each point by the method name
101
  + p9.geom_point(size=3) # Add points with no jitter, set point size
102
  + p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points
103
+ + p9.labs(title=title, x=x_label, y=y_label) # Dynamic labels for X and Y axes
104
  + p9.scale_color_identity() # Use colors directly from the dataframe
105
  + p9.theme(legend_position='none',
106
  figure_size=(8, 8), # Set figure size
 
110
  )
111
 
112
  # Save the plot as an image
113
+ save_path = "/tmp"
 
114
  filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png")
 
115
  g.save(filename=filename, dpi=400)
116
 
117
  return filename