3v324v23 commited on
Commit
6a5be80
·
1 Parent(s): 97c84ef

Added functionalities

Browse files
Files changed (1) hide show
  1. pages/Model_Evaluation.py +188 -70
pages/Model_Evaluation.py CHANGED
@@ -1,59 +1,78 @@
1
- import streamlit as st
 
 
 
 
 
 
2
  import torch
 
3
  from torch.utils.data import DataLoader, Dataset
4
  from torchvision import transforms, models
5
- import torch.nn as nn
6
  from PIL import Image
7
- import pandas as pd
8
- import os
9
- import cv2
10
- import numpy as np
 
11
 
 
 
 
 
 
 
 
 
12
  st.markdown("<h2 style='color: #2E86C1;'>📈 Model Evaluation</h2>", unsafe_allow_html=True)
13
 
14
- # Define class names and label map
15
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
16
  label_map = {label: idx for idx, label in enumerate(class_names)}
17
 
18
- # Define your image preprocessing functions
 
 
 
 
19
  def apply_median_filter(image):
20
  return cv2.medianBlur(image, 5)
21
 
22
  def apply_clahe(image):
23
  lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
24
  l, a, b = cv2.split(lab)
25
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
26
  cl = clahe.apply(l)
27
  merged = cv2.merge((cl, a, b))
28
  return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
29
 
30
- def apply_gamma_correction(image, gamma=1.5):
31
  invGamma = 1.0 / gamma
32
- table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(256)]).astype("uint8")
33
  return cv2.LUT(image, table)
34
 
35
  def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
36
  return cv2.GaussianBlur(image, kernel_size, sigma)
37
 
38
- # Custom dataset with preprocessing
39
  class DDRDataset(Dataset):
40
- def __init__(self, csv_path, img_dir, transform=None):
41
  self.data = pd.read_csv(csv_path)
42
- self.img_dir = img_dir
 
43
  self.transform = transform
44
 
45
  def __len__(self):
46
- return len(self.data)
47
 
48
  def __getitem__(self, idx):
49
- img_name = self.data.iloc[idx, 0]
50
- label_name = self.data.iloc[idx, 1]
51
- label = int(label_map.get(label_name, 0)) # fallback to 0
52
 
53
- img_path = os.path.join(self.img_dir, img_name)
54
  image = cv2.imread(img_path)
55
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
56
 
 
57
  image = apply_median_filter(image)
58
  image = apply_clahe(image)
59
  image = apply_gamma_correction(image)
@@ -63,67 +82,166 @@ class DDRDataset(Dataset):
63
  if self.transform:
64
  image = self.transform(image)
65
 
66
- return image, label
67
 
68
- # -------------------------------
69
- # Load Test Data with Caching
70
- # -------------------------------
 
 
 
 
 
71
  @st.cache_resource
72
- def load_test_data():
73
- transform = transforms.Compose([
74
- transforms.Resize((224, 224)),
75
- transforms.ToTensor(),
76
- transforms.Normalize([0.485, 0.456, 0.406],
77
- [0.229, 0.224, 0.225])
78
- ])
79
- dataset = DDRDataset(
80
- csv_path="D:/DR_Classification/splits/test_labels.csv",
81
- img_dir="D:/DR_Classification/splits/test",
82
- transform=transform
83
- )
84
  return DataLoader(dataset, batch_size=32, shuffle=False)
85
 
86
- # -------------------------------
87
- # Evaluation Function
88
- # -------------------------------
89
- def evaluation_test_model(model, test_loader, criterion, device='cpu'):
 
 
90
  model.eval()
91
- running_loss = 0.0
92
- correct = 0
93
- total = 0
94
 
95
- with torch.no_grad():
96
- for inputs, labels in test_loader:
97
- inputs, labels = inputs.to(device), labels.to(device)
98
- outputs = model(inputs)
99
- loss = criterion(outputs, labels)
100
 
101
- running_loss += loss.item()
102
- _, predicted = torch.max(outputs, 1)
103
- correct += (predicted == labels).sum().item()
104
- total += labels.size(0)
 
 
 
 
 
105
 
106
- val_loss = running_loss / len(test_loader)
107
- val_acc = correct / total * 100
108
- return val_loss, val_acc
109
 
110
- # -------------------------------
111
- # Evaluation Button Trigger
112
- # -------------------------------
113
- if st.button("🔍 Evaluate Trained Model"):
114
- with st.spinner("Evaluating on test data..."):
115
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
 
117
- test_loader = load_test_data()
 
118
 
119
- model = models.densenet121(pretrained=False)
120
- model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
121
- model.load_state_dict(torch.load("training/Pretrained_Densenet-121.pth", map_location=device))
122
- model.to(device)
 
123
 
124
- criterion = nn.CrossEntropyLoss()
125
- val_loss, val_acc = evaluation_test_model(model, test_loader, criterion, device)
 
 
126
 
127
- st.success("✅ Evaluation Complete")
128
- st.metric("Test Loss", f"{val_loss:.4f}")
129
- st.metric("Test Accuracy", f"{val_acc:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+ import cv2
5
+ import numpy as np
6
+ import pandas as pd
7
+ import seaborn as sns
8
  import torch
9
+ import torch.nn as nn
10
  from torch.utils.data import DataLoader, Dataset
11
  from torchvision import transforms, models
 
12
  from PIL import Image
13
+ from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
14
+ from sklearn.preprocessing import label_binarize
15
+ import streamlit as st
16
+ import matplotlib.pyplot as plt
17
+ from fpdf import FPDF
18
 
19
+ # ---- Stop Evaluation button -----
20
+ if 'stop_eval' not in st.session_state:
21
+ st.session_state.stop_eval = False
22
+
23
+ if 'evaluation_done' not in st.session_state:
24
+ st.session_state.evaluation_done = False
25
+
26
+ # ---- Streamlit Title ----
27
  st.markdown("<h2 style='color: #2E86C1;'>📈 Model Evaluation</h2>", unsafe_allow_html=True)
28
 
29
+ # ---- Class Names & Label Mapping ----
30
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
31
  label_map = {label: idx for idx, label in enumerate(class_names)}
32
 
33
+ # ---- Text Cleaning Function for PDF ----
34
+ def clean_text(text):
35
+ return text.encode('utf-8', 'ignore').decode('utf-8')
36
+
37
+ # ---- Preprocessing Functions ----
38
  def apply_median_filter(image):
39
  return cv2.medianBlur(image, 5)
40
 
41
  def apply_clahe(image):
42
  lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
43
  l, a, b = cv2.split(lab)
44
+ clahe = cv2.createCLAHE(clipLimit=2.0)
45
  cl = clahe.apply(l)
46
  merged = cv2.merge((cl, a, b))
47
  return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
48
 
49
+ def apply_gamma_correction(image, gamma=1.2):
50
  invGamma = 1.0 / gamma
51
+ table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype("uint8")
52
  return cv2.LUT(image, table)
53
 
54
  def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
55
  return cv2.GaussianBlur(image, kernel_size, sigma)
56
 
57
+ # ---- Custom Dataset ----
58
  class DDRDataset(Dataset):
59
+ def __init__(self, csv_path, transform=None):
60
  self.data = pd.read_csv(csv_path)
61
+ self.image_paths = self.data['new_path'].tolist()
62
+ self.labels = self.data['label'].tolist()
63
  self.transform = transform
64
 
65
  def __len__(self):
66
+ return len(self.image_paths)
67
 
68
  def __getitem__(self, idx):
69
+ img_path = self.image_paths[idx]
70
+ label = int(self.labels[idx])
 
71
 
 
72
  image = cv2.imread(img_path)
73
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
74
 
75
+ # Apply preprocessing
76
  image = apply_median_filter(image)
77
  image = apply_clahe(image)
78
  image = apply_gamma_correction(image)
 
82
  if self.transform:
83
  image = self.transform(image)
84
 
85
+ return image, torch.tensor(label, dtype=torch.long)
86
 
87
+ # ---- Image Transforms ----
88
+ val_transform = transforms.Compose([
89
+ transforms.Resize((224, 224)),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
92
+ ])
93
+
94
+ # ---- Load Data (with caching) ----
95
  @st.cache_resource
96
+ def load_test_data(csv_path):
97
+ dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
 
 
 
 
 
 
 
 
 
 
98
  return DataLoader(dataset, batch_size=32, shuffle=False)
99
 
100
+ # ---- Load Model (with caching) ----
101
+ @st.cache_resource
102
+ def load_model():
103
+ model = models.densenet121(pretrained=False)
104
+ model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
105
+ model.load_state_dict(torch.load(r"D:\\DR_Classification\\training\\Pretrained_Densenet-121.pth", map_location=torch.device('cpu')))
106
  model.eval()
107
+ return model
 
 
108
 
109
+ # ---- Main Evaluation ----
110
+ csv_path = r"D:\\DR_Classification\\splits\\test_labels.csv"
111
+ model = load_model()
112
+ test_loader = load_test_data(csv_path)
 
113
 
114
+ col1, col2 = st.columns([1, 1])
115
+
116
+ with col1:
117
+ if st.button("🚀 Start Evaluation"):
118
+ st.session_state.stop_eval = False
119
+ st.session_state.evaluation_done = False
120
+ run_eval = True
121
+ else:
122
+ run_eval = False
123
 
124
+ with col2:
125
+ if st.button("🚩 Stop Evaluation"):
126
+ st.session_state.stop_eval = True
127
 
128
+ if st.session_state.evaluation_done:
129
+ reevaluate_col, download_col = st.columns([1, 1])
 
 
 
 
130
 
131
+ if run_eval or st.session_state.evaluation_done:
132
+ st.markdown("### ⏱️ Evaluation Results")
133
 
134
+ start_time = time.time()
135
+ y_true = []
136
+ y_pred = []
137
+ y_score = []
138
+ misclassified_images = []
139
 
140
+ total_batches = len(test_loader)
141
+ progress_bar = st.progress(0)
142
+ status_text = st.empty()
143
+ stop_info = st.empty()
144
 
145
+ with torch.no_grad():
146
+ for i, (images, labels) in enumerate(test_loader):
147
+ if st.session_state.stop_eval:
148
+ stop_info.warning("🚩 Evaluation stopped by user.")
149
+ break
150
+
151
+ outputs = model(images)
152
+ _, predicted = torch.max(outputs, 1)
153
+ y_true.extend(labels.numpy())
154
+ y_pred.extend(predicted.numpy())
155
+ y_score.extend(outputs.detach().numpy())
156
+
157
+ for j in range(len(labels)):
158
+ if predicted[j] != labels[j]:
159
+ misclassified_images.append((images[j], predicted[j].item(), labels[j].item()))
160
+
161
+ percent_complete = (i + 1) / total_batches
162
+ progress_bar.progress(min(percent_complete, 1.0))
163
+ status_text.text(f"Evaluating on Test Set: {int(percent_complete * 100)}% | Batch {i+1}/{total_batches}")
164
+ time.sleep(0.1)
165
+
166
+ end_time = time.time()
167
+ eval_time = end_time - start_time
168
+
169
+ if not st.session_state.stop_eval:
170
+ st.session_state.evaluation_done = True
171
+ st.success(f"✅ Evaluation completed in **{eval_time:.2f} seconds**")
172
+
173
+ report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
174
+ report_df = pd.DataFrame(report).transpose()
175
+ st.dataframe(report_df.style.format("{:.2f}"))
176
+
177
+ pdf = FPDF()
178
+ pdf.add_page()
179
+ pdf.set_font("Arial", size=12)
180
+ pdf.cell(200, 10, txt=clean_text("Classification Report"), ln=True, align='C')
181
+
182
+ col_widths = [40, 40, 40, 40]
183
+ headers = ["Class", "Precision", "Recall", "F1-Score"]
184
+ for i, header in enumerate(headers):
185
+ pdf.cell(col_widths[i], 10, header, border=1)
186
+ pdf.ln()
187
+
188
+ for idx, row in report_df.iterrows():
189
+ if idx in ['accuracy', 'macro avg', 'weighted avg']:
190
+ continue
191
+ pdf.cell(col_widths[0], 10, str(idx), border=1)
192
+ pdf.cell(col_widths[1], 10, f"{row['precision']:.2f}", border=1)
193
+ pdf.cell(col_widths[2], 10, f"{row['recall']:.2f}", border=1)
194
+ pdf.cell(col_widths[3], 10, f"{row['f1-score']:.2f}", border=1)
195
+ pdf.ln()
196
+
197
+ cm = confusion_matrix(y_true, y_pred)
198
+ fig_cm, ax = plt.subplots()
199
+ sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues", ax=ax)
200
+ ax.set_xlabel('Predicted')
201
+ ax.set_ylabel('True')
202
+ ax.set_title("Confusion Matrix")
203
+ st.pyplot(fig_cm)
204
+ cm_path = "confusion_matrix.png"
205
+ fig_cm.savefig(cm_path, format='png', dpi=300, bbox_inches='tight')
206
+ plt.close(fig_cm)
207
+ if os.path.exists(cm_path):
208
+ pdf.image(cm_path, x=10, y=None, w=180)
209
+
210
+ y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
211
+ y_score_np = np.array(y_score)
212
+ fig_roc, ax = plt.subplots()
213
+ for i in range(len(class_names)):
214
+ fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score_np[:, i])
215
+ roc_auc = auc(fpr, tpr)
216
+ ax.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
217
+
218
+ ax.plot([0, 1], [0, 1], 'k--')
219
+ ax.set_xlabel('False Positive Rate')
220
+ ax.set_ylabel('True Positive Rate')
221
+ ax.set_title('Multi-class ROC Curve')
222
+ ax.legend(loc='lower right')
223
+ st.pyplot(fig_roc)
224
+ roc_path = "roc_curve.png"
225
+ fig_roc.savefig(roc_path, format='png', dpi=300, bbox_inches='tight')
226
+ plt.close(fig_roc)
227
+ if os.path.exists(roc_path):
228
+ pdf.image(roc_path, x=10, y=None, w=180)
229
+
230
+ st.markdown("### ❌ Misclassified Samples")
231
+ fig_mis, axs = plt.subplots(1, min(5, len(misclassified_images)), figsize=(15, 4))
232
+ for idx, (img, pred, true) in enumerate(misclassified_images[:5]):
233
+ axs[idx].imshow(img.permute(1, 2, 0))
234
+ axs[idx].set_title(f"True: {class_names[true]}\nPred: {class_names[pred]}")
235
+ axs[idx].axis('off')
236
+ st.pyplot(fig_mis)
237
+
238
+ output_pdf = "evaluation_report.pdf"
239
+ pdf.output(output_pdf)
240
+ with open(output_pdf, "rb") as f:
241
+ reevaluate_col, download_col = st.columns([1, 1])
242
+ with download_col:
243
+ st.download_button("📄 Download Full Evaluation PDF", f, file_name="evaluation_report.pdf")
244
+ with reevaluate_col:
245
+ if st.button("🔁 Re-evaluate"):
246
+ st.session_state.stop_eval = False
247
+ st.session_state.evaluation_done = False