felipekitamura commited on
Commit
5bed619
·
verified ·
1 Parent(s): 68af8b4

Upload 2 files

Browse files
Files changed (2) hide show
  1. omnibin/__init__.py +1 -1
  2. omnibin/utils.py +263 -0
omnibin/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .metrics import generate_binary_classification_report
2
 
3
  __version__ = "0.1.0"
4
  __all__ = ["generate_binary_classification_report"]
 
1
+ from .metrics import generate_binary_classification_report, ColorScheme
2
 
3
  __version__ = "0.1.0"
4
  __all__ = ["generate_binary_classification_report"]
omnibin/utils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from tqdm import tqdm
6
+ from sklearn.metrics import (
7
+ accuracy_score, recall_score, precision_score, f1_score, roc_auc_score,
8
+ average_precision_score, confusion_matrix, matthews_corrcoef, roc_curve,
9
+ precision_recall_curve
10
+ )
11
+ from sklearn.calibration import calibration_curve
12
+ from enum import Enum
13
+ import os
14
+
15
+ class ColorScheme(Enum):
16
+ DEFAULT = {
17
+ 'positive_class': 'tab:blue',
18
+ 'negative_class': 'tab:orange',
19
+ 'roc_curve': 'tab:blue',
20
+ 'pr_curve': 'tab:blue',
21
+ 'threshold_line': 'black',
22
+ 'calibration_curve': 'tab:blue',
23
+ 'calibration_reference': 'gray',
24
+ 'metrics_colors': ['tab:blue', 'tab:red', 'tab:green', 'tab:purple', 'tab:orange', 'tab:brown', 'tab:pink'],
25
+ 'cmap': 'Blues'
26
+ }
27
+
28
+ MONOCHROME = {
29
+ 'positive_class': '#404040',
30
+ 'negative_class': '#808080',
31
+ 'roc_curve': '#000000',
32
+ 'pr_curve': '#000000',
33
+ 'threshold_line': '#000000',
34
+ 'calibration_curve': '#000000',
35
+ 'calibration_reference': '#808080',
36
+ 'metrics_colors': ['#000000', '#404040', '#606060', '#808080', '#A0A0A0', '#C0C0C0', '#E0E0E0'],
37
+ 'cmap': 'Greys'
38
+ }
39
+
40
+ VIBRANT = {
41
+ 'positive_class': '#FF6B6B',
42
+ 'negative_class': '#4ECDC4',
43
+ 'roc_curve': '#FF6B6B',
44
+ 'pr_curve': '#4ECDC4',
45
+ 'threshold_line': '#2C3E50',
46
+ 'calibration_curve': '#FF6B6B',
47
+ 'calibration_reference': '#95A5A6',
48
+ 'metrics_colors': ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#D4A5A5', '#9B59B6'],
49
+ 'cmap': 'Greens'
50
+ }
51
+
52
+ def calculate_metrics_by_threshold(y_true, y_scores):
53
+ """Calculate various metrics across different thresholds."""
54
+ thresholds = np.linspace(0, 1, 100)
55
+ metrics_by_threshold = []
56
+
57
+ for t in tqdm(thresholds, desc="Calculating metrics across thresholds"):
58
+ y_pred = (y_scores >= t).astype(int)
59
+ acc = accuracy_score(y_true, y_pred)
60
+ sens = recall_score(y_true, y_pred)
61
+ spec = recall_score(y_true, y_pred, pos_label=0)
62
+ ppv = precision_score(y_true, y_pred, zero_division=0)
63
+ mcc = matthews_corrcoef(y_true, y_pred)
64
+ f1 = f1_score(y_true, y_pred)
65
+ metrics_by_threshold.append([t, acc, sens, spec, ppv, mcc, f1])
66
+
67
+ return pd.DataFrame(metrics_by_threshold, columns=[
68
+ "Threshold", "Accuracy", "Sensitivity", "Specificity",
69
+ "PPV", "MCC", "F1 Score"
70
+ ])
71
+
72
+ def bootstrap_metric(metric_func, y_true, y_scores, n_boot=1000):
73
+ """Calculate bootstrap confidence intervals for a given metric."""
74
+ stats = []
75
+ for _ in tqdm(range(n_boot), desc="Bootstrap iterations", leave=False):
76
+ indices = np.random.choice(range(len(y_true)), len(y_true), replace=True)
77
+ try:
78
+ stats.append(metric_func(y_true[indices], y_scores[indices]))
79
+ except:
80
+ continue
81
+ return np.percentile(stats, [2.5, 97.5])
82
+
83
+ def bootstrap_curves(y_true, y_scores, n_boot=1000):
84
+ """Calculate bootstrap confidence intervals for ROC and PR curves."""
85
+ tprs = []
86
+ fprs = []
87
+ precisions = []
88
+ recalls = []
89
+
90
+ base_fpr, base_tpr, _ = roc_curve(y_true, y_scores)
91
+ base_precision, base_recall, _ = precision_recall_curve(y_true, y_scores)
92
+
93
+ common_fpr = np.linspace(0, 1, 100)
94
+ common_recall = np.linspace(0, 1, 100)
95
+
96
+ for _ in tqdm(range(n_boot), desc="Bootstrap iterations for curves", leave=False):
97
+ indices = np.random.choice(range(len(y_true)), len(y_true), replace=True)
98
+ try:
99
+ fpr, tpr, _ = roc_curve(y_true[indices], y_scores[indices])
100
+ tpr_interp = np.interp(common_fpr, fpr, tpr)
101
+ tprs.append(tpr_interp)
102
+
103
+ precision, recall, _ = precision_recall_curve(y_true[indices], y_scores[indices])
104
+ sort_idx = np.argsort(recall)
105
+ recall = recall[sort_idx]
106
+ precision = precision[sort_idx]
107
+ precision_interp = np.interp(common_recall, recall, precision)
108
+ precisions.append(precision_interp)
109
+ except:
110
+ continue
111
+
112
+ tpr_ci = np.percentile(tprs, [2.5, 97.5], axis=0)
113
+ precision_ci = np.percentile(precisions, [2.5, 97.5], axis=0)
114
+
115
+ return tpr_ci, precision_ci, common_fpr, common_recall
116
+
117
+ def calculate_optimal_threshold(y_true, y_scores):
118
+ """Calculate the optimal threshold using ROC curve."""
119
+ fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
120
+ j_scores = tpr - fpr
121
+ return roc_thresholds[np.argmax(j_scores)]
122
+
123
+ def calculate_metrics_summary(y_true, y_scores, best_thresh):
124
+ """Calculate summary metrics at the optimal threshold."""
125
+ y_pred_opt = (y_scores >= best_thresh).astype(int)
126
+
127
+ return {
128
+ "Accuracy": accuracy_score(y_true, y_pred_opt),
129
+ "Sensitivity": recall_score(y_true, y_pred_opt),
130
+ "Specificity": recall_score(y_true, y_pred_opt, pos_label=0),
131
+ "PPV": precision_score(y_true, y_pred_opt, zero_division=0),
132
+ "MCC": matthews_corrcoef(y_true, y_pred_opt),
133
+ "F1 Score": f1_score(y_true, y_pred_opt),
134
+ "AUC-ROC": roc_auc_score(y_true, y_scores),
135
+ "AUC-PR": average_precision_score(y_true, y_scores)
136
+ }
137
+
138
+ def calculate_confidence_intervals(y_true, y_scores, best_thresh, n_bootstrap=1000):
139
+ """Calculate confidence intervals for all metrics."""
140
+ metric_functions = {
141
+ "Accuracy": lambda yt, ys: accuracy_score(yt, ys >= best_thresh),
142
+ "Sensitivity": lambda yt, ys: recall_score(yt, ys >= best_thresh),
143
+ "Specificity": lambda yt, ys: recall_score(yt, ys >= best_thresh, pos_label=0),
144
+ "PPV": lambda yt, ys: precision_score(yt, ys >= best_thresh, zero_division=0),
145
+ "MCC": lambda yt, ys: matthews_corrcoef(yt, ys >= best_thresh),
146
+ "F1 Score": lambda yt, ys: f1_score(yt, ys >= best_thresh),
147
+ "AUC-ROC": lambda yt, ys: roc_auc_score(yt, ys),
148
+ "AUC-PR": lambda yt, ys: average_precision_score(yt, ys)
149
+ }
150
+
151
+ return {
152
+ name: bootstrap_metric(func, y_true, y_scores, n_boot=n_bootstrap)
153
+ for name, func in metric_functions.items()
154
+ }
155
+
156
+ def create_output_directories(output_path):
157
+ """Create necessary output directories for plots and PDF."""
158
+ output_dir = os.path.dirname(output_path)
159
+ if output_dir:
160
+ os.makedirs(output_dir, exist_ok=True)
161
+
162
+ plots_dir = os.path.join(output_dir, "plots")
163
+ os.makedirs(plots_dir, exist_ok=True)
164
+
165
+ return plots_dir
166
+
167
+ def plot_roc_pr_curves(y_true, y_scores, tpr_ci, precision_ci, common_fpr, common_recall, colors, dpi, plots_dir):
168
+ """Generate ROC and PR curves with confidence intervals."""
169
+ plt.figure(figsize=(12, 5), dpi=dpi)
170
+
171
+ plt.subplot(1, 2, 1)
172
+ fpr, tpr, _ = roc_curve(y_true, y_scores)
173
+ plt.plot(fpr, tpr, label="ROC curve", color=colors['roc_curve'])
174
+ plt.fill_between(common_fpr, tpr_ci[0], tpr_ci[1], alpha=0.3, color=colors['roc_curve'])
175
+ plt.plot([0, 1], [0, 1], "k--")
176
+ plt.xlabel("False Positive Rate")
177
+ plt.ylabel("True Positive Rate")
178
+ plt.title("ROC Curve")
179
+ plt.legend()
180
+
181
+ plt.subplot(1, 2, 2)
182
+ precision, recall, _ = precision_recall_curve(y_true, y_scores)
183
+ plt.plot(recall, precision, label="PR curve", color=colors['pr_curve'])
184
+ plt.fill_between(common_recall, precision_ci[0], precision_ci[1], alpha=0.3, color=colors['pr_curve'])
185
+ plt.xlabel("Recall")
186
+ plt.ylabel("Precision")
187
+ plt.title("Precision-Recall Curve")
188
+ plt.legend()
189
+
190
+ plt.savefig(os.path.join(plots_dir, "roc_pr.png"), dpi=dpi, bbox_inches='tight')
191
+ return plt.gcf()
192
+
193
+ def plot_metrics_threshold(metrics_df, colors, dpi, plots_dir):
194
+ """Generate metrics vs threshold plot."""
195
+ plt.figure(figsize=(10, 6), dpi=dpi)
196
+ for i, col in enumerate(metrics_df.columns[1:]):
197
+ plt.plot(metrics_df["Threshold"], metrics_df[col], label=col,
198
+ color=colors['metrics_colors'][i % len(colors['metrics_colors'])])
199
+ plt.xlabel("Threshold")
200
+ plt.ylabel("Metric Value")
201
+ plt.title("Metrics Across Thresholds")
202
+ plt.legend()
203
+
204
+ plt.savefig(os.path.join(plots_dir, "metrics_threshold.png"), dpi=dpi, bbox_inches='tight')
205
+ return plt.gcf()
206
+
207
+ def plot_confusion_matrix(y_true, y_scores, best_thresh, colors, dpi, plots_dir):
208
+ """Generate confusion matrix plot."""
209
+ cm = confusion_matrix(y_true, y_scores >= best_thresh)
210
+ plt.figure(figsize=(5, 4), dpi=dpi)
211
+ sns.heatmap(cm, annot=True, fmt="d", cmap=colors['cmap'], cbar=False, annot_kws={"size": 12})
212
+ plt.title("Confusion Matrix (Optimal Threshold)", fontsize=12)
213
+ plt.xlabel("Predicted Label", fontsize=12)
214
+ plt.ylabel("True Label", fontsize=12)
215
+
216
+ plt.savefig(os.path.join(plots_dir, "confusion_matrix.png"), dpi=dpi, bbox_inches='tight')
217
+ return plt.gcf()
218
+
219
+ def plot_calibration(y_true, y_scores, colors, dpi, plots_dir):
220
+ """Generate calibration plot."""
221
+ plt.figure(figsize=(6, 6), dpi=dpi)
222
+ prob_true, prob_pred = calibration_curve(y_true, y_scores, n_bins=10, strategy='uniform')
223
+ plt.plot(prob_pred, prob_true, marker='o', label='Calibration curve', color=colors['calibration_curve'])
224
+ plt.plot([0, 1], [0, 1], linestyle='--', color=colors['calibration_reference'])
225
+ plt.xlabel('Predicted Probability')
226
+ plt.ylabel('True Probability')
227
+ plt.title('Calibration Plot')
228
+ plt.legend()
229
+
230
+ plt.savefig(os.path.join(plots_dir, "calibration.png"), dpi=dpi, bbox_inches='tight')
231
+ return plt.gcf()
232
+
233
+ def plot_metrics_summary(metrics_summary, conf_intervals, dpi, plots_dir):
234
+ """Generate metrics summary table plot."""
235
+ fig, ax = plt.subplots(figsize=(8, 6), dpi=dpi)
236
+ ax.axis("off")
237
+ table_data = [
238
+ [k, f"{v:.3f}", f"[{conf_intervals[k][0]:.3f}, {conf_intervals[k][1]:.3f}]"]
239
+ for k, v in metrics_summary.items()
240
+ ]
241
+ table = ax.table(cellText=table_data, colLabels=["Metric", "Value", "95% CI"], loc="center")
242
+ table.auto_set_font_size(False)
243
+ table.set_fontsize(10)
244
+ table.scale(1.2, 1.2)
245
+ ax.set_title("Performance Metrics at Optimal Threshold", fontweight="bold")
246
+
247
+ plt.savefig(os.path.join(plots_dir, "metrics_summary.png"), dpi=dpi, bbox_inches='tight')
248
+ return plt.gcf()
249
+
250
+ def plot_prediction_distribution(y_true, y_scores, best_thresh, colors, dpi, plots_dir):
251
+ """Generate prediction distribution histogram."""
252
+ plt.figure(figsize=(10, 6), dpi=dpi)
253
+ plt.hist(y_scores[y_true == 1], bins=50, alpha=0.5, label='Positive Class', color=colors['positive_class'])
254
+ plt.hist(y_scores[y_true == 0], bins=50, alpha=0.5, label='Negative Class', color=colors['negative_class'])
255
+ plt.axvline(x=best_thresh, color=colors['threshold_line'], linestyle='--',
256
+ label=f'Optimal Threshold ({best_thresh:.3f})')
257
+ plt.xlabel('Predicted Probability')
258
+ plt.ylabel('Count')
259
+ plt.title('Distribution of Predictions')
260
+ plt.legend()
261
+
262
+ plt.savefig(os.path.join(plots_dir, "prediction_distribution.png"), dpi=dpi, bbox_inches='tight')
263
+ return plt.gcf()