felipekitamura commited on
Commit
4556ec6
·
verified ·
1 Parent(s): 4b878b7

Update omnibin/metrics.py

Browse files
Files changed (1) hide show
  1. omnibin/metrics.py +224 -0
omnibin/metrics.py CHANGED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from tqdm import tqdm
6
+ import os
7
+ from sklearn.metrics import (
8
+ accuracy_score, recall_score, precision_score, f1_score, roc_auc_score,
9
+ average_precision_score, confusion_matrix, matthews_corrcoef, roc_curve,
10
+ precision_recall_curve
11
+ )
12
+ from sklearn.calibration import calibration_curve
13
+ from matplotlib.backends.backend_pdf import PdfPages
14
+
15
+ def generate_binary_classification_report(y_true, y_scores, output_path="omnibin_report.pdf", n_bootstrap=1000, random_seed=42, dpi=300):
16
+ # Set random seed for reproducibility
17
+ if random_seed is not None:
18
+ np.random.seed(random_seed)
19
+
20
+ # Ensure output directory exists
21
+ output_dir = os.path.dirname(output_path)
22
+ if output_dir:
23
+ os.makedirs(output_dir, exist_ok=True)
24
+
25
+ # Set DPI for all figures
26
+ plt.rcParams['figure.dpi'] = dpi
27
+
28
+ thresholds = np.linspace(0, 1, 100)
29
+ metrics_by_threshold = []
30
+
31
+ for t in tqdm(thresholds, desc="Calculating metrics across thresholds"):
32
+ y_pred = (y_scores >= t).astype(int)
33
+ acc = accuracy_score(y_true, y_pred)
34
+ sens = recall_score(y_true, y_pred)
35
+ spec = recall_score(y_true, y_pred, pos_label=0)
36
+ ppv = precision_score(y_true, y_pred, zero_division=0)
37
+ mcc = matthews_corrcoef(y_true, y_pred)
38
+ f1 = f1_score(y_true, y_pred)
39
+ metrics_by_threshold.append([t, acc, sens, spec, ppv, mcc, f1])
40
+
41
+ metrics_df = pd.DataFrame(metrics_by_threshold, columns=[
42
+ "Threshold", "Accuracy", "Sensitivity", "Specificity",
43
+ "PPV", "MCC", "F1 Score"
44
+ ])
45
+
46
+ def bootstrap_metric(metric_func, y_true, y_scores, n_boot=1000):
47
+ stats = []
48
+ for _ in tqdm(range(n_boot), desc="Bootstrap iterations", leave=False):
49
+ indices = np.random.choice(range(len(y_true)), len(y_true), replace=True)
50
+ try:
51
+ stats.append(metric_func(y_true[indices], y_scores[indices]))
52
+ except:
53
+ continue
54
+ return np.percentile(stats, [2.5, 97.5])
55
+
56
+ def bootstrap_curves(y_true, y_scores, n_boot=1000):
57
+ tprs = []
58
+ fprs = []
59
+ precisions = []
60
+ recalls = []
61
+
62
+ # Get the base curves to determine common points
63
+ base_fpr, base_tpr, _ = roc_curve(y_true, y_scores)
64
+ base_precision, base_recall, _ = precision_recall_curve(y_true, y_scores)
65
+
66
+ # Create common x-axis points
67
+ common_fpr = np.linspace(0, 1, 100)
68
+ common_recall = np.linspace(0, 1, 100)
69
+
70
+ for _ in tqdm(range(n_boot), desc="Bootstrap iterations for curves", leave=False):
71
+ indices = np.random.choice(range(len(y_true)), len(y_true), replace=True)
72
+ try:
73
+ # ROC curve
74
+ fpr, tpr, _ = roc_curve(y_true[indices], y_scores[indices])
75
+ tpr_interp = np.interp(common_fpr, fpr, tpr)
76
+ tprs.append(tpr_interp)
77
+
78
+ # PR curve - handle precision interpolation carefully
79
+ precision, recall, _ = precision_recall_curve(y_true[indices], y_scores[indices])
80
+ # Sort by recall to ensure proper interpolation
81
+ sort_idx = np.argsort(recall)
82
+ recall = recall[sort_idx]
83
+ precision = precision[sort_idx]
84
+ # Interpolate precision values
85
+ precision_interp = np.interp(common_recall, recall, precision)
86
+ precisions.append(precision_interp)
87
+ except:
88
+ continue
89
+
90
+ # Calculate confidence intervals
91
+ tpr_ci = np.percentile(tprs, [2.5, 97.5], axis=0)
92
+ precision_ci = np.percentile(precisions, [2.5, 97.5], axis=0)
93
+
94
+ return tpr_ci, precision_ci, common_fpr, common_recall
95
+
96
+ fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
97
+ j_scores = tpr - fpr
98
+ best_thresh = roc_thresholds[np.argmax(j_scores)]
99
+ y_pred_opt = (y_scores >= best_thresh).astype(int)
100
+
101
+ metrics_summary = {
102
+ "Accuracy": accuracy_score(y_true, y_pred_opt),
103
+ "Sensitivity": recall_score(y_true, y_pred_opt),
104
+ "Specificity": recall_score(y_true, y_pred_opt, pos_label=0),
105
+ "PPV": precision_score(y_true, y_pred_opt, zero_division=0),
106
+ "MCC": matthews_corrcoef(y_true, y_pred_opt),
107
+ "F1 Score": f1_score(y_true, y_pred_opt),
108
+ "AUC-ROC": roc_auc_score(y_true, y_scores),
109
+ "AUC-PR": average_precision_score(y_true, y_scores)
110
+ }
111
+
112
+ conf_intervals = {}
113
+ for name, func in {
114
+ "Accuracy": lambda yt, ys: accuracy_score(yt, ys >= best_thresh),
115
+ "Sensitivity": lambda yt, ys: recall_score(yt, ys >= best_thresh),
116
+ "Specificity": lambda yt, ys: recall_score(yt, ys >= best_thresh, pos_label=0),
117
+ "PPV": lambda yt, ys: precision_score(yt, ys >= best_thresh, zero_division=0),
118
+ "MCC": lambda yt, ys: matthews_corrcoef(yt, ys >= best_thresh),
119
+ "F1 Score": lambda yt, ys: f1_score(yt, ys >= best_thresh),
120
+ "AUC-ROC": lambda yt, ys: roc_auc_score(yt, ys),
121
+ "AUC-PR": lambda yt, ys: average_precision_score(yt, ys)
122
+ }.items():
123
+ ci = bootstrap_metric(func, y_true, y_scores, n_boot=n_bootstrap)
124
+ conf_intervals[name] = ci
125
+
126
+ # Create output directory for individual plots
127
+ plots_dir = os.path.join(output_dir, "plots")
128
+ os.makedirs(plots_dir, exist_ok=True)
129
+
130
+ with PdfPages(output_path) as pdf:
131
+ # ROC and PR Curves with proper confidence intervals
132
+ plt.figure(figsize=(12, 5), dpi=dpi)
133
+
134
+ # Calculate confidence intervals for curves
135
+ tpr_ci, precision_ci, common_fpr, common_recall = bootstrap_curves(y_true, y_scores, n_boot=n_bootstrap)
136
+
137
+ plt.subplot(1, 2, 1)
138
+ fpr, tpr, _ = roc_curve(y_true, y_scores)
139
+ plt.plot(fpr, tpr, label="ROC curve")
140
+ plt.fill_between(common_fpr, tpr_ci[0], tpr_ci[1], alpha=0.3)
141
+ plt.plot([0, 1], [0, 1], "k--")
142
+ plt.xlabel("False Positive Rate")
143
+ plt.ylabel("True Positive Rate")
144
+ plt.title("ROC Curve")
145
+ plt.legend()
146
+
147
+ plt.subplot(1, 2, 2)
148
+ precision, recall, _ = precision_recall_curve(y_true, y_scores)
149
+ plt.plot(recall, precision, label="PR curve")
150
+ plt.fill_between(common_recall, precision_ci[0], precision_ci[1], alpha=0.3)
151
+ plt.xlabel("Recall")
152
+ plt.ylabel("Precision")
153
+ plt.title("Precision-Recall Curve")
154
+ plt.legend()
155
+ plt.savefig(os.path.join(plots_dir, "roc_pr.png"), dpi=dpi, bbox_inches='tight')
156
+ pdf.savefig(dpi=dpi)
157
+ plt.close()
158
+
159
+ # Metrics vs Threshold
160
+ plt.figure(figsize=(10, 6), dpi=dpi)
161
+ for col in metrics_df.columns[1:]:
162
+ plt.plot(metrics_df["Threshold"], metrics_df[col], label=col)
163
+ plt.xlabel("Threshold")
164
+ plt.ylabel("Metric Value")
165
+ plt.title("Metrics Across Thresholds")
166
+ plt.legend()
167
+ plt.savefig(os.path.join(plots_dir, "metrics_threshold.png"), dpi=dpi, bbox_inches='tight')
168
+ pdf.savefig(dpi=dpi)
169
+ plt.close()
170
+
171
+ # Confusion Matrix
172
+ cm = confusion_matrix(y_true, y_pred_opt)
173
+ plt.figure(figsize=(5, 4), dpi=dpi)
174
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
175
+ plt.title("Confusion Matrix (Optimal Threshold)")
176
+ plt.xlabel("Predicted Label")
177
+ plt.ylabel("True Label")
178
+ plt.savefig(os.path.join(plots_dir, "confusion_matrix.png"), dpi=dpi, bbox_inches='tight')
179
+ pdf.savefig(dpi=dpi)
180
+ plt.close()
181
+
182
+ # Calibration Plot
183
+ plt.figure(figsize=(6, 6), dpi=dpi)
184
+ prob_true, prob_pred = calibration_curve(y_true, y_scores, n_bins=10, strategy='uniform')
185
+ plt.plot(prob_pred, prob_true, marker='o', label='Calibration curve')
186
+ plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
187
+ plt.xlabel('Predicted Probability')
188
+ plt.ylabel('True Probability')
189
+ plt.title('Calibration Plot')
190
+ plt.legend()
191
+ plt.savefig(os.path.join(plots_dir, "calibration.png"), dpi=dpi, bbox_inches='tight')
192
+ pdf.savefig(dpi=dpi)
193
+ plt.close()
194
+
195
+ # Metrics Summary Table
196
+ fig, ax = plt.subplots(figsize=(8, 6), dpi=dpi)
197
+ ax.axis("off")
198
+ table_data = [
199
+ [k, f"{v:.3f}", f"[{conf_intervals[k][0]:.3f}, {conf_intervals[k][1]:.3f}]"]
200
+ for k, v in metrics_summary.items()
201
+ ]
202
+ table = ax.table(cellText=table_data, colLabels=["Metric", "Value", "95% CI"], loc="center")
203
+ table.auto_set_font_size(False)
204
+ table.set_fontsize(10)
205
+ table.scale(1.2, 1.2)
206
+ ax.set_title("Performance Metrics at Optimal Threshold", fontweight="bold")
207
+ plt.savefig(os.path.join(plots_dir, "metrics_summary.png"), dpi=dpi, bbox_inches='tight')
208
+ pdf.savefig(dpi=dpi)
209
+ plt.close()
210
+
211
+ # Prediction Distribution Histogram
212
+ plt.figure(figsize=(10, 6), dpi=dpi)
213
+ plt.hist(y_scores[y_true == 1], bins=50, alpha=0.5, label='Positive Class', color='blue')
214
+ plt.hist(y_scores[y_true == 0], bins=50, alpha=0.5, label='Negative Class', color='red')
215
+ plt.axvline(x=best_thresh, color='black', linestyle='--', label=f'Optimal Threshold ({best_thresh:.3f})')
216
+ plt.xlabel('Predicted Probability')
217
+ plt.ylabel('Count')
218
+ plt.title('Distribution of Predictions')
219
+ plt.legend()
220
+ plt.savefig(os.path.join(plots_dir, "prediction_distribution.png"), dpi=dpi, bbox_inches='tight')
221
+ pdf.savefig(dpi=dpi)
222
+ plt.close()
223
+
224
+ return output_path