3v324v23 commited on
Commit
177b675
·
1 Parent(s): f6e756e

REvert "Add Model Evaluation Notebook"

Browse files
Files changed (1) hide show
  1. pages/Model_Evaluation.py +42 -78
pages/Model_Evaluation.py CHANGED
@@ -15,16 +15,14 @@ from sklearn.preprocessing import label_binarize
15
  import streamlit as st
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
18
- from datasets import load_dataset
19
- from huggingface_hub import hf_hub_download
20
- import requests
21
- from io import BytesIO
22
 
23
  # ---- Streamlit State Initialization ----
24
  if 'stop_eval' not in st.session_state:
25
  st.session_state.stop_eval = False
 
26
  if 'evaluation_done' not in st.session_state:
27
  st.session_state.evaluation_done = False
 
28
  if 'trigger_eval' not in st.session_state:
29
  st.session_state.trigger_eval = False
30
 
@@ -35,7 +33,7 @@ st.markdown("<h2 style='color: #2E86C1;'>📈 Model Evaluation</h2>", unsafe_all
35
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
36
  label_map = {label: idx for idx, label in enumerate(class_names)}
37
 
38
- # ---- Text Cleaning Function ----
39
  def clean_text(text):
40
  return text.encode('utf-8', 'ignore').decode('utf-8')
41
 
@@ -60,104 +58,70 @@ def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
60
  return cv2.GaussianBlur(image, kernel_size, sigma)
61
 
62
  # ---- Custom Dataset ----
63
- def __getitem__(self, idx):
64
- # Get the image URL and label
65
- raw_url = self.image_paths[idx]
66
- label = int(self.labels[idx])
67
-
68
- # Fix the URL if it's in blob format
69
- img_url = raw_url.replace("/blob/", "/resolve/")
70
-
71
- try:
72
- print(f"Downloading: {img_url}")
73
- response = requests.get(img_url, stream=True)
74
- if response.status_code != 200:
75
- raise ValueError(f"Failed to download image from {img_url} - status code: {response.status_code}")
76
-
77
- # Read and convert the image
78
- image = Image.open(BytesIO(response.content)).convert("RGB")
79
- except Exception as e:
80
- raise RuntimeError(f"Error loading image from {img_url}: {e}")
81
-
82
- image = np.array(image)
83
-
84
- image = apply_median_filter(image)
85
- image = apply_clahe(image)
86
- image = apply_gamma_correction(image)
87
- image = apply_gaussian_filter(image)
88
-
89
- image = Image.fromarray(image)
90
- if self.transform:
91
- image = self.transform(image)
92
-
93
- return image, torch.tensor(label, dtype=torch.long)
94
-
95
  # ---- Image Transforms ----
96
  val_transform = transforms.Compose([
97
  transforms.Resize((224, 224)),
98
  transforms.ToTensor(),
99
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
100
  ])
101
-
102
- # ---- DDRDataset Class Definition ----
103
- class DDRDataset(Dataset):
104
- def __init__(self, csv_path, transform=None):
105
- self.data = pd.read_csv(csv_path)
106
- self.image_paths = self.data['new_path']
107
- self.labels = self.data['label']
108
- self.transform = transform
109
-
110
- def __len__(self):
111
- return len(self.data)
112
-
113
- def __getitem__(self, idx):
114
- img_path = self.image_paths[idx]
115
- label = self.labels[idx]
116
-
117
- try:
118
- image = Image.open(img_path).convert("RGB")
119
- except Exception as e:
120
- raise RuntimeError(f"Error loading image from {img_path}: {e}")
121
-
122
- if self.transform:
123
- image = self.transform(image)
124
-
125
- return image, torch.tensor(label, dtype=torch.long)
126
-
127
- # ---- Load Data from Hugging Face (cached) ----
128
  @st.cache_resource
129
- def load_test_data_from_huggingface():
130
- dataset = load_dataset(
131
- "Ci-Dave/DDR_dataset_train_test",
132
- data_files={"test": "splits/test_labels_newpath.csv"},
133
- split="test"
134
- )
135
- df = dataset.to_pandas()
136
- csv_path = "test_labels_temp.csv"
137
- df.to_csv(csv_path, index=False)
138
  dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
139
  return DataLoader(dataset, batch_size=32, shuffle=False)
140
 
141
- # ---- Load Model from Hugging Face (cached) ----
142
  @st.cache_resource
143
  def load_model():
144
- model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
145
  model = models.densenet121(pretrained=False)
146
  model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
147
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
148
  model.eval()
149
  return model
150
 
151
- # ---- UI Buttons ----
 
152
  model = load_model()
153
- test_loader = load_test_data_from_huggingface()
154
 
155
  col1, col2 = st.columns([1, 1])
 
156
  with col1:
157
  if st.button("🚀 Start Evaluation"):
158
  st.session_state.stop_eval = False
159
  st.session_state.evaluation_done = False
160
  st.session_state.trigger_eval = True
 
161
  with col2:
162
  if st.button("🚩 Stop Evaluation"):
163
  st.session_state.stop_eval = True
 
15
  import streamlit as st
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
 
 
 
 
18
 
19
  # ---- Streamlit State Initialization ----
20
  if 'stop_eval' not in st.session_state:
21
  st.session_state.stop_eval = False
22
+
23
  if 'evaluation_done' not in st.session_state:
24
  st.session_state.evaluation_done = False
25
+
26
  if 'trigger_eval' not in st.session_state:
27
  st.session_state.trigger_eval = False
28
 
 
33
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
34
  label_map = {label: idx for idx, label in enumerate(class_names)}
35
 
36
+ # ---- Text Cleaning Function for PDF ----
37
  def clean_text(text):
38
  return text.encode('utf-8', 'ignore').decode('utf-8')
39
 
 
58
  return cv2.GaussianBlur(image, kernel_size, sigma)
59
 
60
  # ---- Custom Dataset ----
61
+ class DDRDataset(Dataset):
62
+ def __init__(self, csv_path, transform=None):
63
+ self.data = pd.read_csv(csv_path)
64
+ self.image_paths = self.data['new_path'].tolist()
65
+ self.labels = self.data['label'].tolist()
66
+ self.transform = transform
67
+
68
+ def __len__(self):
69
+ return len(self.image_paths)
70
+
71
+ def __getitem__(self, idx):
72
+ img_path = self.image_paths[idx]
73
+ label = int(self.labels[idx])
74
+
75
+ image = cv2.imread(img_path)
76
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
77
+
78
+ # Apply preprocessing
79
+ image = apply_median_filter(image)
80
+ image = apply_clahe(image)
81
+ image = apply_gamma_correction(image)
82
+ image = apply_gaussian_filter(image)
83
+
84
+ image = Image.fromarray(image)
85
+ if self.transform:
86
+ image = self.transform(image)
87
+
88
+ return image, torch.tensor(label, dtype=torch.long)
89
+
 
 
 
90
  # ---- Image Transforms ----
91
  val_transform = transforms.Compose([
92
  transforms.Resize((224, 224)),
93
  transforms.ToTensor(),
94
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
95
  ])
96
+
97
+ # ---- Load Data (with caching) ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  @st.cache_resource
99
+ def load_test_data(csv_path):
 
 
 
 
 
 
 
 
100
  dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
101
  return DataLoader(dataset, batch_size=32, shuffle=False)
102
 
103
+ # ---- Load Model (with caching) ----
104
  @st.cache_resource
105
  def load_model():
 
106
  model = models.densenet121(pretrained=False)
107
  model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
108
+ model.load_state_dict(torch.load(r"D:\\DR_Classification\\training\\Pretrained_Densenet-121.pth", map_location=torch.device('cpu')))
109
  model.eval()
110
  return model
111
 
112
+ # ---- Main UI Buttons ----
113
+ csv_path = r"D:\\DR_Classification\\splits\\test_labels.csv"
114
  model = load_model()
115
+ test_loader = load_test_data(csv_path)
116
 
117
  col1, col2 = st.columns([1, 1])
118
+
119
  with col1:
120
  if st.button("🚀 Start Evaluation"):
121
  st.session_state.stop_eval = False
122
  st.session_state.evaluation_done = False
123
  st.session_state.trigger_eval = True
124
+
125
  with col2:
126
  if st.button("🚩 Stop Evaluation"):
127
  st.session_state.stop_eval = True