3v324v23 commited on
Commit
97c84ef
Β·
1 Parent(s): eaf9896

Added multi pages

Browse files
Files changed (3) hide show
  1. app.py +11 -113
  2. pages/Model_Evaluation.py +129 -0
  3. pages/Upload_and_Predict.py +44 -0
app.py CHANGED
@@ -1,125 +1,23 @@
1
- # Directory Structure Suggestion:
2
- # diabetic_retinopathy_app/
3
- # β”œβ”€β”€ Home.py (Landing Page)
4
- # β”œβ”€β”€ pages/
5
- # β”‚ β”œβ”€β”€ 1_Upload_and_Predict.py
6
- # β”‚ └── 2_Model_Evaluation.py
7
- # └── assets/
8
- # └── banner.jpg
9
-
10
- # Home.py (Landing Page)
11
  import streamlit as st
12
  from PIL import Image
13
 
14
  def main():
15
  st.set_page_config(page_title="DR Assistive Tool", layout="centered")
16
- st.title("Welcome to the Diabetic Retinopathy Assistive Tool")
 
17
 
18
  st.markdown("""
19
- ### 🌟 Your AI-powered assistant for early detection of Diabetic Retinopathy.
20
-
21
- #### Features:
22
- - πŸ–ΌοΈ Upload a retinal image and receive a prediction of its DR stage.
23
- - πŸ“Š Evaluate model performance using real test datasets.
24
 
25
- Select a page from the left sidebar to get started.
 
 
 
 
 
 
26
  """)
27
 
28
- # image = Image.open("assets/banner.jpg") # Optional banner image
29
- # st.image(image, use_column_width=True)
30
-
31
  if __name__ == '__main__':
32
  main()
