cmckinle commited on
Commit
8f77a40
·
verified ·
1 Parent(s): a38dfd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -68,11 +68,33 @@ def process_zip(zip_file):
68
  shutil.rmtree(temp_dir)
69
  return evaluate_model(labels, preds)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def evaluate_model(labels, preds):
72
  cm = confusion_matrix(labels, preds)
73
  accuracy = accuracy_score(labels, preds)
74
  roc_score = roc_auc_score(labels, preds)
75
- report = classification_report(labels, preds)
 
76
  fpr, tpr, _ = roc_curve(labels, preds)
77
  roc_auc = auc(fpr, tpr)
78
 
@@ -92,7 +114,7 @@ def evaluate_model(labels, preds):
92
 
93
  plt.tight_layout()
94
 
95
- return accuracy, roc_score, report, fig
96
 
97
  def load_url(url):
98
  try:
 
68
  shutil.rmtree(temp_dir)
69
  return evaluate_model(labels, preds)
70
 
71
+ def format_classification_report(report_dict):
72
+ headers = ["precision", "recall", "f1-score", "support"]
73
+ row_format = "{:>12}" * (len(headers) + 1)
74
+
75
+ formatted_report = row_format.format("", *headers) + "\n\n"
76
+
77
+ for label, metrics in report_dict.items():
78
+ if label in ["accuracy", "macro avg", "weighted avg"]:
79
+ formatted_report += "\n"
80
+ row = [label]
81
+ for header in headers:
82
+ value = metrics.get(header, "")
83
+ if isinstance(value, float):
84
+ value = f"{value:.2f}"
85
+ elif isinstance(value, int):
86
+ value = str(value)
87
+ row.append(value)
88
+ formatted_report += row_format.format(*row) + "\n"
89
+
90
+ return formatted_report
91
+
92
  def evaluate_model(labels, preds):
93
  cm = confusion_matrix(labels, preds)
94
  accuracy = accuracy_score(labels, preds)
95
  roc_score = roc_auc_score(labels, preds)
96
+ report_dict = classification_report(labels, preds, target_names=LABELS, output_dict=True)
97
+ formatted_report = format_classification_report(report_dict)
98
  fpr, tpr, _ = roc_curve(labels, preds)
99
  roc_auc = auc(fpr, tpr)
100
 
 
114
 
115
  plt.tight_layout()
116
 
117
+ return accuracy, roc_score, formatted_report, fig
118
 
119
  def load_url(url):
120
  try: