felipekitamura commited on
Commit
68af8b4
·
verified ·
1 Parent(s): f3dac67

Update omnibin/metrics.py

Browse files
Files changed (1) hide show
  1. omnibin/metrics.py +33 -195
omnibin/metrics.py CHANGED
@@ -11,214 +11,52 @@ from sklearn.metrics import (
11
  )
12
  from sklearn.calibration import calibration_curve
13
  from matplotlib.backends.backend_pdf import PdfPages
 
 
 
 
 
 
 
 
14
 
15
- def generate_binary_classification_report(y_true, y_scores, output_path="omnibin_report.pdf", n_bootstrap=1000, random_seed=42, dpi=300):
16
  # Set random seed for reproducibility
17
  if random_seed is not None:
18
  np.random.seed(random_seed)
19
 
20
- # Ensure output directory exists
21
- output_dir = os.path.dirname(output_path)
22
- if output_dir:
23
- os.makedirs(output_dir, exist_ok=True)
24
-
25
  # Set DPI for all figures
26
  plt.rcParams['figure.dpi'] = dpi
27
 
28
- thresholds = np.linspace(0, 1, 100)
29
- metrics_by_threshold = []
30
-
31
- for t in tqdm(thresholds, desc="Calculating metrics across thresholds"):
32
- y_pred = (y_scores >= t).astype(int)
33
- acc = accuracy_score(y_true, y_pred)
34
- sens = recall_score(y_true, y_pred)
35
- spec = recall_score(y_true, y_pred, pos_label=0)
36
- ppv = precision_score(y_true, y_pred, zero_division=0)
37
- mcc = matthews_corrcoef(y_true, y_pred)
38
- f1 = f1_score(y_true, y_pred)
39
- metrics_by_threshold.append([t, acc, sens, spec, ppv, mcc, f1])
40
-
41
- metrics_df = pd.DataFrame(metrics_by_threshold, columns=[
42
- "Threshold", "Accuracy", "Sensitivity", "Specificity",
43
- "PPV", "MCC", "F1 Score"
44
- ])
45
-
46
- def bootstrap_metric(metric_func, y_true, y_scores, n_boot=1000):
47
- stats = []
48
- for _ in tqdm(range(n_boot), desc="Bootstrap iterations", leave=False):
49
- indices = np.random.choice(range(len(y_true)), len(y_true), replace=True)
50
- try:
51
- stats.append(metric_func(y_true[indices], y_scores[indices]))
52
- except:
53
- continue
54
- return np.percentile(stats, [2.5, 97.5])
55
-
56
- def bootstrap_curves(y_true, y_scores, n_boot=1000):
57
- tprs = []
58
- fprs = []
59
- precisions = []
60
- recalls = []
61
-
62
- # Get the base curves to determine common points
63
- base_fpr, base_tpr, _ = roc_curve(y_true, y_scores)
64
- base_precision, base_recall, _ = precision_recall_curve(y_true, y_scores)
65
-
66
- # Create common x-axis points
67
- common_fpr = np.linspace(0, 1, 100)
68
- common_recall = np.linspace(0, 1, 100)
69
-
70
- for _ in tqdm(range(n_boot), desc="Bootstrap iterations for curves", leave=False):
71
- indices = np.random.choice(range(len(y_true)), len(y_true), replace=True)
72
- try:
73
- # ROC curve
74
- fpr, tpr, _ = roc_curve(y_true[indices], y_scores[indices])
75
- tpr_interp = np.interp(common_fpr, fpr, tpr)
76
- tprs.append(tpr_interp)
77
-
78
- # PR curve - handle precision interpolation carefully
79
- precision, recall, _ = precision_recall_curve(y_true[indices], y_scores[indices])
80
- # Sort by recall to ensure proper interpolation
81
- sort_idx = np.argsort(recall)
82
- recall = recall[sort_idx]
83
- precision = precision[sort_idx]
84
- # Interpolate precision values
85
- precision_interp = np.interp(common_recall, recall, precision)
86
- precisions.append(precision_interp)
87
- except:
88
- continue
89
-
90
- # Calculate confidence intervals
91
- tpr_ci = np.percentile(tprs, [2.5, 97.5], axis=0)
92
- precision_ci = np.percentile(precisions, [2.5, 97.5], axis=0)
93
-
94
- return tpr_ci, precision_ci, common_fpr, common_recall
95
 
96
- fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
97
- j_scores = tpr - fpr
98
- best_thresh = roc_thresholds[np.argmax(j_scores)]
99
- y_pred_opt = (y_scores >= best_thresh).astype(int)
 
100
 
