Spaces:
Sleeping
Sleeping
Update src/vis_utils.py
Browse files- src/vis_utils.py +85 -44
src/vis_utils.py
CHANGED
@@ -17,9 +17,67 @@ from about import *
|
|
17 |
|
18 |
global data_component, filter_component
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def get_method_color(method):
|
21 |
return color_dict.get(method, 'black') # If method is not in color_dict, use black
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
|
24 |
df = pd.read_csv(CSV_RESULT_PATH)
|
25 |
# Filter the dataframe based on selected methods
|
@@ -64,50 +122,33 @@ def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
|
|
64 |
|
65 |
return filename
|
66 |
|
67 |
-
def
|
68 |
-
|
69 |
-
|
70 |
-
return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
|
71 |
-
elif benchmark_type == 'similarity':
|
72 |
-
title = f"{x_metric} vs {y_metric}"
|
73 |
-
return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
|
74 |
-
elif benchmark_type == 'Benchmark 3':
|
75 |
-
return benchmark_3_plot(x_metric, y_metric)
|
76 |
-
elif benchmark_type == 'Benchmark 4':
|
77 |
-
return benchmark_4_plot(x_metric, y_metric)
|
78 |
-
else:
|
79 |
-
return "Invalid benchmark type selected."
|
80 |
-
|
81 |
-
|
82 |
-
def get_baseline_df(selected_methods, selected_metrics):
|
83 |
-
df = pd.read_csv(CSV_RESULT_PATH)
|
84 |
-
present_columns = ["method_name"] + selected_metrics
|
85 |
-
df = df[df['method_name'].isin(selected_methods)][present_columns]
|
86 |
-
return df
|
87 |
-
|
88 |
-
def general_visualizer(methods_selected, x_metric, y_metric):
|
89 |
-
df = pd.read_csv(CSV_RESULT_PATH)
|
90 |
-
filtered_df = df[df['method_name'].isin(methods_selected)]
|
91 |
-
|
92 |
-
# Create a Seaborn lineplot with method as hue
|
93 |
-
plt.figure(figsize=(10, 8)) # Increase figure size
|
94 |
-
sns.lineplot(
|
95 |
-
data=filtered_df,
|
96 |
-
x=x_metric,
|
97 |
-
y=y_metric,
|
98 |
-
hue="method_name", # Different colors for different methods
|
99 |
-
marker="o", # Add markers to the line plot
|
100 |
-
)
|
101 |
|
102 |
-
#
|
103 |
-
|
104 |
-
plt.ylabel(y_metric)
|
105 |
-
plt.title(f'{y_metric} vs {x_metric} for selected methods')
|
106 |
-
plt.grid(True)
|
107 |
|
108 |
-
#
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
return
|
|
|
17 |
|
18 |
global data_component, filter_component
|
19 |
|
20 |
+
def get_baseline_df(selected_methods, selected_metrics):
|
21 |
+
df = pd.read_csv(CSV_RESULT_PATH)
|
22 |
+
present_columns = ["method_name"] + selected_metrics
|
23 |
+
df = df[df['method_name'].isin(selected_methods)][present_columns]
|
24 |
+
return df
|
25 |
+
|
26 |
def get_method_color(method):
|
27 |
return color_dict.get(method, 'black') # If method is not in color_dict, use black
|
28 |
|
29 |
+
def set_colors_and_marks_for_representation_groups(ax):
|
30 |
+
for label in ax.get_xticklabels():
|
31 |
+
text = label.get_text()
|
32 |
+
color = group_color_dict.get(text, 'black') # Default to black if label not in dict
|
33 |
+
label.set_color(color)
|
34 |
+
label.set_fontweight('bold')
|
35 |
+
|
36 |
+
# Add a caret symbol to specific labels
|
37 |
+
if text in {'MUT2VEC', 'PFAM', 'GENE2VEC', 'BERT-PFAM'}:
|
38 |
+
label.set_text(f"^ {text}")
|
39 |
+
|
40 |
+
def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
|
41 |
+
if benchmark_type == 'flexible':
|
42 |
+
# Use general visualizer logic
|
43 |
+
return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
|
44 |
+
elif benchmark_type == 'similarity':
|
45 |
+
title = f"{x_metric} vs {y_metric}"
|
46 |
+
return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
|
47 |
+
elif benchmark_type == 'Benchmark 3':
|
48 |
+
return benchmark_3_plot(x_metric, y_metric)
|
49 |
+
elif benchmark_type == 'Benchmark 4':
|
50 |
+
return benchmark_4_plot(x_metric, y_metric)
|
51 |
+
else:
|
52 |
+
return "Invalid benchmark type selected."
|
53 |
+
|
54 |
+
def general_visualizer(methods_selected, x_metric, y_metric):
|
55 |
+
df = pd.read_csv(CSV_RESULT_PATH)
|
56 |
+
filtered_df = df[df['method_name'].isin(methods_selected)]
|
57 |
+
|
58 |
+
# Create a Seaborn lineplot with method as hue
|
59 |
+
plt.figure(figsize=(10, 8)) # Increase figure size
|
60 |
+
sns.lineplot(
|
61 |
+
data=filtered_df,
|
62 |
+
x=x_metric,
|
63 |
+
y=y_metric,
|
64 |
+
hue="method_name", # Different colors for different methods
|
65 |
+
marker="o", # Add markers to the line plot
|
66 |
+
)
|
67 |
+
|
68 |
+
# Add labels and title
|
69 |
+
plt.xlabel(x_metric)
|
70 |
+
plt.ylabel(y_metric)
|
71 |
+
plt.title(f'{y_metric} vs {x_metric} for selected methods')
|
72 |
+
plt.grid(True)
|
73 |
+
|
74 |
+
# Save the plot to display it in Gradio
|
75 |
+
plot_path = "plot.png"
|
76 |
+
plt.savefig(plot_path)
|
77 |
+
plt.close()
|
78 |
+
|
79 |
+
return plot_path
|
80 |
+
|
81 |
def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
|
82 |
df = pd.read_csv(CSV_RESULT_PATH)
|
83 |
# Filter the dataframe based on selected methods
|
|
|
122 |
|
123 |
return filename
|
124 |
|
125 |
+
def visualize_aspect_metric_clustermap(file_path, aspect, metric, method_names):
|
126 |
+
# Load data
|
127 |
+
df = pd.read_csv(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
# Filter for selected methods
|
130 |
+
df = df[df['Method'].isin(method_names)]
|
|
|
|
|
|
|
131 |
|
132 |
+
# Filter columns for specified aspect and metric
|
133 |
+
columns_to_plot = [col for col in df.columns if col.startswith(f"{aspect}_") and col.endswith(f"_{metric}")]
|
134 |
+
df = df[['Method'] + columns_to_plot]
|
135 |
+
df.set_index('Method', inplace=True)
|
136 |
+
|
137 |
+
# Create clustermap
|
138 |
+
g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15))
|
139 |
+
|
140 |
+
# Get heatmap axis and customize labels
|
141 |
+
ax = g.ax_heatmap
|
142 |
+
ax.set_xlabel("")
|
143 |
+
ax.set_ylabel("")
|
144 |
+
|
145 |
+
# Apply color and caret adjustments to x-axis labels
|
146 |
+
set_colors_and_marks_for_representation_groups(ax)
|
147 |
+
|
148 |
+
# Save the plot as an image
|
149 |
+
os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
|
150 |
+
filename = os.path.join(save_path, f"{aspect}_{metric}_heatmap.png")
|
151 |
+
plt.savefig(filename, dpi=400, bbox_inches='tight')
|
152 |
+
plt.close() # Close the plot to free memory
|
153 |
|
154 |
+
return filename
|