33
-
34
-
35
- # pages/1_Upload_and_Predict.py
36
- import streamlit as st
37
- import torch
38
- from torchvision import transforms, models
39
- from PIL import Image
40
- import numpy as np
41
-
42
- st.title("πŸ“· Upload & Predict Diabetic Retinopathy")
43
-
44
- class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
45
-
46
- def load_model():
47
- model = models.densenet121(pretrained=False)
48
- num_ftrs = model.classifier.in_features
49
- model.classifier = torch.nn.Linear(num_ftrs, len(class_names))
50
- model.load_state_dict(torch.load("training/Pretrained_Densenet-121.pth", map_location='cpu'))
51
- model.eval()
52
- return model
53
-
54
- transform = transforms.Compose([
55
- transforms.Resize(256),
56
- transforms.CenterCrop(224),
57
- transforms.ToTensor(),
58
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
- ])
60
-
61
- def predict_image(model, image):
62
- img_tensor = transform(image).unsqueeze(0)
63
- with torch.no_grad():
64
- outputs = model(img_tensor)
65
- _, pred = torch.max(outputs, 1)
66
- prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100
67
- return class_names[pred.item()], prob
68
-
69
- uploaded_file = st.file_uploader("Choose a retinal image", type=["jpg", "png"])
70
- if uploaded_file is not None:
71
- image = Image.open(uploaded_file).convert('RGB')
72
- st.image(image, caption='Uploaded Retinal Image', use_column_width=True)
73
-
74
- if st.button("🧠 Predict"):
75
- with st.spinner('Analyzing image...'):
76
- model = load_model()
77
- pred_class, prob = predict_image(model, image)
78
- st.success(f"Prediction: **{pred_class}** ({prob:.2f}% confidence)")
79
-
80
-
81
- # pages/2_Model_Evaluation.py
82
- import streamlit as st
83
- import torch
84
- from torch.utils.data import DataLoader
85
- from torchvision import datasets, transforms, models
86
- import torch.nn as nn
87
- from tqdm import tqdm
88
-
89
- st.title("πŸ“ˆ Model Evaluation on Test Dataset")
90
-
91
- @st.cache_data
92
-
93
- def load_test_data():
94
- transform = transforms.Compose([
95
- transforms.Resize(256),
96
- transforms.CenterCrop(224),
97
- transforms.ToTensor(),
98
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
99
- ])
100
- test_data = datasets.ImageFolder("test_dataset_path", transform=transform)
101
- return DataLoader(test_data, batch_size=32, shuffle=False)
102
-
103
- def evaluate(model, loader):
104
- model.eval()
105
- correct, total, loss = 0, 0, 0.0
106
- criterion = nn.CrossEntropyLoss()
107
- with torch.no_grad():
108
- for inputs, labels in loader:
109
- outputs = model(inputs)
110
- loss += criterion(outputs, labels).item()
111
- _, pred = torch.max(outputs, 1)
112
- correct += (pred == labels).sum().item()
113
- total += labels.size(0)
114
- return loss / len(loader), correct / total * 100
115
-
116
- if st.button("πŸ§ͺ Evaluate Trained Model"):
117
- test_loader = load_test_data()
118
- model = models.densenet121(pretrained=False)
119
- model.classifier = nn.Linear(model.classifier.in_features, 5)
120
- model.load_state_dict(torch.load("dr_densenet121.pth", map_location='cpu'))
121
- model.eval()
122
-
123
- loss, acc = evaluate(model, test_loader)
124
- st.write(f"**Test Loss:** {loss:.4f}")
125
- st.write(f"**Test Accuracy:** {acc:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
 
4
  def main():
5
  st.set_page_config(page_title="DR Assistive Tool", layout="centered")
6
+ # st.image("assets/banner.jpg", use_column_width=True)
7
+ st.markdown("<h1 style='text-align: center; color: #2E86C1;'>DR Assistive Tool</h1>", unsafe_allow_html=True)
8
 
9
  st.markdown("""
10
+ <h4 style='text-align: center; color: grey;'>An AI-powered assistant for early detection of Diabetic Retinopathy</h4>
11
+ """, unsafe_allow_html=True)
 
 
 
12
 
13
+ st.markdown("""
14
+ ---
15
+ ### πŸ›  Features:
16
+ - **Upload** retinal images to predict DR stage.
17
+ - **Evaluate** the model using real test datasets.
18
+
19
+ πŸ‘‰ Use the sidebar to navigate between features.
20
  """)
21
 
 
 
 
22
  if __name__ == '__main__':
23
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/Model_Evaluation.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms, models
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ import pandas as pd
8
+ import os
9
+ import cv2
10
+ import numpy as np
11
+
12
+ st.markdown("<h2 style='color: #2E86C1;'>πŸ“ˆ Model Evaluation</h2>", unsafe_allow_html=True)
13
+
14
+ # Define class names and label map
15
+ class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
16
+ label_map = {label: idx for idx, label in enumerate(class_names)}
17
+
18
+ # Define your image preprocessing functions
19
+ def apply_median_filter(image):
20
+ return cv2.medianBlur(image, 5)
21
+
22
+ def apply_clahe(image):
23
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
24
+ l, a, b = cv2.split(lab)
25
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
26
+ cl = clahe.apply(l)
27
+ merged = cv2.merge((cl, a, b))
28
+ return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
29
+
30
+ def apply_gamma_correction(image, gamma=1.5):
31
+ invGamma = 1.0 / gamma
32
+ table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(256)]).astype("uint8")
33
+ return cv2.LUT(image, table)
34
+
35
+ def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):
36
+ return cv2.GaussianBlur(image, kernel_size, sigma)
37
+
38
+ # Custom dataset with preprocessing
39
+ class DDRDataset(Dataset):
40
+ def __init__(self, csv_path, img_dir, transform=None):
41
+ self.data = pd.read_csv(csv_path)
42
+ self.img_dir = img_dir
43
+ self.transform = transform
44
+
45
+ def __len__(self):
46
+ return len(self.data)
47
+
48
+ def __getitem__(self, idx):
49
+ img_name = self.data.iloc[idx, 0]
50
+ label_name = self.data.iloc[idx, 1]
51
+ label = int(label_map.get(label_name, 0)) # fallback to 0
52
+
53
+ img_path = os.path.join(self.img_dir, img_name)
54
+ image = cv2.imread(img_path)
55
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
56
+
57
+ image = apply_median_filter(image)
58
+ image = apply_clahe(image)
59
+ image = apply_gamma_correction(image)
60
+ image = apply_gaussian_filter(image)
61
+
62
+ image = Image.fromarray(image)
63
+ if self.transform:
64
+ image = self.transform(image)
65
+
66
+ return image, label
67
+
68
+ # -------------------------------
69
+ # Load Test Data with Caching
70
+ # -------------------------------
71
+ @st.cache_resource
72
+ def load_test_data():
73
+ transform = transforms.Compose([
74
+ transforms.Resize((224, 224)),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize([0.485, 0.456, 0.406],
77
+ [0.229, 0.224, 0.225])
78
+ ])
79
+ dataset = DDRDataset(
80
+ csv_path="D:/DR_Classification/splits/test_labels.csv",
81
+ img_dir="D:/DR_Classification/splits/test",
82
+ transform=transform
83
+ )
84
+ return DataLoader(dataset, batch_size=32, shuffle=False)
85
+
86
+ # -------------------------------
87
+ # Evaluation Function
88
+ # -------------------------------
89
+ def evaluation_test_model(model, test_loader, criterion, device='cpu'):
90
+ model.eval()
91
+ running_loss = 0.0
92
+ correct = 0
93
+ total = 0
94
+
95
+ with torch.no_grad():
96
+ for inputs, labels in test_loader:
97
+ inputs, labels = inputs.to(device), labels.to(device)
98
+ outputs = model(inputs)
99
+ loss = criterion(outputs, labels)
100
+
101
+ running_loss += loss.item()
102
+ _, predicted = torch.max(outputs, 1)
103
+ correct += (predicted == labels).sum().item()
104
+ total += labels.size(0)
105
+
106
+ val_loss = running_loss / len(test_loader)
107
+ val_acc = correct / total * 100
108
+ return val_loss, val_acc
109
+
110
+ # -------------------------------
111
+ # Evaluation Button Trigger
112
+ # -------------------------------
113
+ if st.button("πŸ” Evaluate Trained Model"):
114
+ with st.spinner("Evaluating on test data..."):
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+
117
+ test_loader = load_test_data()
118
+
119
+ model = models.densenet121(pretrained=False)
120
+ model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
121
+ model.load_state_dict(torch.load("training/Pretrained_Densenet-121.pth", map_location=device))
122
+ model.to(device)
123
+
124
+ criterion = nn.CrossEntropyLoss()
125
+ val_loss, val_acc = evaluation_test_model(model, test_loader, criterion, device)
126
+
127
+ st.success("βœ… Evaluation Complete")
128
+ st.metric("Test Loss", f"{val_loss:.4f}")
129
+ st.metric("Test Accuracy", f"{val_acc:.2f}%")
pages/Upload_and_Predict.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ st.markdown("<h2 style='color: #2E86C1;'>πŸ“· Upload & Predict</h2>", unsafe_allow_html=True)
8
+
9
+ class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
10
+
11
+ @st.cache_resource
12
+ def load_model():
13
+ model = models.densenet121(pretrained=False)
14
+ model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names))
15
+ model.load_state_dict(torch.load("training/Pretrained_Densenet-121.pth", map_location='cpu'))
16
+ model.eval()
17
+ return model
18
+
19
+ transform = transforms.Compose([
20
+ transforms.Resize(256),
21
+ transforms.CenterCrop(224),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ ])
25
+
26
+ def predict_image(model, image):
27
+ img_tensor = transform(image).unsqueeze(0)
28
+ with torch.no_grad():
29
+ outputs = model(img_tensor)
30
+ _, pred = torch.max(outputs, 1)
31
+ prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100
32
+ return class_names[pred.item()], prob
33
+
34
+ uploaded_file = st.file_uploader("πŸ“ Upload Retinal Image", type=["jpg", "png"])
35
+
36
+ if uploaded_file is not None:
37
+ image = Image.open(uploaded_file).convert('RGB')
38
+ st.image(image, caption='πŸ–Ό Uploaded Image', use_column_width=True)
39
+
40
+ if st.button("🧠 Predict"):
41
+ with st.spinner('Analyzing image...'):
42
+ model = load_model()
43
+ pred_class, prob = predict_image(model, image)
44
+ st.success(f"🎯 Prediction: **{pred_class}** ({prob:.2f}% confidence)")