3v324v23 commited on
Commit
f519c0b
·
1 Parent(s): 6d2a042

put comments

Browse files
Files changed (2) hide show
  1. pages/Model_Evaluation.py +33 -13
  2. training/training.ipynb +37 -51
pages/Model_Evaluation.py CHANGED
@@ -161,64 +161,76 @@ with st.expander("ℹ️ **What is Model Evaluation?**", expanded=True):
161
 
162
 
163
  # ---- Evaluation Logic ----
 
164
  if st.session_state.trigger_eval:
165
  st.markdown("### ⏱️ Evaluation Results")
166
 
 
167
  start_time = time.time()
168
- y_true = []
169
- y_pred = []
170
- y_score = []
171
- misclassified_images = []
172
 
173
- total_batches = len(test_loader)
174
- progress_bar = st.progress(0)
175
- status_text = st.empty()
176
- stop_info = st.empty()
177
 
 
178
  with torch.no_grad():
179
  for i, (images, labels) in enumerate(test_loader):
 
180
  if st.session_state.stop_eval:
181
  stop_info.warning("🚩 Evaluation stopped by user.")
182
  break
183
 
 
184
  outputs = model(images)
185
- _, predicted = torch.max(outputs, 1)
186
  y_true.extend(labels.numpy())
187
  y_pred.extend(predicted.numpy())
188
  y_score.extend(outputs.detach().numpy())
189
 
 
190
  for j in range(len(labels)):
191
  if predicted[j] != labels[j]:
192
  misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
193
 
 
194
  percent_complete = (i + 1) / total_batches
195
  progress_bar.progress(min(percent_complete, 1.0))
196
  status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
197
- time.sleep(0.1)
198
 
199
  end_time = time.time()
200
- eval_time = end_time - start_time
201
 
 
202
  if not st.session_state.stop_eval:
203
  st.session_state.evaluation_done = True
204
  st.session_state.trigger_eval = False # ✅ Reset the trigger
205
  st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
206
 
 
207
  report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
208
  report_df = pd.DataFrame(report).transpose()
209
  st.dataframe(report_df.style.format("{:.2f}"))
210
 
 
211
  pdf = FPDF()
212
  pdf.add_page()
213
  pdf.set_font("Arial", size=12)
214
  pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
215
 
 
216
  col_widths = [40, 40, 40, 40]
217
  headers = ["Class", "Precision", "Recall", "F1-Score"]
218
  for i, header in enumerate(headers):
219
  pdf.cell(col_widths[i], 10, header, border=1)
220
  pdf.ln()
221
 
 
222
  for idx, row in report_df.iterrows():
223
  if idx in ['accuracy', 'macro avg', 'weighted avg']:
224
  continue
@@ -228,6 +240,7 @@ if st.session_state.trigger_eval:
228
  pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
229
  pdf.ln()
230
 
 
231
  cm = confusion_matrix(y_true, y_pred)
232
  fig_cm, ax = plt.subplots()
233
  sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
@@ -235,12 +248,15 @@ if st.session_state.trigger_eval:
235
  ax.set_ylabel('True')
236
  ax.set_title("Confusion Matrix")
237
  st.pyplot(fig_cm)
 
 
238
  cm_path = "confusion_matrix.png"
239
  fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
240
  plt.close(fig_cm)
241
  if os.path.exists(cm_path):
242
  pdf.image(cm_path, x=10, y=None, w=180)
243
 
 
244
  y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
245
  y_score_np = np.array(y_score)
246
  fig_roc, ax = plt.subplots()
@@ -249,26 +265,30 @@ if st.session_state.trigger_eval:
249
  roc_auc = auc(fpr, tpr)
250
  ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
251
 
252
- ax.plot([0, 1], [0, 1], 'k--')
253
  ax.set_xlabel('False Positive Rate')
254
  ax.set_ylabel('True Positive Rate')
255
  ax.set_title('Multi-class ROC Curve')
256
  ax.legend(loc='lower right')
