Spaces:
Sleeping
Sleeping
put comments
Browse files- pages/Model_Evaluation.py +33 -13
- training/training.ipynb +37 -51
pages/Model_Evaluation.py
CHANGED
@@ -161,64 +161,76 @@ with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True):
|
|
161 |
|
162 |
|
163 |
# ---- Evaluation Logic ----
|
|
|
164 |
if st.session_state.trigger_eval:
|
165 |
st.markdown("### ⏱️ Evaluation Results")
|
166 |
|
|
|
167 |
start_time = time.time()
|
168 |
-
y_true = []
|
169 |
-
y_pred = []
|
170 |
-
y_score = []
|
171 |
-
misclassified_images = []
|
172 |
|
173 |
-
total_batches = len(test_loader)
|
174 |
-
progress_bar = st.progress(0)
|
175 |
-
status_text = st.empty()
|
176 |
-
stop_info = st.empty()
|
177 |
|
|
|
178 |
with torch.no_grad():
|
179 |
for i, (images, labels) in enumerate(test_loader):
|
|
|
180 |
if st.session_state.stop_eval:
|
181 |
stop_info.warning("🚩 Evaluation stopped by user.")
|
182 |
break
|
183 |
|
|
|
184 |
outputs = model(images)
|
185 |
-
_, predicted = torch.max(outputs, 1)
|
186 |
y_true.extend(labels.numpy())
|
187 |
y_pred.extend(predicted.numpy())
|
188 |
y_score.extend(outputs.detach().numpy())
|
189 |
|
|
|
190 |
for j in range(len(labels)):
|
191 |
if predicted[j] != labels[j]:
|
192 |
misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
|
193 |
|
|
|
194 |
percent_complete = (i + 1) / total_batches
|
195 |
progress_bar.progress(min(percent_complete, 1.0))
|
196 |
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
|
197 |
-
time.sleep(0.1)
|
198 |
|
199 |
end_time = time.time()
|
200 |
-
eval_time = end_time - start_time
|
201 |
|
|
|
202 |
if not st.session_state.stop_eval:
|
203 |
st.session_state.evaluation_done = True
|
204 |
st.session_state.trigger_eval = False # ✅ Reset the trigger
|
205 |
st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
|
206 |
|
|
|
207 |
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
208 |
report_df = pd.DataFrame(report).transpose()
|
209 |
st.dataframe(report_df.style.format("{:.2f}"))
|
210 |
|
|
|
211 |
pdf = FPDF()
|
212 |
pdf.add_page()
|
213 |
pdf.set_font("Arial", size=12)
|
214 |
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
|
215 |
|
|
|
216 |
col_widths = [40, 40, 40, 40]
|
217 |
headers = ["Class", "Precision", "Recall", "F1-Score"]
|
218 |
for i, header in enumerate(headers):
|
219 |
pdf.cell(col_widths[i], 10, header, border=1)
|
220 |
pdf.ln()
|
221 |
|
|
|
222 |
for idx, row in report_df.iterrows():
|
223 |
if idx in ['accuracy', 'macro avg', 'weighted avg']:
|
224 |
continue
|
@@ -228,6 +240,7 @@ if st.session_state.trigger_eval:
|
|
228 |
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
|
229 |
pdf.ln()
|
230 |
|
|
|
231 |
cm = confusion_matrix(y_true, y_pred)
|
232 |
fig_cm, ax = plt.subplots()
|
233 |
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
|
@@ -235,12 +248,15 @@ if st.session_state.trigger_eval:
|
|
235 |
ax.set_ylabel('True')
|
236 |
ax.set_title("Confusion Matrix")
|
237 |
st.pyplot(fig_cm)
|
|
|
|
|
238 |
cm_path = "confusion_matrix.png"
|
239 |
fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
|
240 |
plt.close(fig_cm)
|
241 |
if os.path.exists(cm_path):
|
242 |
pdf.image(cm_path, x=10, y=None, w=180)
|
243 |
|
|
|
244 |
y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
|
245 |
y_score_np = np.array(y_score)
|
246 |
fig_roc, ax = plt.subplots()
|
@@ -249,26 +265,30 @@ if st.session_state.trigger_eval:
|
|
249 |
roc_auc = auc(fpr, tpr)
|
250 |
ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
|
251 |
|
252 |
-
ax.plot([0, 1], [0, 1], 'k--')
|
253 |
ax.set_xlabel('False Positive Rate')
|
254 |
ax.set_ylabel('True Positive Rate')
|
255 |
ax.set_title('Multi-class ROC Curve')
|
256 |
ax.legend(loc='lower right')
|
257 |
st.pyplot(fig_roc)
|
|
|
|
|
258 |
roc_path = "roc_curve.png"
|
259 |
fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
|
260 |
plt.close(fig_roc)
|
261 |
if os.path.exists(roc_path):
|
262 |
pdf.image(roc_path, x=10, y=None, w=180)
|
263 |
|
|
|
264 |
st.markdown("### ❌ Misclassified Samples")
|
265 |
fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
|
266 |
for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
|
267 |
-
axs[idx].imshow(img.permute(1, 2, 0))
|
268 |
axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
|
269 |
axs[idx].axis('off')
|
270 |
st.pyplot(fig_mis)
|
271 |
|
|
|
272 |
output_pdf = "evaluation_report.pdf"
|
273 |
pdf.output(output_pdf)
|
274 |
with open(output_pdf, "rb") as f:
|
|
|
161 |
|
162 |
|
163 |
# ---- Evaluation Logic ----
|
164 |
+
# Check if evaluation should be triggered
|
165 |
if st.session_state.trigger_eval:
|
166 |
st.markdown("### ⏱️ Evaluation Results")
|
167 |
|
168 |
+
# Start timing the evaluation
|
169 |
start_time = time.time()
|
170 |
+
y_true = [] # Ground truth labels
|
171 |
+
y_pred = [] # Predicted labels
|
172 |
+
y_score = [] # Raw model outputs
|
173 |
+
misclassified_images = [] # List to store misclassified samples
|
174 |
|
175 |
+
total_batches = len(test_loader) # Total number of batches
|
176 |
+
progress_bar = st.progress(0) # Initialize progress bar
|
177 |
+
status_text = st.empty() # Placeholder for status updates
|
178 |
+
stop_info = st.empty() # Placeholder for stop message
|
179 |
|
180 |
+
# Disable gradient calculation for faster evaluation
|
181 |
with torch.no_grad():
|
182 |
for i, (images, labels) in enumerate(test_loader):
|
183 |
+
# Allow user to stop the evaluation
|
184 |
if st.session_state.stop_eval:
|
185 |
stop_info.warning("🚩 Evaluation stopped by user.")
|
186 |
break
|
187 |
|
188 |
+
# Run model on input images
|
189 |
outputs = model(images)
|
190 |
+
_, predicted = torch.max(outputs, 1) # Get predicted class
|
191 |
y_true.extend(labels.numpy())
|
192 |
y_pred.extend(predicted.numpy())
|
193 |
y_score.extend(outputs.detach().numpy())
|
194 |
|
195 |
+
# Store misclassified samples
|
196 |
for j in range(len(labels)):
|
197 |
if predicted[j] != labels[j]:
|
198 |
misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
|
199 |
|
200 |
+
# Update progress bar and status text
|
201 |
percent_complete = (i + 1) / total_batches
|
202 |
progress_bar.progress(min(percent_complete, 1.0))
|
203 |
status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
|
204 |
+
time.sleep(0.1) # Add delay for UI responsiveness
|
205 |
|
206 |
end_time = time.time()
|
207 |
+
eval_time = end_time - start_time # Total evaluation time
|
208 |
|
209 |
+
# Finalize evaluation if not stopped
|
210 |
if not st.session_state.stop_eval:
|
211 |
st.session_state.evaluation_done = True
|
212 |
st.session_state.trigger_eval = False # ✅ Reset the trigger
|
213 |
st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
|
214 |
|
215 |
+
# Generate classification report and display as a DataFrame
|
216 |
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
217 |
report_df = pd.DataFrame(report).transpose()
|
218 |
st.dataframe(report_df.style.format("{:.2f}"))
|
219 |
|
220 |
+
# Initialize PDF report
|
221 |
pdf = FPDF()
|
222 |
pdf.add_page()
|
223 |
pdf.set_font("Arial", size=12)
|
224 |
pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
|
225 |
|
226 |
+
# Add table headers
|
227 |
col_widths = [40, 40, 40, 40]
|
228 |
headers = ["Class", "Precision", "Recall", "F1-Score"]
|
229 |
for i, header in enumerate(headers):
|
230 |
pdf.cell(col_widths[i], 10, header, border=1)
|
231 |
pdf.ln()
|
232 |
|
233 |
+
# Add metrics for each class
|
234 |
for idx, row in report_df.iterrows():
|
235 |
if idx in ['accuracy', 'macro avg', 'weighted avg']:
|
236 |
continue
|
|
|
240 |
pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
|
241 |
pdf.ln()
|
242 |
|
243 |
+
# Create and display confusion matrix
|
244 |
cm = confusion_matrix(y_true, y_pred)
|
245 |
fig_cm, ax = plt.subplots()
|
246 |
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
|
|
|
248 |
ax.set_ylabel('True')
|
249 |
ax.set_title("Confusion Matrix")
|
250 |
st.pyplot(fig_cm)
|
251 |
+
|
252 |
+
# Save confusion matrix to PDF
|
253 |
cm_path = "confusion_matrix.png"
|
254 |
fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
|
255 |
plt.close(fig_cm)
|
256 |
if os.path.exists(cm_path):
|
257 |
pdf.image(cm_path, x=10, y=None, w=180)
|
258 |
|
259 |
+
# Create and display ROC curve for each class
|
260 |
y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
|
261 |
y_score_np = np.array(y_score)
|
262 |
fig_roc, ax = plt.subplots()
|
|
|
265 |
roc_auc = auc(fpr, tpr)
|
266 |
ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
|
267 |
|
268 |
+
ax.plot([0, 1], [0, 1], 'k--') # Diagonal reference line
|
269 |
ax.set_xlabel('False Positive Rate')
|
270 |
ax.set_ylabel('True Positive Rate')
|
271 |
ax.set_title('Multi-class ROC Curve')
|
272 |
ax.legend(loc='lower right')
|
273 |
st.pyplot(fig_roc)
|
274 |
+
|
275 |
+
# Save ROC curve to PDF
|
276 |
roc_path = "roc_curve.png"
|
277 |
fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
|
278 |
plt.close(fig_roc)
|
279 |
if os.path.exists(roc_path):
|
280 |
pdf.image(roc_path, x=10, y=None, w=180)
|
281 |
|
282 |
+
# Show misclassified samples (up to 5)
|
283 |
st.markdown("### ❌ Misclassified Samples")
|
284 |
fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
|
285 |
for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
|
286 |
+
axs[idx].imshow(img.permute(1, 2, 0)) # Convert tensor to image format
|
287 |
axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
|
288 |
axs[idx].axis('off')
|
289 |
st.pyplot(fig_mis)
|
290 |
|
291 |
+
# Save PDF and provide download button
|
292 |
output_pdf = "evaluation_report.pdf"
|
293 |
pdf.output(output_pdf)
|
294 |
with open(output_pdf, "rb") as f:
|
training/training.ipynb
CHANGED
@@ -1017,7 +1017,7 @@
|
|
1017 |
},
|
1018 |
{
|
1019 |
"cell_type": "code",
|
1020 |
-
"execution_count":
|
1021 |
"id": "560a2a1b",
|
1022 |
"metadata": {},
|
1023 |
"outputs": [
|
@@ -1035,38 +1035,70 @@
|
|
1035 |
"source": [
|
1036 |
"from sklearn.metrics import roc_curve, auc\n",
|
1037 |
"from sklearn.preprocessing import label_binarize\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
1038 |
"\n",
|
1039 |
-
"
|
|
|
1040 |
"y_true_bin = label_binarize(all_labels, classes=[0, 1, 2, 3, 4])\n",
|
|
|
|
|
1041 |
"y_scores = []\n",
|
1042 |
"\n",
|
|
|
1043 |
"model.eval()\n",
|
|
|
|
|
1044 |
"with torch.no_grad():\n",
|
1045 |
" for inputs, labels in test_loader:\n",
|
1046 |
" inputs = inputs.to(device)\n",
|
|
|
|
|
1047 |
" outputs = model(inputs)\n",
|
|
|
|
|
1048 |
" probs = torch.softmax(outputs, dim=1)\n",
|
|
|
|
|
1049 |
" y_scores.extend(probs.cpu().numpy())\n",
|
1050 |
"\n",
|
1051 |
-
"#
|
1052 |
-
"fpr, tpr, roc_auc = dict(), dict(), dict()\n",
|
1053 |
"y_scores = np.array(y_scores)\n",
|
1054 |
"\n",
|
|
|
|
|
|
|
|
|
1055 |
"for i in range(n_classes):\n",
|
|
|
1056 |
" fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n",
|
|
|
|
|
1057 |
" roc_auc[i] = auc(fpr[i], tpr[i])\n",
|
1058 |
"\n",
|
1059 |
-
"# Plot
|
1060 |
"plt.figure(figsize=(10, 7))\n",
|
|
|
1061 |
"for i in range(n_classes):\n",
|
1062 |
" plt.plot(fpr[i], tpr[i], label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')\n",
|
1063 |
"\n",
|
|
|
1064 |
"plt.plot([0, 1], [0, 1], 'k--')\n",
|
|
|
|
|
1065 |
"plt.title('Multi-class ROC Curve')\n",
|
1066 |
"plt.xlabel('False Positive Rate')\n",
|
1067 |
"plt.ylabel('True Positive Rate')\n",
|
|
|
|
|
1068 |
"plt.legend(loc='lower right')\n",
|
1069 |
"plt.grid(True)\n",
|
|
|
|
|
1070 |
"plt.show()\n"
|
1071 |
]
|
1072 |
},
|
@@ -1474,52 +1506,6 @@
|
|
1474 |
"print(f\"Predicted Class: {predicted_class}\")\n",
|
1475 |
"print(f\"Confidence: {confidence_percentage:.2f}%\")"
|
1476 |
]
|
1477 |
-
},
|
1478 |
-
{
|
1479 |
-
"cell_type": "code",
|
1480 |
-
"execution_count": 1,
|
1481 |
-
"id": "eb2308ed",
|
1482 |
-
"metadata": {},
|
1483 |
-
"outputs": [
|
1484 |
-
{
|
1485 |
-
"ename": "FileNotFoundError",
|
1486 |
-
"evalue": "[Errno 2] No such file or directory: 'D:\\\\DR_Classification\\\\dataset\\\\splits\\\\test_labels.csv'",
|
1487 |
-
"output_type": "error",
|
1488 |
-
"traceback": [
|
1489 |
-
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
1490 |
-
"\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)",
|
1491 |
-
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 12\u001b[39m\n\u001b[32m 9\u001b[39m new_dir = \u001b[33mr\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mD:\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mDR_Classification\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mdataset\u001b[39m\u001b[33m\\\u001b[39m\u001b[33msplitted-data\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mtest\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 11\u001b[39m \u001b[38;5;66;03m# === Load the CSV ===\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m12\u001b[39m df = \u001b[43mpd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcsv_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 14\u001b[39m \u001b[38;5;66;03m# === Replace old path with new path in 'new_path' column ===\u001b[39;00m\n\u001b[32m 15\u001b[39m df[\u001b[33m'\u001b[39m\u001b[33mnew_path\u001b[39m\u001b[33m'\u001b[39m] = df[\u001b[33m'\u001b[39m\u001b[33mnew_path\u001b[39m\u001b[33m'\u001b[39m].str.replace(old_dir, new_dir, regex=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
1492 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1026\u001b[39m, in \u001b[36mread_csv\u001b[39m\u001b[34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[39m\n\u001b[32m 1013\u001b[39m kwds_defaults = _refine_defaults_read(\n\u001b[32m 1014\u001b[39m dialect,\n\u001b[32m 1015\u001b[39m delimiter,\n\u001b[32m (...)\u001b[39m\u001b[32m 1022\u001b[39m dtype_backend=dtype_backend,\n\u001b[32m 1023\u001b[39m )\n\u001b[32m 1024\u001b[39m kwds.update(kwds_defaults)\n\u001b[32m-> \u001b[39m\u001b[32m1026\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n",
|
1493 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:620\u001b[39m, in \u001b[36m_read\u001b[39m\u001b[34m(filepath_or_buffer, kwds)\u001b[39m\n\u001b[32m 617\u001b[39m _validate_names(kwds.get(\u001b[33m\"\u001b[39m\u001b[33mnames\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[32m 619\u001b[39m \u001b[38;5;66;03m# Create the parser.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m620\u001b[39m parser = \u001b[43mTextFileReader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 622\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m chunksize \u001b[38;5;129;01mor\u001b[39;00m iterator:\n\u001b[32m 623\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m parser\n",
|
1494 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1620\u001b[39m, in \u001b[36mTextFileReader.__init__\u001b[39m\u001b[34m(self, f, engine, **kwds)\u001b[39m\n\u001b[32m 1617\u001b[39m \u001b[38;5;28mself\u001b[39m.options[\u001b[33m\"\u001b[39m\u001b[33mhas_index_names\u001b[39m\u001b[33m\"\u001b[39m] = kwds[\u001b[33m\"\u001b[39m\u001b[33mhas_index_names\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 1619\u001b[39m \u001b[38;5;28mself\u001b[39m.handles: IOHandles | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1620\u001b[39m \u001b[38;5;28mself\u001b[39m._engine = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_make_engine\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mengine\u001b[49m\u001b[43m)\u001b[49m\n",
|
1495 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1880\u001b[39m, in \u001b[36mTextFileReader._make_engine\u001b[39m\u001b[34m(self, f, engine)\u001b[39m\n\u001b[32m 1878\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[32m 1879\u001b[39m mode += \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m-> \u001b[39m\u001b[32m1880\u001b[39m \u001b[38;5;28mself\u001b[39m.handles = \u001b[43mget_handle\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1881\u001b[39m \u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1882\u001b[39m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1883\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mencoding\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1884\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompression\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcompression\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1885\u001b[39m \u001b[43m \u001b[49m\u001b[43mmemory_map\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmemory_map\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1886\u001b[39m \u001b[43m \u001b[49m\u001b[43mis_text\u001b[49m\u001b[43m=\u001b[49m\u001b[43mis_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1887\u001b[39m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mencoding_errors\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstrict\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1888\u001b[39m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstorage_options\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1889\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1890\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m.handles \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1891\u001b[39m f = \u001b[38;5;28mself\u001b[39m.handles.handle\n",
|
1496 |
-
"\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\common.py:873\u001b[39m, in \u001b[36mget_handle\u001b[39m\u001b[34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[39m\n\u001b[32m 868\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(handle, \u001b[38;5;28mstr\u001b[39m):\n\u001b[32m 869\u001b[39m \u001b[38;5;66;03m# Check whether the filename is to be opened in binary mode.\u001b[39;00m\n\u001b[32m 870\u001b[39m \u001b[38;5;66;03m# Binary mode does not support 'encoding' and 'newline'.\u001b[39;00m\n\u001b[32m 871\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ioargs.encoding \u001b[38;5;129;01mand\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ioargs.mode:\n\u001b[32m 872\u001b[39m \u001b[38;5;66;03m# Encoding\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m873\u001b[39m handle = \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[32m 874\u001b[39m \u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 875\u001b[39m \u001b[43m \u001b[49m\u001b[43mioargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 876\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m=\u001b[49m\u001b[43mioargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 877\u001b[39m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 878\u001b[39m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 879\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 880\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 881\u001b[39m \u001b[38;5;66;03m# Binary mode\u001b[39;00m\n\u001b[32m 882\u001b[39m handle = \u001b[38;5;28mopen\u001b[39m(handle, ioargs.mode)\n",
|
1497 |
-
"\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: 'D:\\\\DR_Classification\\\\dataset\\\\splits\\\\test_labels.csv'"
|
1498 |
-
]
|
1499 |
-
}
|
1500 |
-
],
|
1501 |
-
"source": [
|
1502 |
-
"import pandas as pd\n",
|
1503 |
-
"\n",
|
1504 |
-
"# === File paths ===\n",
|
1505 |
-
"csv_path = r\"D:\\DR_Classification\\dataset\\splits\\test_labels.csv\"\n",
|
1506 |
-
"output_csv_path = r\"D:\\DR_Classification\\dataset\\Splitted_data\\splits\\test_labels.csv\"\n",
|
1507 |
-
"\n",
|
1508 |
-
"# === Old and new base directory paths ===\n",
|
1509 |
-
"old_dir = r\"D:\\DR_Classification\\splits\\test\"\n",
|
1510 |
-
"new_dir = r\"D:\\DR_Classification\\dataset\\splitted-data\\test\"\n",
|
1511 |
-
"\n",
|
1512 |
-
"# === Load the CSV ===\n",
|
1513 |
-
"df = pd.read_csv(csv_path)\n",
|
1514 |
-
"\n",
|
1515 |
-
"# === Replace old path with new path in 'new_path' column ===\n",
|
1516 |
-
"df['new_path'] = df['new_path'].str.replace(old_dir, new_dir, regex=False)\n",
|
1517 |
-
"\n",
|
1518 |
-
"# === Save the updated CSV ===\n",
|
1519 |
-
"df.to_csv(output_csv_path, index=False)\n",
|
1520 |
-
"\n",
|
1521 |
-
"print(\"✅ CSV updated and saved at:\", output_csv_path)\n"
|
1522 |
-
]
|
1523 |
}
|
1524 |
],
|
1525 |
"metadata": {
|
|
|
1017 |
},
|
1018 |
{
|
1019 |
"cell_type": "code",
|
1020 |
+
"execution_count": null,
|
1021 |
"id": "560a2a1b",
|
1022 |
"metadata": {},
|
1023 |
"outputs": [
|
|
|
1035 |
"source": [
|
1036 |
"from sklearn.metrics import roc_curve, auc\n",
|
1037 |
"from sklearn.preprocessing import label_binarize\n",
|
1038 |
+
"import numpy as np\n",
|
1039 |
+
"import matplotlib.pyplot as plt\n",
|
1040 |
+
"import torch\n",
|
1041 |
+
"\n",
|
1042 |
+
"# Number of classes for Diabetic Retinopathy (DR) classification\n",
|
1043 |
+
"n_classes = 5 \n",
|
1044 |
"\n",
|
1045 |
+
"# Convert class labels to one-hot encoded format (needed for multi-class ROC)\n",
|
1046 |
+
"# Example: label 2 becomes [0, 0, 1, 0, 0]\n",
|
1047 |
"y_true_bin = label_binarize(all_labels, classes=[0, 1, 2, 3, 4])\n",
|
1048 |
+
"\n",
|
1049 |
+
"# Will hold predicted probabilities for each class\n",
|
1050 |
"y_scores = []\n",
|
1051 |
"\n",
|
1052 |
+
"# Set model to evaluation mode\n",
|
1053 |
"model.eval()\n",
|
1054 |
+
"\n",
|
1055 |
+
"# Disable gradient calculation for faster inference\n",
|
1056 |
"with torch.no_grad():\n",
|
1057 |
" for inputs, labels in test_loader:\n",
|
1058 |
" inputs = inputs.to(device)\n",
|
1059 |
+
" \n",
|
1060 |
+
" # Forward pass through the model\n",
|
1061 |
" outputs = model(inputs)\n",
|
1062 |
+
" \n",
|
1063 |
+
" # Apply softmax to get class probabilities\n",
|
1064 |
" probs = torch.softmax(outputs, dim=1)\n",
|
1065 |
+
" \n",
|
1066 |
+
" # Append the probabilities to y_scores list\n",
|
1067 |
" y_scores.extend(probs.cpu().numpy())\n",
|
1068 |
"\n",
|
1069 |
+
"# Convert the list of predictions to a NumPy array\n",
|
|
|
1070 |
"y_scores = np.array(y_scores)\n",
|
1071 |
"\n",
|
1072 |
+
"# Initialize dictionaries to store False Positive Rate (FPR), True Positive Rate (TPR), and AUC\n",
|
1073 |
+
"fpr, tpr, roc_auc = dict(), dict(), dict()\n",
|
1074 |
+
"\n",
|
1075 |
+
"# Compute ROC curve and AUC for each class\n",
|
1076 |
"for i in range(n_classes):\n",
|
1077 |
+
" # Calculate FPR and TPR for class `i`\n",
|
1078 |
" fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n",
|
1079 |
+
" \n",
|
1080 |
+
" # Calculate Area Under Curve (AUC) for class `i`\n",
|
1081 |
" roc_auc[i] = auc(fpr[i], tpr[i])\n",
|
1082 |
"\n",
|
1083 |
+
"# Plot ROC curves for all classes\n",
|
1084 |
"plt.figure(figsize=(10, 7))\n",
|
1085 |
+
"\n",
|
1086 |
"for i in range(n_classes):\n",
|
1087 |
" plt.plot(fpr[i], tpr[i], label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')\n",
|
1088 |
"\n",
|
1089 |
+
"# Plot the diagonal line (chance level)\n",
|
1090 |
"plt.plot([0, 1], [0, 1], 'k--')\n",
|
1091 |
+
"\n",
|
1092 |
+
"# Set plot title and axis labels\n",
|
1093 |
"plt.title('Multi-class ROC Curve')\n",
|
1094 |
"plt.xlabel('False Positive Rate')\n",
|
1095 |
"plt.ylabel('True Positive Rate')\n",
|
1096 |
+
"\n",
|
1097 |
+
"# Add legend and grid\n",
|
1098 |
"plt.legend(loc='lower right')\n",
|
1099 |
"plt.grid(True)\n",
|
1100 |
+
"\n",
|
1101 |
+
"# Show the plot\n",
|
1102 |
"plt.show()\n"
|
1103 |
]
|
1104 |
},
|
|
|
1506 |
"print(f\"Predicted Class: {predicted_class}\")\n",
|
1507 |
"print(f\"Confidence: {confidence_percentage:.2f}%\")"
|
1508 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1509 |
}
|
1510 |
],
|
1511 |
"metadata": {
|