Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -72,10 +72,25 @@ def evaluate_model(labels, preds):
|
|
72 |
cm = confusion_matrix(labels, preds)
|
73 |
accuracy = accuracy_score(labels, preds)
|
74 |
roc_score = roc_auc_score(labels, preds)
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
fpr, tpr, _ = roc_curve(labels, preds)
|
77 |
roc_auc = auc(fpr, tpr)
|
78 |
|
|
|
79 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
80 |
|
81 |
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
@@ -92,7 +107,7 @@ def evaluate_model(labels, preds):
|
|
92 |
|
93 |
plt.tight_layout()
|
94 |
|
95 |
-
return accuracy, roc_score,
|
96 |
|
97 |
def load_url(url):
|
98 |
try:
|
@@ -134,7 +149,7 @@ def create_gradio_interface():
|
|
134 |
gr.Markdown(f"### Results for {MODEL_NAME}")
|
135 |
output_acc = gr.Label(label="Accuracy")
|
136 |
output_roc = gr.Label(label="ROC Score")
|
137 |
-
output_report = gr.
|
138 |
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
|
139 |
|
140 |
load_btn.click(load_url, in_url, [inp, message])
|
@@ -154,4 +169,4 @@ def create_gradio_interface():
|
|
154 |
|
155 |
if __name__ == "__main__":
|
156 |
app = create_gradio_interface()
|
157 |
-
app.launch(show_api=False, max_threads=24)
|
|
|
72 |
cm = confusion_matrix(labels, preds)
|
73 |
accuracy = accuracy_score(labels, preds)
|
74 |
roc_score = roc_auc_score(labels, preds)
|
75 |
+
|
76 |
+
# Generate classification report as a dictionary
|
77 |
+
report_dict = classification_report(labels, preds, output_dict=True)
|
78 |
+
|
79 |
+
# Formatting the report as an HTML string for cleaner display
|
80 |
+
report_html = "<h4>Classification Report</h4>"
|
81 |
+
report_html += "<table><tr><th>Class</th><th>Precision</th><th>Recall</th><th>F1-Score</th></tr>"
|
82 |
+
|
83 |
+
for key, value in report_dict.items():
|
84 |
+
if isinstance(value, dict):
|
85 |
+
report_html += f"<tr><td>{key}</td><td>{value['precision']:.2f}</td><td>{value['recall']:.2f}</td><td>{value['f1-score']:.2f}</td></tr>"
|
86 |
+
|
87 |
+
report_html += "</table>"
|
88 |
+
|
89 |
+
# Compute ROC Curve
|
90 |
fpr, tpr, _ = roc_curve(labels, preds)
|
91 |
roc_auc = auc(fpr, tpr)
|
92 |
|
93 |
+
# Plot Confusion Matrix and ROC Curve
|
94 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
95 |
|
96 |
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
|
|
107 |
|
108 |
plt.tight_layout()
|
109 |
|
110 |
+
return accuracy, roc_score, report_html, fig
|
111 |
|
112 |
def load_url(url):
|
113 |
try:
|
|
|
149 |
gr.Markdown(f"### Results for {MODEL_NAME}")
|
150 |
output_acc = gr.Label(label="Accuracy")
|
151 |
output_roc = gr.Label(label="ROC Score")
|
152 |
+
output_report = gr.HTML(label="Classification Report") # Changed to gr.HTML for formatted display
|
153 |
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
|
154 |
|
155 |
load_btn.click(load_url, in_url, [inp, message])
|
|
|
169 |
|
170 |
if __name__ == "__main__":
|
171 |
app = create_gradio_interface()
|
172 |
+
app.launch(show_api=False, max_threads=24)
|