101
- metrics_summary = {
102
- "Accuracy": accuracy_score(y_true, y_pred_opt),
103
- "Sensitivity": recall_score(y_true, y_pred_opt),
104
- "Specificity": recall_score(y_true, y_pred_opt, pos_label=0),
105
- "PPV": precision_score(y_true, y_pred_opt, zero_division=0),
106
- "MCC": matthews_corrcoef(y_true, y_pred_opt),
107
- "F1 Score": f1_score(y_true, y_pred_opt),
108
- "AUC-ROC": roc_auc_score(y_true, y_scores),
109
- "AUC-PR": average_precision_score(y_true, y_scores)
110
- }
111
 
112
- conf_intervals = {}
113
- for name, func in {
114
- "Accuracy": lambda yt, ys: accuracy_score(yt, ys >= best_thresh),
115
- "Sensitivity": lambda yt, ys: recall_score(yt, ys >= best_thresh),
116
- "Specificity": lambda yt, ys: recall_score(yt, ys >= best_thresh, pos_label=0),
117
- "PPV": lambda yt, ys: precision_score(yt, ys >= best_thresh, zero_division=0),
118
- "MCC": lambda yt, ys: matthews_corrcoef(yt, ys >= best_thresh),
119
- "F1 Score": lambda yt, ys: f1_score(yt, ys >= best_thresh),
120
- "AUC-ROC": lambda yt, ys: roc_auc_score(yt, ys),
121
- "AUC-PR": lambda yt, ys: average_precision_score(yt, ys)
122
- }.items():
123
- ci = bootstrap_metric(func, y_true, y_scores, n_boot=n_bootstrap)
124
- conf_intervals[name] = ci
125
-
126
- # Create output directory for individual plots
127
- plots_dir = os.path.join(output_dir, "plots")
128
- os.makedirs(plots_dir, exist_ok=True)
129
 
130
  with PdfPages(output_path) as pdf:
131
- # ROC and PR Curves with proper confidence intervals
132
- plt.figure(figsize=(12, 5), dpi=dpi)
133
-
134
- # Calculate confidence intervals for curves
135
- tpr_ci, precision_ci, common_fpr, common_recall = bootstrap_curves(y_true, y_scores, n_boot=n_bootstrap)
136
-
137
- plt.subplot(1, 2, 1)
138
- fpr, tpr, _ = roc_curve(y_true, y_scores)
139
- plt.plot(fpr, tpr, label="ROC curve")
140
- plt.fill_between(common_fpr, tpr_ci[0], tpr_ci[1], alpha=0.3)
141
- plt.plot([0, 1], [0, 1], "k--")
142
- plt.xlabel("False Positive Rate")
143
- plt.ylabel("True Positive Rate")
144
- plt.title("ROC Curve")
145
- plt.legend()
146
-
147
- plt.subplot(1, 2, 2)
148
- precision, recall, _ = precision_recall_curve(y_true, y_scores)
149
- plt.plot(recall, precision, label="PR curve")
150
- plt.fill_between(common_recall, precision_ci[0], precision_ci[1], alpha=0.3)
151
- plt.xlabel("Recall")
152
- plt.ylabel("Precision")
153
- plt.title("Precision-Recall Curve")
154
- plt.legend()
155
- plt.savefig(os.path.join(plots_dir, "roc_pr.png"), dpi=dpi, bbox_inches='tight')
156
- pdf.savefig(dpi=dpi)
157
- plt.close()
158
-
159
- # Metrics vs Threshold
160
- plt.figure(figsize=(10, 6), dpi=dpi)
161
- for col in metrics_df.columns[1:]:
162
- plt.plot(metrics_df["Threshold"], metrics_df[col], label=col)
163
- plt.xlabel("Threshold")
164
- plt.ylabel("Metric Value")
165
- plt.title("Metrics Across Thresholds")
166
- plt.legend()
167
- plt.savefig(os.path.join(plots_dir, "metrics_threshold.png"), dpi=dpi, bbox_inches='tight')
168
- pdf.savefig(dpi=dpi)
169
- plt.close()
170
-
171
- # Confusion Matrix
172
- cm = confusion_matrix(y_true, y_pred_opt)
173
- plt.figure(figsize=(5, 4), dpi=dpi)
174
- sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
175
- plt.title("Confusion Matrix (Optimal Threshold)")
176
- plt.xlabel("Predicted Label")
177
- plt.ylabel("True Label")
178
- plt.savefig(os.path.join(plots_dir, "confusion_matrix.png"), dpi=dpi, bbox_inches='tight')
179
- pdf.savefig(dpi=dpi)
180
- plt.close()
181
-
182
- # Calibration Plot
183
- plt.figure(figsize=(6, 6), dpi=dpi)
184
- prob_true, prob_pred = calibration_curve(y_true, y_scores, n_bins=10, strategy='uniform')
185
- plt.plot(prob_pred, prob_true, marker='o', label='Calibration curve')
186
- plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
187
- plt.xlabel('Predicted Probability')
188
- plt.ylabel('True Probability')
189
- plt.title('Calibration Plot')
190
- plt.legend()
191
- plt.savefig(os.path.join(plots_dir, "calibration.png"), dpi=dpi, bbox_inches='tight')
192
- pdf.savefig(dpi=dpi)
193
- plt.close()
194
-
195
- # Metrics Summary Table
196
- fig, ax = plt.subplots(figsize=(8, 6), dpi=dpi)
197
- ax.axis("off")
198
- table_data = [
199
- [k, f"{v:.3f}", f"[{conf_intervals[k][0]:.3f}, {conf_intervals[k][1]:.3f}]"]
200
- for k, v in metrics_summary.items()
201
  ]