257
  st.pyplot(fig_roc)
 
 
258
  roc_path = "roc_curve.png"
259
  fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
260
  plt.close(fig_roc)
261
  if os.path.exists(roc_path):
262
  pdf.image(roc_path, x=10, y=None, w=180)
263
 
 
264
  st.markdown("### ❌ Misclassified Samples")
265
  fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
266
  for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
267
- axs[idx].imshow(img.permute(1, 2, 0))
268
  axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
269
  axs[idx].axis('off')
270
  st.pyplot(fig_mis)
271
 
 
272
  output_pdf = "evaluation_report.pdf"
273
  pdf.output(output_pdf)
274
  with open(output_pdf, "rb") as f:
 
161
 
162
 
163
  # ---- Evaluation Logic ----
164
+ # Check if evaluation should be triggered
165
  if st.session_state.trigger_eval:
166
  st.markdown("### ⏱️ Evaluation Results")
167
 
168
+ # Start timing the evaluation
169
  start_time = time.time()
170
+ y_true = [] # Ground truth labels
171
+ y_pred = [] # Predicted labels
172
+ y_score = [] # Raw model outputs
173
+ misclassified_images = [] # List to store misclassified samples
174
 
175
+ total_batches = len(test_loader) # Total number of batches
176
+ progress_bar = st.progress(0) # Initialize progress bar
177
+ status_text = st.empty() # Placeholder for status updates
178
+ stop_info = st.empty() # Placeholder for stop message
179
 
180
+ # Disable gradient calculation for faster evaluation
181
  with torch.no_grad():
182
  for i, (images, labels) in enumerate(test_loader):
183
+ # Allow user to stop the evaluation
184
  if st.session_state.stop_eval:
185
  stop_info.warning("🚩 Evaluation stopped by user.")
186
  break
187
 
188
+ # Run model on input images
189
  outputs = model(images)
190
+ _, predicted = torch.max(outputs, 1) # Get predicted class
191
  y_true.extend(labels.numpy())
192
  y_pred.extend(predicted.numpy())
193
  y_score.extend(outputs.detach().numpy())
194
 
195
+ # Store misclassified samples
196
  for j in range(len(labels)):
197
  if predicted[j] != labels[j]:
198
  misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
199
 
200
+ # Update progress bar and status text
201
  percent_complete = (i + 1) / total_batches
202
  progress_bar.progress(min(percent_complete, 1.0))
203
  status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
204
+ time.sleep(0.1) # Add delay for UI responsiveness
205
 
206
  end_time = time.time()
207
+ eval_time = end_time - start_time # Total evaluation time
208
 
209
+ # Finalize evaluation if not stopped
210
  if not st.session_state.stop_eval:
211
  st.session_state.evaluation_done = True
212
  st.session_state.trigger_eval = False # ✅ Reset the trigger
213
  st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
214
 
215
+ # Generate classification report and display as a DataFrame
216
  report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
217
  report_df = pd.DataFrame(report).transpose()
218
  st.dataframe(report_df.style.format("{:.2f}"))
219
 
220
+ # Initialize PDF report
221
  pdf = FPDF()
222
  pdf.add_page()
223
  pdf.set_font("Arial", size=12)
224
  pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
225
 
226
+ # Add table headers
227
  col_widths = [40, 40, 40, 40]
228
  headers = ["Class", "Precision", "Recall", "F1-Score"]
229
  for i, header in enumerate(headers):
230
  pdf.cell(col_widths[i], 10, header, border=1)
231
  pdf.ln()
232
 
233
+ # Add metrics for each class
234
  for idx, row in report_df.iterrows():
235
  if idx in ['accuracy', 'macro avg', 'weighted avg']:
236
  continue
 
240
  pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
241
  pdf.ln()
242
 
243
+ # Create and display confusion matrix
244
  cm = confusion_matrix(y_true, y_pred)
245
  fig_cm, ax = plt.subplots()
246
  sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
 
248
  ax.set_ylabel('True')
