cmckinle commited on
Commit
80bc071
·
verified ·
1 Parent(s): 24dd2a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -17
app.py CHANGED
@@ -68,29 +68,90 @@ def process_zip(zip_file):
68
  shutil.rmtree(temp_dir)
69
  return evaluate_model(labels, preds)
70
 
71
- def evaluate_model(labels, preds):
72
- cm = confusion_matrix(labels, preds)
73
- accuracy = accuracy_score(labels, preds)
74
- roc_score = roc_auc_score(labels, preds)
75
-
76
- # Generate classification report as a dictionary
77
  report_dict = classification_report(labels, preds, output_dict=True)
78
 
79
- # Formatting the report as an HTML string for cleaner display
80
- report_html = "<h4>Classification Report</h4>"
81
- report_html += "<table><tr><th>Class</th><th>Precision</th><th>Recall</th><th>F1-Score</th></tr>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- for key, value in report_dict.items():
84
- if isinstance(value, dict):
85
- report_html += f"<tr><td>{key}</td><td>{value['precision']:.2f}</td><td>{value['recall']:.2f}</td><td>{value['f1-score']:.2f}</td></tr>"
 
 
 
 
 
 
 
 
86
 
87
- report_html += "</table>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Compute ROC Curve
 
 
 
 
 
 
90
  fpr, tpr, _ = roc_curve(labels, preds)
91
  roc_auc = auc(fpr, tpr)
92
 
93
- # Plot Confusion Matrix and ROC Curve
94
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
95
 
96
  ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
@@ -149,7 +210,7 @@ def create_gradio_interface():
149
  gr.Markdown(f"### Results for {MODEL_NAME}")
150
  output_acc = gr.Label(label="Accuracy")
151
  output_roc = gr.Label(label="ROC Score")
152
- output_report = gr.HTML(label="Classification Report") # Changed to gr.HTML for formatted display
153
  output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
154
 
155
  load_btn.click(load_url, in_url, [inp, message])
@@ -169,4 +230,4 @@ def create_gradio_interface():
169
 
170
  if __name__ == "__main__":
171
  app = create_gradio_interface()
172
- app.launch(show_api=False, max_threads=24)
 
68
  shutil.rmtree(temp_dir)
69
  return evaluate_model(labels, preds)
70
 
71
+ def format_classification_report(labels, preds):
72
+ # Convert the report string to a dictionary
 
 
 
 
73
  report_dict = classification_report(labels, preds, output_dict=True)
74
 
75
+ # Create an HTML table
76
+ html = """
77
+ <style>
78
+ .report-table {
79
+ border-collapse: collapse;
80
+ width: 100%;
81
+ font-family: Arial, sans-serif;
82
+ }
83
+ .report-table th, .report-table td {
84
+ border: 1px solid #ddd;
85
+ padding: 8px;
86
+ text-align: center;
87
+ }
88
+ .report-table th {
89
+ background-color: #f2f2f2;
90
+ font-weight: bold;
91
+ }
92
+ .report-table tr:nth-child(even) {
93
+ background-color: #f9f9f9;
94
+ }
95
+ .report-table tr:hover {
96
+ background-color: #f5f5f5;
97
+ }
98
+ </style>
99
+ <table class="report-table">
100
+ <tr>
101
+ <th>Class</th>
102
+ <th>Precision</th>
103
+ <th>Recall</th>
104
+ <th>F1-Score</th>
105
+ <th>Support</th>
106
+ </tr>
107
+ """
108
 
109
+ # Add rows for each class
110
+ for class_name in ['0', '1']:
111
+ html += f"""
112
+ <tr>
113
+ <td>{class_name}</td>
114
+ <td>{report_dict[class_name]['precision']:.2f}</td>
115
+ <td>{report_dict[class_name]['recall']:.2f}</td>
116
+ <td>{report_dict[class_name]['f1-score']:.2f}</td>
117
+ <td>{report_dict[class_name]['support']}</td>
118
+ </tr>
119
+ """
120
 
121
+ # Add summary rows
122
+ html += f"""
123
+ <tr>
124
+ <td>Accuracy</td>
125
+ <td colspan="3">{report_dict['accuracy']:.2f}</td>
126
+ <td>{report_dict['macro avg']['support']}</td>
127
+ </tr>
128
+ <tr>
129
+ <td>Macro Avg</td>
130
+ <td>{report_dict['macro avg']['precision']:.2f}</td>
131
+ <td>{report_dict['macro avg']['recall']:.2f}</td>
132
+ <td>{report_dict['macro avg']['f1-score']:.2f}</td>
133
+ <td>{report_dict['macro avg']['support']}</td>
134
+ </tr>
135
+ <tr>
136
+ <td>Weighted Avg</td>
137
+ <td>{report_dict['weighted avg']['precision']:.2f}</td>
138
+ <td>{report_dict['weighted avg']['recall']:.2f}</td>
139
+ <td>{report_dict['weighted avg']['f1-score']:.2f}</td>
140
+ <td>{report_dict['weighted avg']['support']}</td>
141
+ </tr>
142
+ </table>
143
+ """
144
 
145
+ return html
146
+
147
+ def evaluate_model(labels, preds):
148
+ cm = confusion_matrix(labels, preds)
149
+ accuracy = accuracy_score(labels, preds)
150
+ roc_score = roc_auc_score(labels, preds)
151
+ report_html = format_classification_report(labels, preds)
152
  fpr, tpr, _ = roc_curve(labels, preds)
153
  roc_auc = auc(fpr, tpr)
154
 
 
155
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
156
 
157
  ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
 
210
  gr.Markdown(f"### Results for {MODEL_NAME}")
211
  output_acc = gr.Label(label="Accuracy")
212
  output_roc = gr.Label(label="ROC Score")
213
+ output_report = gr.HTML(label="Classification Report")
214
  output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
215
 
216
  load_btn.click(load_url, in_url, [inp, message])
 
230
 
231
  if __name__ == "__main__":
232
  app = create_gradio_interface()
233
+ app.launch(show_api=False, max_threads=24)