Spaces:
Sleeping
Sleeping
Update src/vis_utils.py
Browse files- src/vis_utils.py +51 -57
src/vis_utils.py
CHANGED
@@ -14,81 +14,77 @@ sys.path.append('..')
|
|
14 |
sys.path.append('.')
|
15 |
|
16 |
from about import *
|
|
|
17 |
|
18 |
-
global data_component, filter_component
|
19 |
|
20 |
-
def get_baseline_df(selected_methods, selected_metrics):
|
21 |
-
df = pd.read_csv(CSV_RESULT_PATH)
|
22 |
-
present_columns = ["method_name"] + selected_metrics
|
23 |
-
df = df[df['method_name'].isin(selected_methods)][present_columns]
|
24 |
-
return df
|
25 |
|
26 |
-
|
27 |
-
|
28 |
|
29 |
-
def set_colors_and_marks_for_representation_groups(ax):
|
30 |
-
for label in ax.get_xticklabels():
|
31 |
-
text = label.get_text()
|
32 |
-
color = group_color_dict.get(text, 'black') # Default to black if label not in dict
|
33 |
-
label.set_color(color)
|
34 |
-
label.set_fontweight('bold')
|
35 |
-
|
36 |
-
# Add a caret symbol to specific labels
|
37 |
-
if text in {'MUT2VEC', 'PFAM', 'GENE2VEC', 'BERT-PFAM'}:
|
38 |
-
label.set_text(f"^ {text}")
|
39 |
|
40 |
def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
|
41 |
if benchmark_type == 'similarity':
|
42 |
-
|
43 |
-
return plot_similarity_results(methods_selected, x_metric, y_metric, title)
|
44 |
elif benchmark_type == 'function':
|
45 |
return plot_function_results("./data/function_results.csv", x_metric, y_metric, methods_selected)
|
46 |
elif benchmark_type == 'family':
|
47 |
return plot_family_results("./data/family_results.csv", methods_selected, x_metric, save_path="./plot_images")
|
48 |
elif benchmark_type == "affinity":
|
49 |
return plot_affinity_results("./data/affinity_results.csv", methods_selected, x_metric, save_path="./plot_images")
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
hue="method_name", # Different colors for different methods
|
65 |
-
marker="o", # Add markers to the line plot
|
66 |
-
)
|
67 |
|
68 |
-
#
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
def plot_similarity_results(methods_selected, x_metric, y_metric, title):
|
82 |
-
df = pd.read_csv(CSV_RESULT_PATH)
|
83 |
-
# Filter the dataframe based on selected methods
|
84 |
-
filtered_df = df[df['method_name'].isin(methods_selected)]
|
85 |
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Add a new column to the dataframe for the color
|
90 |
filtered_df['color'] = filtered_df['method_name'].apply(get_method_color)
|
91 |
|
|
|
|
|
92 |
adjust_text_dict = {
|
93 |
'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5),
|
94 |
'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center',
|
@@ -104,7 +100,7 @@ def plot_similarity_results(methods_selected, x_metric, y_metric, title):
|
|
104 |
label='method_name')) # Label each point by the method name
|
105 |
+ p9.geom_point(size=3) # Add points with no jitter, set point size
|
106 |
+ p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points
|
107 |
-
+ p9.labs(title=title, x=
|
108 |
+ p9.scale_color_identity() # Use colors directly from the dataframe
|
109 |
+ p9.theme(legend_position='none',
|
110 |
figure_size=(8, 8), # Set figure size
|
@@ -114,10 +110,8 @@ def plot_similarity_results(methods_selected, x_metric, y_metric, title):
|
|
114 |
)
|
115 |
|
116 |
# Save the plot as an image
|
117 |
-
save_path = "
|
118 |
-
os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
|
119 |
filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png")
|
120 |
-
|
121 |
g.save(filename=filename, dpi=400)
|
122 |
|
123 |
return filename
|
|
|
14 |
sys.path.append('.')
|
15 |
|
16 |
from about import *
|
17 |
+
from saving_utils import download_from_hub
|
18 |
|
|
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
global data_component, filter_component
|
22 |
+
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
|
26 |
if benchmark_type == 'similarity':
|
27 |
+
return plot_similarity_results(methods_selected, x_metric, y_metric)
|
|
|
28 |
elif benchmark_type == 'function':
|
29 |
return plot_function_results("./data/function_results.csv", x_metric, y_metric, methods_selected)
|
30 |
elif benchmark_type == 'family':
|
31 |
return plot_family_results("./data/family_results.csv", methods_selected, x_metric, save_path="./plot_images")
|
32 |
elif benchmark_type == "affinity":
|
33 |
return plot_affinity_results("./data/affinity_results.csv", methods_selected, x_metric, save_path="./plot_images")
|
34 |
+
|
35 |
+
return 0
|
36 |
+
|
37 |
+
def get_method_color(method):
|
38 |
+
return color_dict.get(method, 'black') # If method is not in color_dict, use black
|
39 |
+
|
40 |
+
|
41 |
+
def get_labels_and_title(x_metric, y_metric):
|
42 |
+
# Define mapping for long forms
|
43 |
+
long_form_mapping = {
|
44 |
+
"MF": "Molecular Function",
|
45 |
+
"BP": "Biological Process",
|
46 |
+
"CC": "Cellular Component"
|
47 |
+
}
|
|
|
|
|
|
|
48 |
|
49 |
+
# Parse the metrics
|
50 |
+
def parse_metric(metric):
|
51 |
+
parts = metric.split("_")
|
52 |
+
dataset = parts[0] # sparse/200/500
|
53 |
+
category = parts[1] # MF/BP/CC
|
54 |
+
measure = parts[2] # pvalue/correlation
|
55 |
+
return dataset, category, measure
|
56 |
|
57 |
+
x_dataset, x_category, x_measure = parse_metric(x_metric)
|
58 |
+
y_dataset, y_category, y_measure = parse_metric(y_metric)
|
59 |
+
|
60 |
+
# Determine the title
|
61 |
+
if x_category == y_category:
|
62 |
+
title = long_form_mapping[x_category]
|
63 |
+
else:
|
64 |
+
title = f"{long_form_mapping[x_category]} vs {long_form_mapping[y_category]}"
|
65 |
|
66 |
+
# Determine the axis labels
|
67 |
+
x_label = f"{x_measure.capitalize()} on {x_dataset.capitalize()} Dataset"
|
68 |
+
y_label = f"{y_measure.capitalize()} on {y_dataset.capitalize()} Dataset"
|
69 |
+
|
70 |
+
return title, x_label, y_label
|
71 |
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
def plot_similarity_results(methods_selected, x_metric, y_metric, similarity_path="/tmp/similarity_results.csv"):
|
74 |
+
if not os.path.exists(similarity_path):
|
75 |
+
benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
|
76 |
+
download_from_hub(benchmark_types)
|
77 |
+
|
78 |
+
similarity_df = pd.read_csv(similarity_path)
|
79 |
+
|
80 |
+
# Filter the dataframe based on selected methods
|
81 |
+
filtered_df = similarity_df[similarity_df['method_name'].isin(methods_selected)]
|
82 |
|
83 |
# Add a new column to the dataframe for the color
|
84 |
filtered_df['color'] = filtered_df['method_name'].apply(get_method_color)
|
85 |
|
86 |
+
title, x_label, y_label = generate_labels_and_title(x_metric, y_metric)
|
87 |
+
|
88 |
adjust_text_dict = {
|
89 |
'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5),
|
90 |
'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center',
|
|
|
100 |
label='method_name')) # Label each point by the method name
|
101 |
+ p9.geom_point(size=3) # Add points with no jitter, set point size
|
102 |
+ p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points
|
103 |
+
+ p9.labs(title=title, x=x_label, y=y_label) # Dynamic labels for X and Y axes
|
104 |
+ p9.scale_color_identity() # Use colors directly from the dataframe
|
105 |
+ p9.theme(legend_position='none',
|
106 |
figure_size=(8, 8), # Set figure size
|
|
|
110 |
)
|
111 |
|
112 |
# Save the plot as an image
|
113 |
+
save_path = "/tmp"
|
|
|
114 |
filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png")
|
|
|
115 |
g.save(filename=filename, dpi=400)
|
116 |
|
117 |
return filename
|