mgyigit commited on
Commit
83ac1b8
·
verified ·
1 Parent(s): 4d48db2

Update src/vis_utils.py

Browse files
Files changed (1) hide show
  1. src/vis_utils.py +20 -22
src/vis_utils.py CHANGED
@@ -26,13 +26,13 @@ def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect,
26
  if benchmark_type == 'similarity':
27
  return plot_similarity_results(methods_selected, x_metric, y_metric)
28
  elif benchmark_type == 'function':
29
- return plot_function_results(aspect, single_metric, methods_selected)
30
  elif benchmark_type == 'family':
31
- return plot_family_results("./data/family_results.csv", methods_selected, x_metric, save_path="./plot_images")
32
  elif benchmark_type == "affinity":
33
- return plot_affinity_results("./data/affinity_results.csv", methods_selected, x_metric, save_path="./plot_images")
34
-
35
- return 0
36
 
37
  def get_method_color(method):
38
  return color_dict.get(method, 'black') # If method is not in color_dict, use black
@@ -119,7 +119,7 @@ def plot_similarity_results(methods_selected, x_metric, y_metric, similarity_pat
119
 
120
  return filename
121
 
122
- def plot_function_results(aspect, metric, method_names, function_path="/tmp/function_results.csv"):
123
  if not os.path.exists(function_path):
124
  benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
125
  download_from_hub(benchmark_types)
@@ -168,18 +168,16 @@ def plot_function_results(aspect, metric, method_names, function_path="/tmp/func
168
 
169
  return filename
170
 
171
- def plot_family_results(file_path, method_names, metric, save_path="./plot_images"):
172
- # Load data
173
- df = pd.read_csv(file_path)
174
-
175
- # Filter by method names and selected metric columns
 
 
 
176
  df = df[df['Method'].isin(method_names)]
177
- metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
178
-
179
- # Check if there are columns matching the selected metric
180
- if not metric_columns:
181
- print(f"No columns found for metric '{metric}'.")
182
- return None
183
 
184
  # Reshape data for plotting
185
  df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
@@ -204,12 +202,12 @@ def plot_family_results(file_path, method_names, metric, save_path="./plot_image
204
  ax.hlines(ytick + 0.5, -0.1, 1, linestyles='dashed')
205
 
206
  # Apply color settings to y-axis labels
207
- set_colors_and_marks_for_representation_groups(ax)
208
-
209
- # Ensure save directory exists
210
- os.makedirs(save_path, exist_ok=True)
211
-
212
  # Save the plot
 
213
  filename = os.path.join(save_path, f"{metric}_family_results.png")
214
  ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
215
  plt.close() # Close the plot to free memory
 
26
  if benchmark_type == 'similarity':
27
  return plot_similarity_results(methods_selected, x_metric, y_metric)
28
  elif benchmark_type == 'function':
29
+ return plot_function_results(methods_selected, aspect, single_metric)
30
  elif benchmark_type == 'family':
31
+ return plot_family_results(methods_selected, dataset, single_metric)
32
  elif benchmark_type == "affinity":
33
+ return plot_affinity_results(methods_selected, single_metric)
34
+ else:
35
+ return -1
36
 
37
  def get_method_color(method):
38
  return color_dict.get(method, 'black') # If method is not in color_dict, use black
 
119
 
120
  return filename
121
 
122
+ def plot_function_results(method_names, aspect, metric, function_path="/tmp/function_results.csv"):
123
  if not os.path.exists(function_path):
124
  benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
125
  download_from_hub(benchmark_types)
 
168
 
169
  return filename
170
 
171
+ def plot_family_results(methods_selected, dataset, metric, family_path="/tmp/family_results.csv"):
172
+ if not os.path.exists(function_path):
173
+ benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
174
+ download_from_hub(benchmark_types)
175
+
176
+ df = pd.read_csv(family_path)
177
+
178
+ # Filter by method names and selected dataset columns
179
  df = df[df['Method'].isin(method_names)]
180
+ metric_columns = [col for col in df.columns if col.startswith(f"{dataset}_{metric}_")]
 
 
 
 
 
181
 
182
  # Reshape data for plotting
183
  df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
 
202
  ax.hlines(ytick + 0.5, -0.1, 1, linestyles='dashed')
203
 
204
  # Apply color settings to y-axis labels
205
+ for label in ax.get_yticklabels():
206
+ method = label.get_text()
207
+ label.set_color(get_method_color(method))
208
+
 
209
  # Save the plot
210
+ save_path = "/tmp"
211
  filename = os.path.join(save_path, f"{metric}_family_results.png")
212
  ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
213
  plt.close() # Close the plot to free memory