cmckinle commited on
Commit
24dd2a8
·
verified ·
1 Parent(s): c07149a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -72,10 +72,25 @@ def evaluate_model(labels, preds):
72
  cm = confusion_matrix(labels, preds)
73
  accuracy = accuracy_score(labels, preds)
74
  roc_score = roc_auc_score(labels, preds)
75
- report = classification_report(labels, preds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  fpr, tpr, _ = roc_curve(labels, preds)
77
  roc_auc = auc(fpr, tpr)
78
 
 
79
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
80
 
81
  ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
@@ -92,7 +107,7 @@ def evaluate_model(labels, preds):
92
 
93
  plt.tight_layout()
94
 
95
- return accuracy, roc_score, report, fig
96
 
97
  def load_url(url):
98
  try:
@@ -134,7 +149,7 @@ def create_gradio_interface():
134
  gr.Markdown(f"### Results for {MODEL_NAME}")
135
  output_acc = gr.Label(label="Accuracy")
136
  output_roc = gr.Label(label="ROC Score")
137
- output_report = gr.Textbox(label="Classification Report", lines=10)
138
  output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
139
 
140
  load_btn.click(load_url, in_url, [inp, message])
@@ -154,4 +169,4 @@ def create_gradio_interface():
154
 
155
  if __name__ == "__main__":
156
  app = create_gradio_interface()
157
- app.launch(show_api=False, max_threads=24)
 
72
  cm = confusion_matrix(labels, preds)
73
  accuracy = accuracy_score(labels, preds)
74
  roc_score = roc_auc_score(labels, preds)
75
+
76
+ # Generate classification report as a dictionary
77
+ report_dict = classification_report(labels, preds, output_dict=True)
78
+
79
+ # Formatting the report as an HTML string for cleaner display
80
+ report_html = "<h4>Classification Report</h4>"
81
+ report_html += "<table><tr><th>Class</th><th>Precision</th><th>Recall</th><th>F1-Score</th></tr>"
82
+
83
+ for key, value in report_dict.items():
84
+ if isinstance(value, dict):
85
+ report_html += f"<tr><td>{key}</td><td>{value['precision']:.2f}</td><td>{value['recall']:.2f}</td><td>{value['f1-score']:.2f}</td></tr>"
86
+
87
+ report_html += "</table>"
88
+
89
+ # Compute ROC Curve
90
  fpr, tpr, _ = roc_curve(labels, preds)
91
  roc_auc = auc(fpr, tpr)
92
 
93
+ # Plot Confusion Matrix and ROC Curve
94
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
95
 
96
  ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
 
107
 
108
  plt.tight_layout()
109
 
110
+ return accuracy, roc_score, report_html, fig
111
 
112
  def load_url(url):
113
  try:
 
149
  gr.Markdown(f"### Results for {MODEL_NAME}")
150
  output_acc = gr.Label(label="Accuracy")
151
  output_roc = gr.Label(label="ROC Score")
152
+ output_report = gr.HTML(label="Classification Report") # Changed to gr.HTML for formatted display
153
  output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
154
 
155
  load_btn.click(load_url, in_url, [inp, message])
 
169
 
170
  if __name__ == "__main__":
171
  app = create_gradio_interface()
172
+ app.launch(show_api=False, max_threads=24)