202
- table = ax.table(cellText=table_data, colLabels=["Metric", "Value", "95% CI"], loc="center")
203
- table.auto_set_font_size(False)
204
- table.set_fontsize(10)
205
- table.scale(1.2, 1.2)
206
- ax.set_title("Performance Metrics at Optimal Threshold", fontweight="bold")
207
- plt.savefig(os.path.join(plots_dir, "metrics_summary.png"), dpi=dpi, bbox_inches='tight')
208
- pdf.savefig(dpi=dpi)
209
- plt.close()
210
-
211
- # Prediction Distribution Histogram
212
- plt.figure(figsize=(10, 6), dpi=dpi)
213
- plt.hist(y_scores[y_true == 1], bins=50, alpha=0.5, label='Positive Class', color='blue')
214
- plt.hist(y_scores[y_true == 0], bins=50, alpha=0.5, label='Negative Class', color='red')
215
- plt.axvline(x=best_thresh, color='black', linestyle='--', label=f'Optimal Threshold ({best_thresh:.3f})')
216
- plt.xlabel('Predicted Probability')
217
- plt.ylabel('Count')
218
- plt.title('Distribution of Predictions')
219
- plt.legend()
220
- plt.savefig(os.path.join(plots_dir, "prediction_distribution.png"), dpi=dpi, bbox_inches='tight')
221
- pdf.savefig(dpi=dpi)
222
- plt.close()
223
 
224
  return output_path
 
11
  )
12
  from sklearn.calibration import calibration_curve
13
  from matplotlib.backends.backend_pdf import PdfPages
14
+ from enum import Enum
15
+ from .utils import (
16
+ ColorScheme, calculate_metrics_by_threshold, bootstrap_curves,
17
+ calculate_optimal_threshold, calculate_metrics_summary,
18
+ calculate_confidence_intervals, create_output_directories,
19
+ plot_roc_pr_curves, plot_metrics_threshold, plot_confusion_matrix,
20
+ plot_calibration, plot_metrics_summary, plot_prediction_distribution
21
+ )
22
 
23
+ def generate_binary_classification_report(y_true, y_scores, output_path="omnibin_report.pdf", n_bootstrap=1000, random_seed=42, dpi=300, color_scheme=ColorScheme.DEFAULT):
24
  # Set random seed for reproducibility
25
  if random_seed is not None:
26
  np.random.seed(random_seed)
27
 
 
 
 
 
 
28
  # Set DPI for all figures
29
  plt.rcParams['figure.dpi'] = dpi
30
 
31
+ # Get color scheme
32
+ colors = color_scheme.value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Calculate metrics and optimal threshold
35
+ metrics_df = calculate_metrics_by_threshold(y_true, y_scores)
36
+ best_thresh = calculate_optimal_threshold(y_true, y_scores)
37
+ metrics_summary = calculate_metrics_summary(y_true, y_scores, best_thresh)
38
+ conf_intervals = calculate_confidence_intervals(y_true, y_scores, best_thresh, n_bootstrap)
39
 
40
+ # Create output directories
41
+ plots_dir = create_output_directories(output_path)
 
 
 
 
 
 
 
 
42
 
43
+ # Calculate confidence intervals for curves
44
+ tpr_ci, precision_ci, common_fpr, common_recall = bootstrap_curves(y_true, y_scores, n_boot=n_bootstrap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  with PdfPages(output_path) as pdf:
47
+ # Generate and save all plots
48
+ plots = [
49
+ plot_roc_pr_curves(y_true, y_scores, tpr_ci, precision_ci, common_fpr, common_recall, colors, dpi, plots_dir),
50
+ plot_metrics_threshold(metrics_df, colors, dpi, plots_dir),
51
+ plot_confusion_matrix(y_true, y_scores, best_thresh, colors, dpi, plots_dir),
52
+ plot_calibration(y_true, y_scores, colors, dpi, plots_dir),
53
+ plot_metrics_summary(metrics_summary, conf_intervals, dpi, plots_dir),
54
+ plot_prediction_distribution(y_true, y_scores, best_thresh, colors, dpi, plots_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ]
56
+
57
+ # Save all plots to PDF
58
+ for plot in plots:
59
+ pdf.savefig(plot, dpi=dpi)
60
+ plt.close(plot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  return output_path