vericudebuget commited on
Commit
f8d23e1
·
verified ·
1 Parent(s): ff71fa8

Create denoising_model.py

Browse files
Files changed (1) hide show
  1. denoising_model.py +259 -0
denoising_model.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torchvision.transforms import ToTensor
7
+ import time
8
+ from datetime import datetime
9
+ import multiprocessing
10
+
11
+ def get_optimal_threads():
12
+ """Calculate optimal number of threads based on CPU cores"""
13
+ return max(1, multiprocessing.cpu_count() - 1) # Leave one core free for system
14
+
15
+ # Simple UNet-style denoising model
16
+ class DenoisingModel(nn.Module):
17
+ def __init__(self):
18
+ super(DenoisingModel, self).__init__()
19
+ # Encoder
20
+ self.enc1 = nn.Sequential(
21
+ nn.Conv2d(3, 64, 3, padding=1),
22
+ nn.ReLU(),
23
+ nn.Conv2d(64, 64, 3, padding=1),
24
+ nn.ReLU()
25
+ )
26
+ self.pool1 = nn.MaxPool2d(2, 2)
27
+
28
+ # Decoder
29
+ self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
30
+ self.dec1 = nn.Sequential(
31
+ nn.Conv2d(64, 64, 3, padding=1),
32
+ nn.ReLU(),
33
+ nn.Conv2d(64, 3, 3, padding=1)
34
+ )
35
+
36
+ def forward(self, x):
37
+ # Encoder
38
+ e1 = self.enc1(x)
39
+ p1 = self.pool1(e1)
40
+
41
+ # Decoder
42
+ u1 = self.up1(p1)
43
+ d1 = self.dec1(u1)
44
+ return d1
45
+
46
+ class DenoiseDataset(Dataset):
47
+ def __init__(self, noisy_folder, target_folder, patch_size=256):
48
+ self.noisy_folder = noisy_folder
49
+ self.target_folder = target_folder
50
+ self.patch_size = patch_size
51
+ self.image_pairs = [
52
+ (os.path.join(noisy_folder, f), os.path.join(target_folder, f.replace("_noisy", "_target")))
53
+ for f in os.listdir(noisy_folder) if "_noisy" in f
54
+ ]
55
+ self.transform = ToTensor()
56
+
57
+ print(f"Dataset initialization:")
58
+ print(f"- Noisy folder: {noisy_folder}")
59
+ print(f"- Target folder: {target_folder}")
60
+ print(f"- Patch size: {patch_size}")
61
+ print(f"- Found {len(self.image_pairs)} image pairs")
62
+
63
+ if not self.image_pairs:
64
+ raise ValueError("No image pairs found. Check if noisy and target images are correctly named.")
65
+
66
+ # Precalculate number of patches per image for better performance
67
+ self.patches_per_image = {}
68
+ for noisy_path, _ in self.image_pairs:
69
+ try:
70
+ self.patches_per_image[noisy_path] = self._get_num_patches_per_image(noisy_path)
71
+ except Exception as e:
72
+ print(f"Error calculating patches for {noisy_path}: {e}. Skipping this image pair.")
73
+ self.image_pairs = [(n, t) for n, t in self.image_pairs if n != noisy_path]
74
+
75
+ self.total_patches = sum(self.patches_per_image.values())
76
+
77
+ def __len__(self):
78
+ return self.total_patches
79
+
80
+ def __getitem__(self, idx):
81
+ image_idx = 0
82
+ cumulative_patches = 0
83
+
84
+ for i, (noisy_path, _) in enumerate(self.image_pairs):
85
+ num_patches = self.patches_per_image[noisy_path]
86
+ if cumulative_patches + num_patches > idx:
87
+ image_idx = i
88
+ break
89
+ cumulative_patches += num_patches
90
+
91
+ patch_idx = idx - cumulative_patches
92
+ noisy_path, target_path = self.image_pairs[image_idx]
93
+
94
+ try:
95
+ noisy_image = self._load_image(noisy_path)
96
+ target_image = self._load_image(target_path)
97
+ except Exception as e:
98
+ print(f"Error loading image pair ({noisy_path}, {target_path}): {e}. Returning default values.")
99
+ return torch.zeros((3, self.patch_size, self.patch_size)), torch.zeros((3, self.patch_size, self.patch_size))
100
+
101
+ try:
102
+ noisy_patch = self._get_patch(noisy_image, patch_idx)
103
+ target_patch = self._get_patch(target_image, patch_idx)
104
+ except Exception as e:
105
+ print(f"Error getting patch from image pair ({noisy_path}, {target_path}): {e}. Returning default values.")
106
+ return torch.zeros((3, self.patch_size, self.patch_size)), torch.zeros((3, self.patch_size, self.patch_size))
107
+
108
+ return noisy_patch, target_patch
109
+
110
+ def _load_image(self, image_path):
111
+ try:
112
+ image = Image.open(image_path).convert("RGB")
113
+ return self.transform(image)
114
+ except Exception as e:
115
+ raise Exception(f"Error loading image {image_path}: {e}")
116
+
117
+ def _get_num_patches_per_image(self, image_path):
118
+ try:
119
+ image = Image.open(image_path)
120
+ width, height = image.size
121
+ num_patches = (width // self.patch_size) * (height // self.patch_size)
122
+ return num_patches
123
+ except Exception as e:
124
+ raise Exception(f"Error calculating patches for {image_path}: {e}")
125
+
126
+ def _get_patch(self, image, patch_idx):
127
+ width, height = image.shape[2], image.shape[1]
128
+ patches_per_row = width // self.patch_size
129
+ row = patch_idx // patches_per_row
130
+ col = patch_idx % patches_per_row
131
+
132
+ x_start = col * self.patch_size
133
+ y_start = row * self.patch_size
134
+ return image[:, y_start:y_start+self.patch_size, x_start:x_start+self.patch_size]
135
+
136
+
137
+ def train_model(noisy_dir, target_dir, epochs, batch_size, learning_rate, save_interval, num_workers):
138
+ # Set up CUDA if available
139
+ if torch.cuda.is_available():
140
+ torch.backends.cudnn.benchmark = True # Enable cuDNN auto-tuner
141
+ device = torch.device("cuda")
142
+ print(f"\nUsing GPU: {torch.cuda.get_device_name(0)}")
143
+ print(f"CUDA version: {torch.version.cuda}")
144
+ else:
145
+ device = torch.device("cpu")
146
+ print("\nNo GPU detected, using CPU")
147
+
148
+ # Create output directory for models
149
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
150
+ output_dir = f"model_checkpoints_{timestamp}"
151
+ os.makedirs(output_dir, exist_ok=True)
152
+
153
+ print("\nTraining Configuration:")
154
+ print(f"- Number of epochs: {epochs}")
155
+ print(f"- Batch size: {batch_size}")
156
+ print(f"- Learning rate: {learning_rate}")
157
+ print(f"- Number of worker threads: {num_workers}")
158
+ print(f"- Model checkpoint directory: {output_dir}")
159
+
160
+ # Initialize dataset and dataloader with specified number of workers
161
+ dataset = DenoiseDataset(noisy_dir, target_dir)
162
+ dataloader = DataLoader(
163
+ dataset,
164
+ batch_size=batch_size,
165
+ shuffle=True,
166
+ num_workers=num_workers,
167
+ pin_memory=True if torch.cuda.is_available() else False
168
+ )
169
+
170
+ # Initialize model, loss function, and optimizer
171
+ model = DenoisingModel().to(device)
172
+ if torch.cuda.device_count() > 1:
173
+ print(f"Using {torch.cuda.device_count()} GPUs!")
174
+ model = nn.DataParallel(model)
175
+
176
+ criterion = nn.MSELoss()
177
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
178
+
179
+ # Training loop
180
+ total_batches = len(dataloader)
181
+ start_time = time.time()
182
+
183
+ print("\nStarting training...")
184
+ for epoch in range(epochs):
185
+ epoch_loss = 0.0
186
+ for batch_idx, (noisy_patches, target_patches) in enumerate(dataloader):
187
+ # Move data to device
188
+ noisy_patches = noisy_patches.to(device, non_blocking=True)
189
+ target_patches = target_patches.to(device, non_blocking=True)
190
+
191
+ # Forward pass
192
+ outputs = model(noisy_patches)
193
+ loss = criterion(outputs, target_patches)
194
+
195
+ # Backward pass and optimize
196
+ optimizer.zero_grad()
197
+ loss.backward()
198
+ optimizer.step()
199
+
200
+ # Update epoch loss
201
+ epoch_loss += loss.item()
202
+
203
+ # Print progress
204
+ if (batch_idx + 1) % 100 == 0:
205
+ elapsed_time = time.time() - start_time
206
+ print(f"Epoch [{epoch+1}/{epochs}], "
207
+ f"Batch [{batch_idx+1}/{total_batches}], "
208
+ f"Loss: {loss.item():.6f}, "
209
+ f"Time: {elapsed_time:.2f}s")
210
+
211
+ # Save model checkpoint
212
+ if (batch_idx + 1) % save_interval == 0:
213
+ checkpoint_path = os.path.join(output_dir,
214
+ f"denoising_model_epoch{epoch+1}_batch{batch_idx+1}.pth")
215
+ torch.save({
216
+ 'epoch': epoch,
217
+ 'batch': batch_idx,
218
+ 'model_state_dict': model.state_dict(),
219
+ 'optimizer_state_dict': optimizer.state_dict(),
220
+ 'loss': loss.item(),
221
+ }, checkpoint_path)
222
+ print(f"\nCheckpoint saved: {checkpoint_path}")
223
+
224
+ # End of epoch summary
225
+ avg_epoch_loss = epoch_loss / total_batches
226
+ print(f"\nEpoch [{epoch+1}/{epochs}] completed. "
227
+ f"Average loss: {avg_epoch_loss:.6f}")
228
+
229
+ # Save epoch checkpoint
230
+ checkpoint_path = os.path.join(output_dir, f"denoising_model_epoch{epoch+1}.pth")
231
+ torch.save({
232
+ 'epoch': epoch,
233
+ 'model_state_dict': model.state_dict(),
234
+ 'optimizer_state_dict': optimizer.state_dict(),
235
+ 'loss': avg_epoch_loss,
236
+ }, checkpoint_path)
237
+ print(f"Epoch checkpoint saved: {checkpoint_path}")
238
+
239
+ print("\nTraining completed!")
240
+ print(f"Total training time: {time.time() - start_time:.2f} seconds")
241
+
242
+ # Save final model
243
+ final_model_path = os.path.join(output_dir, "denoising_model_final.pth")
244
+ torch.save(model.state_dict(), final_model_path)
245
+ print(f"Final model saved: {final_model_path}")
246
+
247
+ def main():
248
+ noisy_dir = 'noisy_images'
249
+ target_dir = 'target_images'
250
+ epochs = 10
251
+ batch_size = 4
252
+ learning_rate = 0.001
253
+ save_interval = 1000
254
+ num_workers = get_optimal_threads()
255
+
256
+ train_model(noisy_dir, target_dir, epochs, batch_size, learning_rate, save_interval, num_workers)
257
+
258
+ if __name__ == "__main__":
259
+ main()