249
  ax.set_title("Confusion Matrix")
250
  st.pyplot(fig_cm)
251
+
252
+ # Save confusion matrix to PDF
253
  cm_path = "confusion_matrix.png"
254
  fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
255
  plt.close(fig_cm)
256
  if os.path.exists(cm_path):
257
  pdf.image(cm_path, x=10, y=None, w=180)
258
 
259
+ # Create and display ROC curve for each class
260
  y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
261
  y_score_np = np.array(y_score)
262
  fig_roc, ax = plt.subplots()
 
265
  roc_auc = auc(fpr, tpr)
266
  ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
267
 
268
+ ax.plot([0, 1], [0, 1], 'k--') # Diagonal reference line
269
  ax.set_xlabel('False Positive Rate')
270
  ax.set_ylabel('True Positive Rate')
271
  ax.set_title('Multi-class ROC Curve')
272
  ax.legend(loc='lower right')
273
  st.pyplot(fig_roc)
274
+
275
+ # Save ROC curve to PDF
276
  roc_path = "roc_curve.png"
277
  fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
278
  plt.close(fig_roc)
279
  if os.path.exists(roc_path):
280
  pdf.image(roc_path, x=10, y=None, w=180)
281
 
282
+ # Show misclassified samples (up to 5)
283
  st.markdown("### ❌ Misclassified Samples")
284
  fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
285
  for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
286
+ axs[idx].imshow(img.permute(1, 2, 0)) # Convert tensor to image format
287
  axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
288
  axs[idx].axis('off')
289
  st.pyplot(fig_mis)
290
 
291
+ # Save PDF and provide download button
292
  output_pdf = "evaluation_report.pdf"
293
  pdf.output(output_pdf)
294
  with open(output_pdf, "rb") as f:
training/training.ipynb CHANGED
@@ -1017,7 +1017,7 @@
1017
  },
1018
  {
1019
  "cell_type": "code",
1020
- "execution_count": 29,
1021
  "id": "560a2a1b",
1022
  "metadata": {},
1023
  "outputs": [
@@ -1035,38 +1035,70 @@
1035
  "source": [
1036
  "from sklearn.metrics import roc_curve, auc\n",
1037
  "from sklearn.preprocessing import label_binarize\n",
 
 
 
 
 
 
1038
  "\n",
1039
- "n_classes = 5 # for DR classification\n",
 
1040
  "y_true_bin = label_binarize(all_labels, classes=[0, 1, 2, 3, 4])\n",
 
 
1041
  "y_scores = []\n",
1042
  "\n",
 
1043
  "model.eval()\n",
 
 
1044
  "with torch.no_grad():\n",
1045
  " for inputs, labels in test_loader:\n",
1046
  " inputs = inputs.to(device)\n",
 
 
1047
  " outputs = model(inputs)\n",
 
 
1048
  " probs = torch.softmax(outputs, dim=1)\n",
 
 
1049
  " y_scores.extend(probs.cpu().numpy())\n",
1050
  "\n",
1051
- "# Compute ROC curve and AUC for each class\n",
1052
- "fpr, tpr, roc_auc = dict(), dict(), dict()\n",
1053
  "y_scores = np.array(y_scores)\n",
1054
  "\n",
 
 
 
 
1055
  "for i in range(n_classes):\n",
 
1056
  " fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n",
 
 
1057
  " roc_auc[i] = auc(fpr[i], tpr[i])\n",
1058
  "\n",
1059
- "# Plot all ROC curves\n",
1060
  "plt.figure(figsize=(10, 7))\n",
 
1061
  "for i in range(n_classes):\n",
1062
  " plt.plot(fpr[i], tpr[i], label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')\n",
1063
  "\n",
 
1064
  "plt.plot([0, 1], [0, 1], 'k--')\n",
 
 
1065
  "plt.title('Multi-class ROC Curve')\n",
1066
  "plt.xlabel('False Positive Rate')\n",
1067
  "plt.ylabel('True Positive Rate')\n",
 
 
1068
  "plt.legend(loc='lower right')\n",
1069
  "plt.grid(True)\n",
 
 
1070
  "plt.show()\n"
1071
  ]
1072
  },
@@ -1474,52 +1506,6 @@
1474
  "print(f\"Predicted Class: {predicted_class}\")\n",
1475
  "print(f\"Confidence: {confidence_percentage:.2f}%\")"
1476
  ]
