Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -68,29 +68,90 @@ def process_zip(zip_file):
|
|
68 |
shutil.rmtree(temp_dir)
|
69 |
return evaluate_model(labels, preds)
|
70 |
|
71 |
-
def
|
72 |
-
|
73 |
-
accuracy = accuracy_score(labels, preds)
|
74 |
-
roc_score = roc_auc_score(labels, preds)
|
75 |
-
|
76 |
-
# Generate classification report as a dictionary
|
77 |
report_dict = classification_report(labels, preds, output_dict=True)
|
78 |
|
79 |
-
#
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
fpr, tpr, _ = roc_curve(labels, preds)
|
91 |
roc_auc = auc(fpr, tpr)
|
92 |
|
93 |
-
# Plot Confusion Matrix and ROC Curve
|
94 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
95 |
|
96 |
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
@@ -149,7 +210,7 @@ def create_gradio_interface():
|
|
149 |
gr.Markdown(f"### Results for {MODEL_NAME}")
|
150 |
output_acc = gr.Label(label="Accuracy")
|
151 |
output_roc = gr.Label(label="ROC Score")
|
152 |
-
output_report = gr.HTML(label="Classification Report")
|
153 |
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
|
154 |
|
155 |
load_btn.click(load_url, in_url, [inp, message])
|
@@ -169,4 +230,4 @@ def create_gradio_interface():
|
|
169 |
|
170 |
if __name__ == "__main__":
|
171 |
app = create_gradio_interface()
|
172 |
-
app.launch(show_api=False, max_threads=24)
|
|
|
68 |
shutil.rmtree(temp_dir)
|
69 |
return evaluate_model(labels, preds)
|
70 |
|
71 |
+
def format_classification_report(labels, preds):
|
72 |
+
# Convert the report string to a dictionary
|
|
|
|
|
|
|
|
|
73 |
report_dict = classification_report(labels, preds, output_dict=True)
|
74 |
|
75 |
+
# Create an HTML table
|
76 |
+
html = """
|
77 |
+
<style>
|
78 |
+
.report-table {
|
79 |
+
border-collapse: collapse;
|
80 |
+
width: 100%;
|
81 |
+
font-family: Arial, sans-serif;
|
82 |
+
}
|
83 |
+
.report-table th, .report-table td {
|
84 |
+
border: 1px solid #ddd;
|
85 |
+
padding: 8px;
|
86 |
+
text-align: center;
|
87 |
+
}
|
88 |
+
.report-table th {
|
89 |
+
background-color: #f2f2f2;
|
90 |
+
font-weight: bold;
|
91 |
+
}
|
92 |
+
.report-table tr:nth-child(even) {
|
93 |
+
background-color: #f9f9f9;
|
94 |
+
}
|
95 |
+
.report-table tr:hover {
|
96 |
+
background-color: #f5f5f5;
|
97 |
+
}
|
98 |
+
</style>
|
99 |
+
<table class="report-table">
|
100 |
+
<tr>
|
101 |
+
<th>Class</th>
|
102 |
+
<th>Precision</th>
|
103 |
+
<th>Recall</th>
|
104 |
+
<th>F1-Score</th>
|
105 |
+
<th>Support</th>
|
106 |
+
</tr>
|
107 |
+
"""
|
108 |
|
109 |
+
# Add rows for each class
|
110 |
+
for class_name in ['0', '1']:
|
111 |
+
html += f"""
|
112 |
+
<tr>
|
113 |
+
<td>{class_name}</td>
|
114 |
+
<td>{report_dict[class_name]['precision']:.2f}</td>
|
115 |
+
<td>{report_dict[class_name]['recall']:.2f}</td>
|
116 |
+
<td>{report_dict[class_name]['f1-score']:.2f}</td>
|
117 |
+
<td>{report_dict[class_name]['support']}</td>
|
118 |
+
</tr>
|
119 |
+
"""
|
120 |
|
121 |
+
# Add summary rows
|
122 |
+
html += f"""
|
123 |
+
<tr>
|
124 |
+
<td>Accuracy</td>
|
125 |
+
<td colspan="3">{report_dict['accuracy']:.2f}</td>
|
126 |
+
<td>{report_dict['macro avg']['support']}</td>
|
127 |
+
</tr>
|
128 |
+
<tr>
|
129 |
+
<td>Macro Avg</td>
|
130 |
+
<td>{report_dict['macro avg']['precision']:.2f}</td>
|
131 |
+
<td>{report_dict['macro avg']['recall']:.2f}</td>
|
132 |
+
<td>{report_dict['macro avg']['f1-score']:.2f}</td>
|
133 |
+
<td>{report_dict['macro avg']['support']}</td>
|
134 |
+
</tr>
|
135 |
+
<tr>
|
136 |
+
<td>Weighted Avg</td>
|
137 |
+
<td>{report_dict['weighted avg']['precision']:.2f}</td>
|
138 |
+
<td>{report_dict['weighted avg']['recall']:.2f}</td>
|
139 |
+
<td>{report_dict['weighted avg']['f1-score']:.2f}</td>
|
140 |
+
<td>{report_dict['weighted avg']['support']}</td>
|
141 |
+
</tr>
|
142 |
+
</table>
|
143 |
+
"""
|
144 |
|
145 |
+
return html
|
146 |
+
|
147 |
+
def evaluate_model(labels, preds):
|
148 |
+
cm = confusion_matrix(labels, preds)
|
149 |
+
accuracy = accuracy_score(labels, preds)
|
150 |
+
roc_score = roc_auc_score(labels, preds)
|
151 |
+
report_html = format_classification_report(labels, preds)
|
152 |
fpr, tpr, _ = roc_curve(labels, preds)
|
153 |
roc_auc = auc(fpr, tpr)
|
154 |
|
|
|
155 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
156 |
|
157 |
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
|
|
210 |
gr.Markdown(f"### Results for {MODEL_NAME}")
|
211 |
output_acc = gr.Label(label="Accuracy")
|
212 |
output_roc = gr.Label(label="ROC Score")
|
213 |
+
output_report = gr.HTML(label="Classification Report")
|
214 |
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
|
215 |
|
216 |
load_btn.click(load_url, in_url, [inp, message])
|
|
|
230 |
|
231 |
if __name__ == "__main__":
|
232 |
app = create_gradio_interface()
|
233 |
+
app.launch(show_api=False, max_threads=24)
|