EduardoPacheco commited on
Commit
3be65ae
·
1 Parent(s): 63a3fbe
Files changed (1) hide show
  1. utils.py +117 -0
utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.graph_objects as go
3
+ from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve, average_precision_score
4
+
5
+ def plot_multi_label_pr_curve(clf, X_test: np.ndarray, Y_test: np.ndarray):
6
+ n_classes = Y_test.shape[1]
7
+ y_score = clf.decision_function(X_test)
8
+
9
+ # For each class
10
+ precision = dict()
11
+ recall = dict()
12
+ average_precision = dict()
13
+ for i in range(n_classes):
14
+ precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i])
15
+ average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i])
16
+
17
+ # A "micro-average": quantifying score on all classes jointly
18
+ precision["micro"], recall["micro"], _ = precision_recall_curve(
19
+ Y_test.ravel(), y_score.ravel()
20
+ )
21
+ average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro")
22
+
23
+ # Plotting
24
+ fig = go.Figure()
25
+
26
+
27
+ # Plottin Precision-Recall Curves for each class
28
+ colors = ["navy", "turquoise", "darkorange", "gold"]
29
+ keys = list(precision.keys())
30
+
31
+ for color, key in zip(colors, keys):
32
+ if key=="micro":
33
+ name = f"Micro-average Precision-Recall (AP={average_precision[key]:.2f})"
34
+ else:
35
+ name = f"Precision-recall for class {key} (AP={average_precision[key]:.2f})"
36
+ fig.add_trace(
37
+ go.Scatter(
38
+ x=recall[key],
39
+ y=precision[key],
40
+ mode="lines",
41
+ name=name,
42
+ line=dict(color=color),
43
+ showlegend=True,
44
+ line_shape="hv"
45
+ )
46
+ )
47
+
48
+ # Creating Iso-F1 Curves
49
+ f_scores = np.linspace(0.2, 0.8, num=4)
50
+ for idx, f_score in enumerate(f_scores):
51
+ if idx==0:
52
+ name = "Iso-F1 Curves"
53
+ showlegend = True
54
+ else:
55
+ name = ""
56
+ showlegend = False
57
+ x = np.linspace(0.01, 1, 1001)
58
+ y = f_score * x / (2 * x - f_score)
59
+ mask = y >= 0
60
+ fig.add_trace(go.Scatter(x=x[mask], y=y[mask], mode='lines', line_color='gray', name=name, showlegend=showlegend))
61
+ fig.add_annotation(x=0.9, y=y[900] + 0.02, text=f"<b>f1={f_score:0.1f}</b>", showarrow=False, font=dict(size=15))
62
+
63
+
64
+ fig.update_yaxes(range=[0, 1.05])
65
+
66
+ fig.update_layout(
67
+ title='Extension of Precision-Recall Curve to Multi-Class',
68
+ xaxis_title='Recall',
69
+ yaxis_title='Precision',
70
+ )
71
+
72
+ return fig
73
+
74
+
75
+ def plot_binary_pr_curve(clf, X_test: np.ndarray, y_test:np.array):
76
+ # make predictions on the test data
77
+ y_pred = clf.decision_function(X_test)
78
+
79
+ # calculate precision and recall for different probability thresholds
80
+ precision, recall, _ = precision_recall_curve(y_test, y_pred)
81
+
82
+ # calculate the average precision
83
+ ap = average_precision_score(y_test, y_pred)
84
+
85
+ # Plotting
86
+ fig = go.Figure()
87
+
88
+ fig.add_trace(
89
+ go.Scatter(
90
+ x=recall,
91
+ y=precision,
92
+ mode="lines",
93
+ name=f"LinearSVC (AP={ap:.2f})",
94
+ line=dict(color="blue"),
95
+ showlegend=True,
96
+ line_shape="hv"
97
+ )
98
+ )
99
+
100
+ # Make x-range slightly larger than max value
101
+ fig.update_xaxes(range=[-0.05, 1.05])
102
+ # Make Legend text size larger
103
+ fig.update_layout(
104
+ title='2-Class Precision-Recall Curve',
105
+ xaxis_title='Recall (Positive label: 1)',
106
+ yaxis_title='Precision (Positive label: 1)',
107
+ legend=dict(
108
+ x=0.009,
109
+ y=0.05,
110
+ font=dict(
111
+ size=12,
112
+ ),
113
+ )
114
+ )
115
+
116
+ return fig
117
+