1477
- },
1478
- {
1479
- "cell_type": "code",
1480
- "execution_count": 1,
1481
- "id": "eb2308ed",
1482
- "metadata": {},
1483
- "outputs": [
1484
- {
1485
- "ename": "FileNotFoundError",
1486
- "evalue": "[Errno 2] No such file or directory: 'D:\\\\DR_Classification\\\\dataset\\\\splits\\\\test_labels.csv'",
1487
- "output_type": "error",
1488
- "traceback": [
1489
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
1490
- "\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)",
1491
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 12\u001b[39m\n\u001b[32m 9\u001b[39m new_dir = \u001b[33mr\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mD:\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mDR_Classification\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mdataset\u001b[39m\u001b[33m\\\u001b[39m\u001b[33msplitted-data\u001b[39m\u001b[33m\\\u001b[39m\u001b[33mtest\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 11\u001b[39m \u001b[38;5;66;03m# === Load the CSV ===\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m12\u001b[39m df = \u001b[43mpd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcsv_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 14\u001b[39m \u001b[38;5;66;03m# === Replace old path with new path in 'new_path' column ===\u001b[39;00m\n\u001b[32m 15\u001b[39m df[\u001b[33m'\u001b[39m\u001b[33mnew_path\u001b[39m\u001b[33m'\u001b[39m] = df[\u001b[33m'\u001b[39m\u001b[33mnew_path\u001b[39m\u001b[33m'\u001b[39m].str.replace(old_dir, new_dir, regex=\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
1492
- "\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1026\u001b[39m, in \u001b[36mread_csv\u001b[39m\u001b[34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[39m\n\u001b[32m 1013\u001b[39m kwds_defaults = _refine_defaults_read(\n\u001b[32m 1014\u001b[39m dialect,\n\u001b[32m 1015\u001b[39m delimiter,\n\u001b[32m (...)\u001b[39m\u001b[32m 1022\u001b[39m dtype_backend=dtype_backend,\n\u001b[32m 1023\u001b[39m )\n\u001b[32m 1024\u001b[39m kwds.update(kwds_defaults)\n\u001b[32m-> \u001b[39m\u001b[32m1026\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n",
1493
- "\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:620\u001b[39m, in \u001b[36m_read\u001b[39m\u001b[34m(filepath_or_buffer, kwds)\u001b[39m\n\u001b[32m 617\u001b[39m _validate_names(kwds.get(\u001b[33m\"\u001b[39m\u001b[33mnames\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[32m 619\u001b[39m \u001b[38;5;66;03m# Create the parser.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m620\u001b[39m parser = \u001b[43mTextFileReader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 622\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m chunksize \u001b[38;5;129;01mor\u001b[39;00m iterator:\n\u001b[32m 623\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m parser\n",
1494
- "\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1620\u001b[39m, in \u001b[36mTextFileReader.__init__\u001b[39m\u001b[34m(self, f, engine, **kwds)\u001b[39m\n\u001b[32m 1617\u001b[39m \u001b[38;5;28mself\u001b[39m.options[\u001b[33m\"\u001b[39m\u001b[33mhas_index_names\u001b[39m\u001b[33m\"\u001b[39m] = kwds[\u001b[33m\"\u001b[39m\u001b[33mhas_index_names\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 1619\u001b[39m \u001b[38;5;28mself\u001b[39m.handles: IOHandles | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1620\u001b[39m \u001b[38;5;28mself\u001b[39m._engine = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_make_engine\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mengine\u001b[49m\u001b[43m)\u001b[49m\n",
1495
- "\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\parsers\\readers.py:1880\u001b[39m, in \u001b[36mTextFileReader._make_engine\u001b[39m\u001b[34m(self, f, engine)\u001b[39m\n\u001b[32m 1878\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[32m 1879\u001b[39m mode += \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m-> \u001b[39m\u001b[32m1880\u001b[39m \u001b[38;5;28mself\u001b[39m.handles = \u001b[43mget_handle\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1881\u001b[39m \u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1882\u001b[39m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1883\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mencoding\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1884\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompression\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcompression\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1885\u001b[39m \u001b[43m \u001b[49m\u001b[43mmemory_map\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmemory_map\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1886\u001b[39m \u001b[43m \u001b[49m\u001b[43mis_text\u001b[49m\u001b[43m=\u001b[49m\u001b[43mis_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1887\u001b[39m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mencoding_errors\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstrict\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1888\u001b[39m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mstorage_options\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1889\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1890\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m.handles \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1891\u001b[39m f = \u001b[38;5;28mself\u001b[39m.handles.handle\n",
1496
- "\u001b[36mFile \u001b[39m\u001b[32md:\\DR_Classification\\.venv\\Lib\\site-packages\\pandas\\io\\common.py:873\u001b[39m, in \u001b[36mget_handle\u001b[39m\u001b[34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[39m\n\u001b[32m 868\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(handle, \u001b[38;5;28mstr\u001b[39m):\n\u001b[32m 869\u001b[39m \u001b[38;5;66;03m# Check whether the filename is to be opened in binary mode.\u001b[39;00m\n\u001b[32m 870\u001b[39m \u001b[38;5;66;03m# Binary mode does not support 'encoding' and 'newline'.\u001b[39;00m\n\u001b[32m 871\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ioargs.encoding \u001b[38;5;129;01mand\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ioargs.mode:\n\u001b[32m 872\u001b[39m \u001b[38;5;66;03m# Encoding\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m873\u001b[39m handle = \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[32m 874\u001b[39m \u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 875\u001b[39m \u001b[43m \u001b[49m\u001b[43mioargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 876\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m=\u001b[49m\u001b[43mioargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 877\u001b[39m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 878\u001b[39m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 879\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 880\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 881\u001b[39m \u001b[38;5;66;03m# Binary mode\u001b[39;00m\n\u001b[32m 882\u001b[39m handle = \u001b[38;5;28mopen\u001b[39m(handle, ioargs.mode)\n",
1497
- "\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: 'D:\\\\DR_Classification\\\\dataset\\\\splits\\\\test_labels.csv'"
1498
- ]
1499
- }
1500
- ],
1501
- "source": [
1502
- "import pandas as pd\n",
1503
- "\n",
1504
- "# === File paths ===\n",
1505
- "csv_path = r\"D:\\DR_Classification\\dataset\\splits\\test_labels.csv\"\n",
1506
- "output_csv_path = r\"D:\\DR_Classification\\dataset\\Splitted_data\\splits\\test_labels.csv\"\n",
1507
- "\n",
1508
- "# === Old and new base directory paths ===\n",
1509
- "old_dir = r\"D:\\DR_Classification\\splits\\test\"\n",
1510
- "new_dir = r\"D:\\DR_Classification\\dataset\\splitted-data\\test\"\n",
1511
- "\n",
1512
- "# === Load the CSV ===\n",
1513
- "df = pd.read_csv(csv_path)\n",
1514
- "\n",
1515
- "# === Replace old path with new path in 'new_path' column ===\n",
1516
- "df['new_path'] = df['new_path'].str.replace(old_dir, new_dir, regex=False)\n",
1517
- "\n",
1518
- "# === Save the updated CSV ===\n",
1519
- "df.to_csv(output_csv_path, index=False)\n",
1520
- "\n",
1521
- "print(\"✅ CSV updated and saved at:\", output_csv_path)\n"
1522
- ]
1523
  }
1524
  ],
1525
  "metadata": {
 
1017
  },
1018
  {
1019
  "cell_type": "code",
1020
+ "execution_count": null,
1021
  "id": "560a2a1b",
1022
  "metadata": {},
1023
  "outputs": [
 
1035
  "source": [
1036
  "from sklearn.metrics import roc_curve, auc\n",
1037
  "from sklearn.preprocessing import label_binarize\n",
1038
+ "import numpy as np\n",
1039
+ "import matplotlib.pyplot as plt\n",
1040
+ "import torch\n",
1041
+ "\n",
1042
+ "# Number of classes for Diabetic Retinopathy (DR) classification\n",
1043
+ "n_classes = 5 \n",
1044
  "\n",
1045
+ "# Convert class labels to one-hot encoded format (needed for multi-class ROC)\n",
1046
+ "# Example: label 2 becomes [0, 0, 1, 0, 0]\n",
1047
  "y_true_bin = label_binarize(all_labels, classes=[0, 1, 2, 3, 4])\n",
1048
+ "\n",
1049
+ "# Will hold predicted probabilities for each class\n",
1050
  "y_scores = []\n",
1051
  "\n",
1052
+ "# Set model to evaluation mode\n",
1053
  "model.eval()\n",
1054
+ "\n",
1055
+ "# Disable gradient calculation for faster inference\n",
1056
  "with torch.no_grad():\n",
1057
  " for inputs, labels in test_loader:\n",
1058
  " inputs = inputs.to(device)\n",
1059
+ " \n",
1060
+ " # Forward pass through the model\n",
1061
  " outputs = model(inputs)\n",
1062
+ " \n",
1063
+ " # Apply softmax to get class probabilities\n",
1064
  " probs = torch.softmax(outputs, dim=1)\n",
1065
+ " \n",
1066
+ " # Append the probabilities to y_scores list\n",
1067
  " y_scores.extend(probs.cpu().numpy())\n",
1068
  "\n",
1069
+ "# Convert the list of predictions to a NumPy array\n",
 
1070
  "y_scores = np.array(y_scores)\n",
1071
  "\n",
1072
+ "# Initialize dictionaries to store False Positive Rate (FPR), True Positive Rate (TPR), and AUC\n",
1073
+ "fpr, tpr, roc_auc = dict(), dict(), dict()\n",
1074
+ "\n",
1075
+ "# Compute ROC curve and AUC for each class\n",
1076
  "for i in range(n_classes):\n",
1077
+ " # Calculate FPR and TPR for class `i`\n",
1078
  " fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n",
1079
+ " \n",
1080
+ " # Calculate Area Under Curve (AUC) for class `i`\n",
1081
  " roc_auc[i] = auc(fpr[i], tpr[i])\n",
1082
  "\n",
1083
+ "# Plot ROC curves for all classes\n",
1084
  "plt.figure(figsize=(10, 7))\n",
1085
+ "\n",
1086
  "for i in range(n_classes):\n",
1087
  " plt.plot(fpr[i], tpr[i], label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')\n",
1088
  "\n",
1089
+ "# Plot the diagonal line (chance level)\n",
1090
  "plt.plot([0, 1], [0, 1], 'k--')\n",
1091
+ "\n",
1092
+ "# Set plot title and axis labels\n",
1093
  "plt.title('Multi-class ROC Curve')\n",
1094
  "plt.xlabel('False Positive Rate')\n",
1095
  "plt.ylabel('True Positive Rate')\n",
1096
+ "\n",
1097
+ "# Add legend and grid\n",
1098
  "plt.legend(loc='lower right')\n",
1099
  "plt.grid(True)\n",
1100
+ "\n",
1101
+ "# Show the plot\n",
1102
  "plt.show()\n"
1103
  ]
1104
  },
 
1506
  "print(f\"Predicted Class: {predicted_class}\")\n",
1507
  "print(f\"Confidence: {confidence_percentage:.2f}%\")"
1508
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1509
  }
1510
  ],
1511
  "metadata": {