mgyigit commited on
Commit
3671dd0
·
verified ·
1 Parent(s): dc454dd

Update src/vis_utils.py

Browse files
Files changed (1) hide show
  1. src/vis_utils.py +85 -44
src/vis_utils.py CHANGED
@@ -17,9 +17,67 @@ from about import *
17
 
18
  global data_component, filter_component
19
 
 
 
 
 
 
 
20
  def get_method_color(method):
21
  return color_dict.get(method, 'black') # If method is not in color_dict, use black
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
24
  df = pd.read_csv(CSV_RESULT_PATH)
25
  # Filter the dataframe based on selected methods
@@ -64,50 +122,33 @@ def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
64
 
65
  return filename
66
 
67
- def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
68
- if benchmark_type == 'flexible':
69
- # Use general visualizer logic
70
- return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
71
- elif benchmark_type == 'similarity':
72
- title = f"{x_metric} vs {y_metric}"
73
- return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
74
- elif benchmark_type == 'Benchmark 3':
75
- return benchmark_3_plot(x_metric, y_metric)
76
- elif benchmark_type == 'Benchmark 4':
77
- return benchmark_4_plot(x_metric, y_metric)
78
- else:
79
- return "Invalid benchmark type selected."
80
-
81
-
82
- def get_baseline_df(selected_methods, selected_metrics):
83
- df = pd.read_csv(CSV_RESULT_PATH)
84
- present_columns = ["method_name"] + selected_metrics
85
- df = df[df['method_name'].isin(selected_methods)][present_columns]
86
- return df
87
-
88
- def general_visualizer(methods_selected, x_metric, y_metric):
89
- df = pd.read_csv(CSV_RESULT_PATH)
90
- filtered_df = df[df['method_name'].isin(methods_selected)]
91
-
92
- # Create a Seaborn lineplot with method as hue
93
- plt.figure(figsize=(10, 8)) # Increase figure size
94
- sns.lineplot(
95
- data=filtered_df,
96
- x=x_metric,
97
- y=y_metric,
98
- hue="method_name", # Different colors for different methods
99
- marker="o", # Add markers to the line plot
100
- )
101
 
102
- # Add labels and title
103
- plt.xlabel(x_metric)
104
- plt.ylabel(y_metric)
105
- plt.title(f'{y_metric} vs {x_metric} for selected methods')
106
- plt.grid(True)
107
 
108
- # Save the plot to display it in Gradio
109
- plot_path = "plot.png"
110
- plt.savefig(plot_path)
111
- plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- return plot_path
 
17
 
18
  global data_component, filter_component
19
 
20
+ def get_baseline_df(selected_methods, selected_metrics):
21
+ df = pd.read_csv(CSV_RESULT_PATH)
22
+ present_columns = ["method_name"] + selected_metrics
23
+ df = df[df['method_name'].isin(selected_methods)][present_columns]
24
+ return df
25
+
26
  def get_method_color(method):
27
  return color_dict.get(method, 'black') # If method is not in color_dict, use black
28
 
29
+ def set_colors_and_marks_for_representation_groups(ax):
30
+ for label in ax.get_xticklabels():
31
+ text = label.get_text()
32
+ color = group_color_dict.get(text, 'black') # Default to black if label not in dict
33
+ label.set_color(color)
34
+ label.set_fontweight('bold')
35
+
36
+ # Add a caret symbol to specific labels
37
+ if text in {'MUT2VEC', 'PFAM', 'GENE2VEC', 'BERT-PFAM'}:
38
+ label.set_text(f"^ {text}")
39
+
40
+ def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
41
+ if benchmark_type == 'flexible':
42
+ # Use general visualizer logic
43
+ return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
44
+ elif benchmark_type == 'similarity':
45
+ title = f"{x_metric} vs {y_metric}"
46
+ return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
47
+ elif benchmark_type == 'Benchmark 3':
48
+ return benchmark_3_plot(x_metric, y_metric)
49
+ elif benchmark_type == 'Benchmark 4':
50
+ return benchmark_4_plot(x_metric, y_metric)
51
+ else:
52
+ return "Invalid benchmark type selected."
53
+
54
+ def general_visualizer(methods_selected, x_metric, y_metric):
55
+ df = pd.read_csv(CSV_RESULT_PATH)
56
+ filtered_df = df[df['method_name'].isin(methods_selected)]
57
+
58
+ # Create a Seaborn lineplot with method as hue
59
+ plt.figure(figsize=(10, 8)) # Increase figure size
60
+ sns.lineplot(
61
+ data=filtered_df,
62
+ x=x_metric,
63
+ y=y_metric,
64
+ hue="method_name", # Different colors for different methods
65
+ marker="o", # Add markers to the line plot
66
+ )
67
+
68
+ # Add labels and title
69
+ plt.xlabel(x_metric)
70
+ plt.ylabel(y_metric)
71
+ plt.title(f'{y_metric} vs {x_metric} for selected methods')
72
+ plt.grid(True)
73
+
74
+ # Save the plot to display it in Gradio
75
+ plot_path = "plot.png"
76
+ plt.savefig(plot_path)
77
+ plt.close()
78
+
79
+ return plot_path
80
+
81
  def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
82
  df = pd.read_csv(CSV_RESULT_PATH)
83
  # Filter the dataframe based on selected methods
 
122
 
123
  return filename
124
 
125
+ def visualize_aspect_metric_clustermap(file_path, aspect, metric, method_names):
126
+ # Load data
127
+ df = pd.read_csv(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # Filter for selected methods
130
+ df = df[df['Method'].isin(method_names)]
 
 
 
131
 
132
+ # Filter columns for specified aspect and metric
133
+ columns_to_plot = [col for col in df.columns if col.startswith(f"{aspect}_") and col.endswith(f"_{metric}")]
134
+ df = df[['Method'] + columns_to_plot]
135
+ df.set_index('Method', inplace=True)
136
+
137
+ # Create clustermap
138
+ g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15))
139
+
140
+ # Get heatmap axis and customize labels
141
+ ax = g.ax_heatmap
142
+ ax.set_xlabel("")
143
+ ax.set_ylabel("")
144
+
145
+ # Apply color and caret adjustments to x-axis labels
146
+ set_colors_and_marks_for_representation_groups(ax)
147
+
148
+ # Save the plot as an image
149
+ os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
150
+ filename = os.path.join(save_path, f"{aspect}_{metric}_heatmap.png")
151
+ plt.savefig(filename, dpi=400, bbox_inches='tight')
152
+ plt.close() # Close the plot to free memory
153
 
154
+ return filename