Spaces:
Sleeping
Sleeping
added image split
Browse files- .gitignore +3 -0
- app.py +0 -0
- training/training.ipynb +704 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
dataset/
|
2 |
+
splits/
|
3 |
+
.venv/
|
app.py
ADDED
File without changes
|
training/training.ipynb
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "46ebdaab",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## This is Classification Model Training for DDR dataset. This model uses Densenet-121 a pretrained model.\n",
|
9 |
+
"\n"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "markdown",
|
14 |
+
"id": "ad55f9e5",
|
15 |
+
"metadata": {},
|
16 |
+
"source": [
|
17 |
+
"### STEPS for model training\n",
|
18 |
+
"##### Step 1. Define Preprocessing Function \n",
|
19 |
+
"###### Median Filtering, CLAHE, Gamma Correction, ESRGAN(optional)\n",
|
20 |
+
"##### Step 2. Create Custom Dataset for preprocessing (Pytorch doesn't supports Custom Preprocessing)\n",
|
21 |
+
"##### Step 3. Define Transform (with data augmentation)\n",
|
22 |
+
"##### Step 4. Create datasets and Dataloaders\n",
|
23 |
+
"##### Step 5. Load DenseNet-121 Model\n",
|
24 |
+
"##### Step 6. Define Loss, Optimizer, Scheduler\n",
|
25 |
+
"##### Step 7. Model Training\n",
|
26 |
+
"##### Step 8. Validation and Evaluation\n",
|
27 |
+
"##### Step 9. Save the Model"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "markdown",
|
32 |
+
"id": "58a3a0fd",
|
33 |
+
"metadata": {},
|
34 |
+
"source": [
|
35 |
+
"## Dataset Descrption\n",
|
36 |
+
"\n",
|
37 |
+
"### DDR dataset contains 13,673 fundus images from 147 hospitals, covering 23 provinces in China. The images are classified into 5 classes according to DR severity: none, mild, moderate, severe, and proliferative DR. There is a sixth category which indicates the images with poor quality. The dataset presented here does not include the images with poor quality (sixth category) and all images have been preprocessed to delete the black background. https://www.kaggle.com/datasets/mariaherrerot/ddrdataset"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "markdown",
|
42 |
+
"id": "255d5477",
|
43 |
+
"metadata": {},
|
44 |
+
"source": [
|
45 |
+
"### Import Libraries"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 1,
|
51 |
+
"id": "80aeae51",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"import os\n",
|
56 |
+
"import cv2\n",
|
57 |
+
"import numpy as np\n",
|
58 |
+
"import pandas as pd\n",
|
59 |
+
"from PIL import Image\n",
|
60 |
+
"from matplotlib import pyplot as plt\n",
|
61 |
+
"from sklearn.model_selection import train_test_split\n",
|
62 |
+
"import shutil \n",
|
63 |
+
"import torch\n",
|
64 |
+
"import torch.nn as nn\n",
|
65 |
+
"from tqdm import tqdm\n",
|
66 |
+
"import torch.optim as optim\n",
|
67 |
+
"from torch.optim.lr_scheduler import StepLR\n",
|
68 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
69 |
+
"from torchvision import models, transforms"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "markdown",
|
74 |
+
"id": "2a8e8a65",
|
75 |
+
"metadata": {},
|
76 |
+
"source": [
|
77 |
+
"## Load and split to train_test_val DDR Dataset"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 2,
|
83 |
+
"id": "99b3890e",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [
|
86 |
+
{
|
87 |
+
"name": "stdout",
|
88 |
+
"output_type": "stream",
|
89 |
+
"text": [
|
90 |
+
"train_labels.csv already exists. Skipping this split.\n",
|
91 |
+
"test_labels.csv already exists. Skipping this split.\n",
|
92 |
+
"Splits and CSVs checked and saved successfully!\n"
|
93 |
+
]
|
94 |
+
}
|
95 |
+
],
|
96 |
+
"source": [
|
97 |
+
"# Load CSV\n",
|
98 |
+
"df = pd.read_csv(\"D:\\\\DR_Classification\\\\dataset\\\\DR_grading.csv\")\n",
|
99 |
+
"df['image_path'] = df['id_code'].apply(lambda x: os.path.join(\"D:\\\\DR_Classification\\\\dataset\\\\images\", x))\n",
|
100 |
+
"df['label'] = df['diagnosis']\n",
|
101 |
+
"\n",
|
102 |
+
"# Create output directories\n",
|
103 |
+
"output_root = \"D:\\\\DR_Classification\\\\splits\"\n",
|
104 |
+
"os.makedirs(os.path.join(output_root, \"train\"), exist_ok=True)\n",
|
105 |
+
"os.makedirs(os.path.join(output_root, \"test\"), exist_ok=True)\n",
|
106 |
+
"\n",
|
107 |
+
"# Split: train vs test\n",
|
108 |
+
"train_df, test_df = train_test_split(df, test_size=0.3, stratify=df['label'], random_state=42)\n",
|
109 |
+
"\n",
|
110 |
+
"# Function to copy images and save CSV\n",
|
111 |
+
"def save_split(df_split, split_name):\n",
|
112 |
+
" split_folder = os.path.join(output_root, split_name)\n",
|
113 |
+
" new_paths = []\n",
|
114 |
+
"\n",
|
115 |
+
" # Check if the CSV for the current split already exists\n",
|
116 |
+
" csv_path = os.path.join(output_root, f\"{split_name}_labels.csv\")\n",
|
117 |
+
" if os.path.exists(csv_path):\n",
|
118 |
+
" print(f\"{split_name}_labels.csv already exists. Skipping this split.\")\n",
|
119 |
+
" return # Skip processing if the CSV file exists\n",
|
120 |
+
"\n",
|
121 |
+
" # Proceed with copying images and saving the new CSV if the CSV does not exist\n",
|
122 |
+
" for _, row in df_split.iterrows():\n",
|
123 |
+
" src = row['image_path']\n",
|
124 |
+
" dst = os.path.join(split_folder, os.path.basename(src))\n",
|
125 |
+
"\n",
|
126 |
+
" if os.path.exists(src):\n",
|
127 |
+
" shutil.copy(src, dst)\n",
|
128 |
+
" new_paths.append(dst)\n",
|
129 |
+
" else:\n",
|
130 |
+
" print(f\"Warning: Missing image file {src}\")\n",
|
131 |
+
"\n",
|
132 |
+
" df_split = df_split.copy()\n",
|
133 |
+
" df_split['new_path'] = new_paths\n",
|
134 |
+
" df_split[['id_code', 'label', 'new_path']].to_csv(csv_path, index=False)\n",
|
135 |
+
" print(f\"{split_name}_labels.csv saved successfully!\")\n",
|
136 |
+
"\n",
|
137 |
+
"# Save each split\n",
|
138 |
+
"save_split(train_df, \"train\")\n",
|
139 |
+
"save_split(test_df, \"test\")\n",
|
140 |
+
"\n",
|
141 |
+
"print(\"Splits and CSVs checked and saved successfully!\")"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": 3,
|
147 |
+
"id": "d6bda349",
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [
|
150 |
+
{
|
151 |
+
"name": "stdout",
|
152 |
+
"output_type": "stream",
|
153 |
+
"text": [
|
154 |
+
"0 D:\\DR_Classification\\dataset\\images\\2017041310...\n",
|
155 |
+
"1 D:\\DR_Classification\\dataset\\images\\2017041311...\n",
|
156 |
+
"2 D:\\DR_Classification\\dataset\\images\\2017041311...\n",
|
157 |
+
"3 D:\\DR_Classification\\dataset\\images\\2017041311...\n",
|
158 |
+
"4 D:\\DR_Classification\\dataset\\images\\2017041311...\n",
|
159 |
+
"Name: image_path, dtype: object\n"
|
160 |
+
]
|
161 |
+
}
|
162 |
+
],
|
163 |
+
"source": [
|
164 |
+
"print(df['image_path'].head())"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": 4,
|
170 |
+
"id": "734943e9",
|
171 |
+
"metadata": {},
|
172 |
+
"outputs": [
|
173 |
+
{
|
174 |
+
"name": "stdout",
|
175 |
+
"output_type": "stream",
|
176 |
+
"text": [
|
177 |
+
"Train class distribution:\n",
|
178 |
+
"label\n",
|
179 |
+
"0 4386\n",
|
180 |
+
"2 3134\n",
|
181 |
+
"4 639\n",
|
182 |
+
"1 441\n",
|
183 |
+
"3 165\n",
|
184 |
+
"Name: count, dtype: int64\n",
|
185 |
+
"\n",
|
186 |
+
"Test class distribution:\n",
|
187 |
+
"label\n",
|
188 |
+
"0 1880\n",
|
189 |
+
"2 1343\n",
|
190 |
+
"4 274\n",
|
191 |
+
"1 189\n",
|
192 |
+
"3 71\n",
|
193 |
+
"Name: count, dtype: int64\n",
|
194 |
+
"Total: 12522 images\n",
|
195 |
+
"Train: 8765\n",
|
196 |
+
"Test: 3757\n"
|
197 |
+
]
|
198 |
+
}
|
199 |
+
],
|
200 |
+
"source": [
|
201 |
+
"# Count samples per class in each split\n",
|
202 |
+
"print(\"Train class distribution:\")\n",
|
203 |
+
"print(train_df['label'].value_counts())\n",
|
204 |
+
"\n",
|
205 |
+
"print(\"\\nTest class distribution:\")\n",
|
206 |
+
"print(test_df['label'].value_counts())\n",
|
207 |
+
"\n",
|
208 |
+
"# Count the total number of images in each split\n",
|
209 |
+
"print(f\"Total: {len(df)} images\")\n",
|
210 |
+
"print(f\"Train: {len(train_df)}\")\n",
|
211 |
+
"print(f\"Test: {len(test_df)}\")"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 15,
|
217 |
+
"id": "3264f03c",
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"# Load the split CSVs\n",
|
222 |
+
"train_df = pd.read_csv(\"D:/DR_Classification/splits/train_labels.csv\")\n",
|
223 |
+
"test_df = pd.read_csv(\"D:/DR_Classification/splits/test_labels.csv\")\n",
|
224 |
+
"\n",
|
225 |
+
"# Extract paths and labels\n",
|
226 |
+
"train_paths = train_df['new_path'].tolist()\n",
|
227 |
+
"train_labels = train_df['label'].tolist()\n",
|
228 |
+
"\n",
|
229 |
+
"test_paths = test_df['new_path'].tolist()\n",
|
230 |
+
"test_labels = test_df['label'].tolist()"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "markdown",
|
235 |
+
"id": "85dc2767",
|
236 |
+
"metadata": {},
|
237 |
+
"source": [
|
238 |
+
"### Step 1. Define Preprocessing Functions"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"execution_count": 16,
|
244 |
+
"id": "0b0e9212",
|
245 |
+
"metadata": {},
|
246 |
+
"outputs": [],
|
247 |
+
"source": [
|
248 |
+
"def apply_median_filter(image):\n",
|
249 |
+
" return cv2.medianBlur(image, 3)\n",
|
250 |
+
"\n",
|
251 |
+
"def apply_clahe(image):\n",
|
252 |
+
" lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)\n",
|
253 |
+
" l, a, b = cv2.split(lab)\n",
|
254 |
+
" clahe = cv2.createCLAHE(clipLimit=2.0)\n",
|
255 |
+
" cl = clahe.apply(l)\n",
|
256 |
+
" merged = cv2.merge((cl, a, b))\n",
|
257 |
+
" return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)\n",
|
258 |
+
"\n",
|
259 |
+
"def apply_gamma_correction(image, gamma=1.2):\n",
|
260 |
+
" invGamma = 1.0 / gamma\n",
|
261 |
+
" table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(0, 256)]).astype(\"uint8\")\n",
|
262 |
+
" return cv2.LUT(image, table)\n",
|
263 |
+
"\n",
|
264 |
+
"def apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0):\n",
|
265 |
+
" return cv2.GaussianBlur(image, kernel_size, sigma)\n"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "markdown",
|
270 |
+
"id": "1e34f571",
|
271 |
+
"metadata": {},
|
272 |
+
"source": [
|
273 |
+
"#### For the ESRGAN if applicable"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "code",
|
278 |
+
"execution_count": 7,
|
279 |
+
"id": "dd662d1d",
|
280 |
+
"metadata": {},
|
281 |
+
"outputs": [],
|
282 |
+
"source": [
|
283 |
+
"# def apply_esrgan(image):\n",
|
284 |
+
"# # Placeholder: You should replace with actual ESRGAN model output\n",
|
285 |
+
"# return enhance_with_esrgan(image)"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "markdown",
|
290 |
+
"id": "9470983d",
|
291 |
+
"metadata": {},
|
292 |
+
"source": [
|
293 |
+
"### Step 2. Create Custom Dataset"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"cell_type": "code",
|
298 |
+
"execution_count": 17,
|
299 |
+
"id": "165f2fc2",
|
300 |
+
"metadata": {},
|
301 |
+
"outputs": [],
|
302 |
+
"source": [
|
303 |
+
"# Apply customary dataset\n",
|
304 |
+
"class DDRDataset(Dataset):\n",
|
305 |
+
" def __init__(self, image_paths, labels, transform=None):\n",
|
306 |
+
" self.image_paths = image_paths\n",
|
307 |
+
" self.labels = labels\n",
|
308 |
+
" self.transform = transform\n",
|
309 |
+
"\n",
|
310 |
+
" def __len__(self):\n",
|
311 |
+
" return len(self.image_paths)\n",
|
312 |
+
"\n",
|
313 |
+
" def __getitem__(self, idx):\n",
|
314 |
+
" img_path = self.image_paths[idx]\n",
|
315 |
+
" image = cv2.imread(img_path)\n",
|
316 |
+
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
317 |
+
"\n",
|
318 |
+
" # Preprocessing steps\n",
|
319 |
+
" image = apply_median_filter(image)\n",
|
320 |
+
" image = apply_clahe(image)\n",
|
321 |
+
" image = apply_gamma_correction(image)\n",
|
322 |
+
" \n",
|
323 |
+
" # Apply Gaussian filter\n",
|
324 |
+
" image = apply_gaussian_filter(image, kernel_size=(5, 5), sigma=1.0)\n",
|
325 |
+
"\n",
|
326 |
+
" image = Image.fromarray(image)\n",
|
327 |
+
" if self.transform:\n",
|
328 |
+
" image = self.transform(image)\n",
|
329 |
+
"\n",
|
330 |
+
" label = torch.tensor(self.labels[idx], dtype=torch.long)\n",
|
331 |
+
" return image, label\n"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "markdown",
|
336 |
+
"id": "7b2fb514",
|
337 |
+
"metadata": {},
|
338 |
+
"source": [
|
339 |
+
"### Step 3. Define Transforms with data augmentations"
|
340 |
+
]
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"cell_type": "code",
|
344 |
+
"execution_count": 18,
|
345 |
+
"id": "0a59f18e",
|
346 |
+
"metadata": {},
|
347 |
+
"outputs": [],
|
348 |
+
"source": [
|
349 |
+
"train_transform = transforms.Compose([\n",
|
350 |
+
" transforms.Resize((224, 224)),\n",
|
351 |
+
" transforms.RandomHorizontalFlip(),\n",
|
352 |
+
" transforms.RandomVerticalFlip(), # Add vertical flip\n",
|
353 |
+
" transforms.RandomRotation(degrees=15), # Random rotation\n",
|
354 |
+
" transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), # Random resized crop\n",
|
355 |
+
" transforms.ToTensor(),\n",
|
356 |
+
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet pretraining normalization\n",
|
357 |
+
"])\n",
|
358 |
+
"\n",
|
359 |
+
"val_transform = transforms.Compose([\n",
|
360 |
+
" transforms.Resize((224, 224)),\n",
|
361 |
+
" transforms.ToTensor(),\n",
|
362 |
+
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # this is the same as training\n",
|
363 |
+
"])\n"
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"cell_type": "markdown",
|
368 |
+
"id": "e6a2cc7e",
|
369 |
+
"metadata": {},
|
370 |
+
"source": [
|
371 |
+
"### Step 4. Create Datasets and Dataloaders"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "code",
|
376 |
+
"execution_count": 19,
|
377 |
+
"id": "419c07cb",
|
378 |
+
"metadata": {},
|
379 |
+
"outputs": [],
|
380 |
+
"source": [
|
381 |
+
"from sklearn.model_selection import train_test_split"
|
382 |
+
]
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"cell_type": "code",
|
386 |
+
"execution_count": 20,
|
387 |
+
"id": "cd96cd1b",
|
388 |
+
"metadata": {},
|
389 |
+
"outputs": [],
|
390 |
+
"source": [
|
391 |
+
"# Dataset class should accept image paths and labels\n",
|
392 |
+
"train_dataset = DDRDataset(train_paths, train_labels, transform=train_transform)\n",
|
393 |
+
"test_dataset = DDRDataset(test_paths, test_labels, transform=val_transform)\n",
|
394 |
+
"\n",
|
395 |
+
"# Create dataloaders\n",
|
396 |
+
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
|
397 |
+
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "markdown",
|
402 |
+
"id": "626d78b3",
|
403 |
+
"metadata": {},
|
404 |
+
"source": [
|
405 |
+
"### Step 5. Load DenseNet-121 model"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "code",
|
410 |
+
"execution_count": 21,
|
411 |
+
"id": "808006f9",
|
412 |
+
"metadata": {},
|
413 |
+
"outputs": [
|
414 |
+
{
|
415 |
+
"name": "stderr",
|
416 |
+
"output_type": "stream",
|
417 |
+
"text": [
|
418 |
+
"d:\\DR_Classification\\.venv\\Lib\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
419 |
+
" warnings.warn(\n",
|
420 |
+
"d:\\DR_Classification\\.venv\\Lib\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=DenseNet121_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet121_Weights.DEFAULT` to get the most up-to-date weights.\n",
|
421 |
+
" warnings.warn(msg)\n"
|
422 |
+
]
|
423 |
+
}
|
424 |
+
],
|
425 |
+
"source": [
|
426 |
+
"model = models.densenet121(pretrained=True)\n",
|
427 |
+
"model.classifier = nn.Linear(model.classifier.in_features, 5) # DDR has 5 classes\n",
|
428 |
+
"\n",
|
429 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
430 |
+
"model = model.to(device)"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"cell_type": "markdown",
|
435 |
+
"id": "c9c2bf69",
|
436 |
+
"metadata": {},
|
437 |
+
"source": [
|
438 |
+
"### Step 6. Define Loss, Optimizer, Scheduler, checkpoint"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"cell_type": "code",
|
443 |
+
"execution_count": 22,
|
444 |
+
"id": "0bf18c57",
|
445 |
+
"metadata": {},
|
446 |
+
"outputs": [],
|
447 |
+
"source": [
|
448 |
+
"criterion = nn.CrossEntropyLoss()\n",
|
449 |
+
"optimizer = optim.AdamW(model.parameters(), lr=1e-3)\n",
|
450 |
+
"scheduler = StepLR(optimizer, step_size=5, gamma=0.1)"
|
451 |
+
]
|
452 |
+
},
|
453 |
+
{
|
454 |
+
"cell_type": "markdown",
|
455 |
+
"id": "576c9583",
|
456 |
+
"metadata": {},
|
457 |
+
"source": [
|
458 |
+
"### Step 7. Train the Model"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "code",
|
463 |
+
"execution_count": null,
|
464 |
+
"id": "9b76b385",
|
465 |
+
"metadata": {},
|
466 |
+
"outputs": [
|
467 |
+
{
|
468 |
+
"name": "stderr",
|
469 |
+
"output_type": "stream",
|
470 |
+
"text": [
|
471 |
+
"Epoch 1/15: 0%| | 0/274 [00:00<?, ?it/s]"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
{
|
475 |
+
"name": "stderr",
|
476 |
+
"output_type": "stream",
|
477 |
+
"text": [
|
478 |
+
"Epoch 1/15: 100%|██████████| 274/274 [41:09<00:00, 9.01s/it, accuracy=68.1, loss=0.833]\n"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"name": "stdout",
|
483 |
+
"output_type": "stream",
|
484 |
+
"text": [
|
485 |
+
"Epoch 1/15\n",
|
486 |
+
"Train Loss: 0.8326 | Train Accuracy: 68.09%\n"
|
487 |
+
]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"name": "stderr",
|
491 |
+
"output_type": "stream",
|
492 |
+
"text": [
|
493 |
+
"Epoch 2/15: 4%|▍ | 12/274 [01:44<37:59, 8.70s/it, accuracy=71.6, loss=0.743]"
|
494 |
+
]
|
495 |
+
}
|
496 |
+
],
|
497 |
+
"source": [
|
498 |
+
"def train_model(model, train_loader, criterion, optimizer, scheduler, num_epochs=15, device='cuda'):\n",
|
499 |
+
" model.train() # Set model to training mode\n",
|
500 |
+
" for epoch in range(num_epochs):\n",
|
501 |
+
" running_train_loss = 0.0\n",
|
502 |
+
" correct_train = 0\n",
|
503 |
+
" total_train = 0\n",
|
504 |
+
"\n",
|
505 |
+
" # Add tqdm progress bar\n",
|
506 |
+
" with tqdm(total=len(train_loader), desc=f\"Epoch {epoch+1}/{num_epochs}\") as pbar:\n",
|
507 |
+
" for batch_idx, (inputs, labels) in enumerate(train_loader):\n",
|
508 |
+
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
509 |
+
"\n",
|
510 |
+
" optimizer.zero_grad()\n",
|
511 |
+
" outputs = model(inputs) # Forward pass\n",
|
512 |
+
" loss = criterion(outputs, labels) # Compute loss\n",
|
513 |
+
" loss.backward() # Backpropagate\n",
|
514 |
+
" optimizer.step() # Update weights\n",
|
515 |
+
"\n",
|
516 |
+
" running_train_loss += loss.item()\n",
|
517 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
518 |
+
" correct_train += (predicted == labels).sum().item()\n",
|
519 |
+
" total_train += labels.size(0)\n",
|
520 |
+
"\n",
|
521 |
+
" pbar.set_postfix(loss=running_train_loss / (batch_idx + 1), accuracy=correct_train / total_train * 100)\n",
|
522 |
+
" pbar.update(1) # Update progress bar\n",
|
523 |
+
"\n",
|
524 |
+
" train_loss = running_train_loss / len(train_loader)\n",
|
525 |
+
" train_acc = correct_train / total_train * 100\n",
|
526 |
+
"\n",
|
527 |
+
" print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
|
528 |
+
" print(f\"Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.2f}%\")\n",
|
529 |
+
" scheduler.step() # Step the learning rate scheduler\n",
|
530 |
+
"\n",
|
531 |
+
"# Call the function\n",
|
532 |
+
"train_model(model, train_loader, criterion, optimizer, scheduler, num_epochs=15, device=device)"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "code",
|
537 |
+
"execution_count": null,
|
538 |
+
"id": "d439485d",
|
539 |
+
"metadata": {},
|
540 |
+
"outputs": [],
|
541 |
+
"source": [
|
542 |
+
"# # Training Loop\n",
|
543 |
+
"# num_epochs = 15\n",
|
544 |
+
"# for epoch in range(start_epoch, num_epochs):\n",
|
545 |
+
"# model.train()\n",
|
546 |
+
"# running_loss = 0.0\n",
|
547 |
+
"# correct = 0\n",
|
548 |
+
"# total = 0\n",
|
549 |
+
"\n",
|
550 |
+
"# for images, labels in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs}\"):\n",
|
551 |
+
"# images, labels = images.to(device), labels.to(device)\n",
|
552 |
+
"\n",
|
553 |
+
"# optimizer.zero_grad()\n",
|
554 |
+
"# outputs = model(images)\n",
|
555 |
+
"# loss = criterion(outputs, labels)\n",
|
556 |
+
"# loss.backward()\n",
|
557 |
+
"\n",
|
558 |
+
"# optimizer.step()\n",
|
559 |
+
"\n",
|
560 |
+
"# running_loss += loss.item()\n",
|
561 |
+
"# _, preds = torch.max(outputs, 1)\n",
|
562 |
+
"# correct += (preds == labels).sum().item()\n",
|
563 |
+
"# total += labels.size(0)\n",
|
564 |
+
"\n",
|
565 |
+
"# # Step the learning rate scheduler\n",
|
566 |
+
"# scheduler.step()\n",
|
567 |
+
"\n",
|
568 |
+
"# # Calculate accuracy for this epoch\n",
|
569 |
+
"# epoch_accuracy = 100 * correct / total\n",
|
570 |
+
"# print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%\")\n",
|
571 |
+
"\n",
|
572 |
+
"# # Save checkpoint at the end of each epoch\n",
|
573 |
+
"# save_checkpoint(epoch, model, optimizer, scheduler, running_loss)"
|
574 |
+
]
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"cell_type": "markdown",
|
578 |
+
"id": "fef39141",
|
579 |
+
"metadata": {},
|
580 |
+
"source": [
|
581 |
+
"### Step 8. Validate and Evaluate the Model"
|
582 |
+
]
|
583 |
+
},
|
584 |
+
{
|
585 |
+
"cell_type": "code",
|
586 |
+
"execution_count": null,
|
587 |
+
"id": "2090199a",
|
588 |
+
"metadata": {},
|
589 |
+
"outputs": [],
|
590 |
+
"source": [
|
591 |
+
"def evaluate_on_test_as_val(model, test_loader, criterion, device='cuda'):\n",
|
592 |
+
" model.eval()\n",
|
593 |
+
" running_loss = 0.0\n",
|
594 |
+
" correct = 0\n",
|
595 |
+
" total = 0\n",
|
596 |
+
"\n",
|
597 |
+
" with torch.no_grad():\n",
|
598 |
+
" for inputs, labels in test_loader:\n",
|
599 |
+
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
600 |
+
" outputs = model(inputs)\n",
|
601 |
+
" loss = criterion(outputs, labels)\n",
|
602 |
+
"\n",
|
603 |
+
" running_loss += loss.item()\n",
|
604 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
605 |
+
" correct += (predicted == labels).sum().item()\n",
|
606 |
+
" total += labels.size(0)\n",
|
607 |
+
"\n",
|
608 |
+
" val_loss = running_loss / len(test_loader)\n",
|
609 |
+
" val_acc = correct / total * 100\n",
|
610 |
+
" return val_loss, val_acc\n",
|
611 |
+
"\n",
|
612 |
+
"val_loss, val_acc = evaluate_on_test_as_val(model, test_loader, criterion, device)"
|
613 |
+
]
|
614 |
+
},
|
615 |
+
{
|
616 |
+
"cell_type": "markdown",
|
617 |
+
"id": "c5709926",
|
618 |
+
"metadata": {},
|
619 |
+
"source": [
|
620 |
+
"### Step 9. Visualize predictions"
|
621 |
+
]
|
622 |
+
},
|
623 |
+
{
|
624 |
+
"cell_type": "code",
|
625 |
+
"execution_count": null,
|
626 |
+
"id": "45a03f67",
|
627 |
+
"metadata": {},
|
628 |
+
"outputs": [],
|
629 |
+
"source": [
|
630 |
+
"import matplotlib.pyplot as plt\n",
|
631 |
+
"import numpy as np\n",
|
632 |
+
"\n",
|
633 |
+
"def visualize_predictions(model, dataloader, class_names, device='cuda', num_images=6):\n",
|
634 |
+
" model.eval()\n",
|
635 |
+
" images_shown = 0\n",
|
636 |
+
"\n",
|
637 |
+
" with torch.no_grad():\n",
|
638 |
+
" for inputs, labels in dataloader:\n",
|
639 |
+
" inputs = inputs.to(device)\n",
|
640 |
+
" outputs = model(inputs)\n",
|
641 |
+
" _, preds = torch.max(outputs, 1)\n",
|
642 |
+
"\n",
|
643 |
+
" inputs = inputs.cpu()\n",
|
644 |
+
" labels = labels.cpu()\n",
|
645 |
+
" preds = preds.cpu()\n",
|
646 |
+
"\n",
|
647 |
+
" for i in range(inputs.size(0)):\n",
|
648 |
+
" if images_shown >= num_images:\n",
|
649 |
+
" return\n",
|
650 |
+
" img = inputs[i].permute(1, 2, 0).numpy()\n",
|
651 |
+
" img = (img - img.min()) / (img.max() - img.min()) # normalize for display\n",
|
652 |
+
"\n",
|
653 |
+
" plt.imshow(img)\n",
|
654 |
+
" plt.title(f\"True: {class_names[labels[i]]}, Pred: {class_names[preds[i]]}\")\n",
|
655 |
+
" plt.axis(\"off\")\n",
|
656 |
+
" plt.show()\n",
|
657 |
+
" images_shown += 1\n",
|
658 |
+
"\n",
|
659 |
+
"class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']\n",
|
660 |
+
"visualize_predictions(model, test_loader, class_names, device=device, num_images=6)\n"
|
661 |
+
]
|
662 |
+
},
|
663 |
+
{
|
664 |
+
"cell_type": "markdown",
|
665 |
+
"id": "c2faf756",
|
666 |
+
"metadata": {},
|
667 |
+
"source": [
|
668 |
+
"### Step 10. Save The Model"
|
669 |
+
]
|
670 |
+
},
|
671 |
+
{
|
672 |
+
"cell_type": "code",
|
673 |
+
"execution_count": null,
|
674 |
+
"id": "af93f096",
|
675 |
+
"metadata": {},
|
676 |
+
"outputs": [],
|
677 |
+
"source": [
|
678 |
+
"torch.save(model.state_dict(), \"Densenet121.pth\")\n",
|
679 |
+
"print(\"✅ Final model saved as 'Densenet121.pth'\")"
|
680 |
+
]
|
681 |
+
}
|
682 |
+
],
|
683 |
+
"metadata": {
|
684 |
+
"kernelspec": {
|
685 |
+
"display_name": ".venv",
|
686 |
+
"language": "python",
|
687 |
+
"name": "python3"
|
688 |
+
},
|
689 |
+
"language_info": {
|
690 |
+
"codemirror_mode": {
|
691 |
+
"name": "ipython",
|
692 |
+
"version": 3
|
693 |
+
},
|
694 |
+
"file_extension": ".py",
|
695 |
+
"mimetype": "text/x-python",
|
696 |
+
"name": "python",
|
697 |
+
"nbconvert_exporter": "python",
|
698 |
+
"pygments_lexer": "ipython3",
|
699 |
+
"version": "3.11.1"
|
700 |
+
}
|
701 |
+
},
|
702 |
+
"nbformat": 4,
|
703 |
+
"nbformat_minor": 5
|
704 |
+
}
|