File size: 4,667 Bytes
6221b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# train.py
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset import DIV2KDataset
from models.srcnn import SRCNN
from models.vdsr import VDSR
from models.edsr import EDSR
import math
import numpy as np
class EarlyStopping:
def __init__(self, patience=7, min_delta=0.01, min_psnr_improvement=0.1):
self.patience = patience
self.min_delta = min_delta
self.min_psnr_improvement = min_psnr_improvement
self.counter = 0
self.best_loss = None
self.best_psnr = None
self.early_stop = False
def __call__(self, loss, psnr):
if self.best_loss is None:
self.best_loss = loss
self.best_psnr = psnr
elif (loss > self.best_loss - self.min_delta) and (psnr < self.best_psnr + self.min_psnr_improvement):
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = min(loss, self.best_loss)
self.best_psnr = max(psnr, self.best_psnr)
self.counter = 0
def calculate_psnr(img1, img2):
mse = torch.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
return 20 * math.log10(1.0 / math.sqrt(mse.item()))
def train_model(model_name, train_loader, val_loader, device, num_epochs=100):
# Initialize model
if model_name == 'srcnn':
model = SRCNN()
elif model_name == 'vdsr':
model = VDSR()
else:
model = EDSR()
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Initialize early stopping
early_stopping = EarlyStopping(patience=10, min_delta=0.00001, min_psnr_improvement=0.1)
best_psnr = 0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0
num_batches = 0
for batch_idx, (lr_img, hr_img) in enumerate(train_loader):
lr_img, hr_img = lr_img.to(device), hr_img.to(device)
optimizer.zero_grad()
output = model(lr_img)
loss = criterion(output, hr_img)
loss.backward()
optimizer.step()
train_loss += loss.item()
num_batches += 1
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx}/{len(train_loader)}]\tLoss: {loss.item():.6f}')
avg_train_loss = train_loss / num_batches
# Validation
model.eval()
val_psnr = 0
with torch.no_grad():
for lr_img, hr_img in val_loader:
lr_img, hr_img = lr_img.to(device), hr_img.to(device)
output = model(lr_img)
val_psnr += calculate_psnr(output, hr_img)
val_psnr /= len(val_loader)
print(f'Epoch: {epoch}, Average Loss: {avg_train_loss:.6f}, Average PSNR: {val_psnr:.2f}dB')
# Early stopping check
early_stopping(avg_train_loss, val_psnr)
if early_stopping.early_stop:
print(f"Early stopping triggered at epoch {epoch}")
break
# Save best model
if val_psnr > best_psnr:
best_psnr = val_psnr
torch.save(model.state_dict(), f'checkpoints/{model_name}_best.pth')
print(f'Saved new best model with PSNR: {best_psnr:.2f}dB')
def main():
# Setup
device = torch.device('cpu')
# Data paths
train_hr_dir = 'data/DIV2K_train_HR/DIV2K_train_HR/'
train_lr_dir = 'data/DIV2K_train_LR_bicubic_X4/DIV2K_train_LR_bicubic/X4'
val_hr_dir = 'data/DIV2K_valid_HR/DIV2K_valid_HR'
val_lr_dir = 'data/DIV2K_valid_LR_bicubic_X4/DIV2K_valid_LR_bicubic/X4'
# Create datasets
train_dataset = DIV2KDataset(train_hr_dir, train_lr_dir, patch_size=48)
val_dataset = DIV2KDataset(val_hr_dir, val_lr_dir, patch_size=48)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
# Create checkpoints directory
os.makedirs('checkpoints', exist_ok=True)
# Train models
models = ['edsr']
for model_name in models:
print(f'Training {model_name.upper()}...')
train_model(model_name, train_loader, val_loader, device)
if __name__ == '__main__':
main() |