File size: 4,667 Bytes
6221b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# train.py
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset import DIV2KDataset
from models.srcnn import SRCNN
from models.vdsr import VDSR
from models.edsr import EDSR
import math
import numpy as np

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.01, min_psnr_improvement=0.1):
        self.patience = patience
        self.min_delta = min_delta
        self.min_psnr_improvement = min_psnr_improvement
        self.counter = 0
        self.best_loss = None
        self.best_psnr = None
        self.early_stop = False
        
    def __call__(self, loss, psnr):
        if self.best_loss is None:
            self.best_loss = loss
            self.best_psnr = psnr
        elif (loss > self.best_loss - self.min_delta) and (psnr < self.best_psnr + self.min_psnr_improvement):
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = min(loss, self.best_loss)
            self.best_psnr = max(psnr, self.best_psnr)
            self.counter = 0

def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(1.0 / math.sqrt(mse.item()))

def train_model(model_name, train_loader, val_loader, device, num_epochs=100):
    # Initialize model
    if model_name == 'srcnn':
        model = SRCNN()
    elif model_name == 'vdsr':
        model = VDSR()
    else:
        model = EDSR()
    
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    # Initialize early stopping
    early_stopping = EarlyStopping(patience=10, min_delta=0.00001, min_psnr_improvement=0.1)
    best_psnr = 0
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        num_batches = 0
        
        for batch_idx, (lr_img, hr_img) in enumerate(train_loader):
            lr_img, hr_img = lr_img.to(device), hr_img.to(device)
            
            optimizer.zero_grad()
            output = model(lr_img)
            loss = criterion(output, hr_img)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            num_batches += 1
            
            if batch_idx % 100 == 0:
                print(f'Train Epoch: {epoch} [{batch_idx}/{len(train_loader)}]\tLoss: {loss.item():.6f}')
        
        avg_train_loss = train_loss / num_batches
        
        # Validation
        model.eval()
        val_psnr = 0
        with torch.no_grad():
            for lr_img, hr_img in val_loader:
                lr_img, hr_img = lr_img.to(device), hr_img.to(device)
                output = model(lr_img)
                val_psnr += calculate_psnr(output, hr_img)
        
        val_psnr /= len(val_loader)
        print(f'Epoch: {epoch}, Average Loss: {avg_train_loss:.6f}, Average PSNR: {val_psnr:.2f}dB')
        
        # Early stopping check
        early_stopping(avg_train_loss, val_psnr)
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch}")
            break
        
        # Save best model
        if val_psnr > best_psnr:
            best_psnr = val_psnr
            torch.save(model.state_dict(), f'checkpoints/{model_name}_best.pth')
            print(f'Saved new best model with PSNR: {best_psnr:.2f}dB')

def main():
    # Setup
    device = torch.device('cpu')
    
    # Data paths
    train_hr_dir = 'data/DIV2K_train_HR/DIV2K_train_HR/'
    train_lr_dir = 'data/DIV2K_train_LR_bicubic_X4/DIV2K_train_LR_bicubic/X4'
    val_hr_dir = 'data/DIV2K_valid_HR/DIV2K_valid_HR'
    val_lr_dir = 'data/DIV2K_valid_LR_bicubic_X4/DIV2K_valid_LR_bicubic/X4'
    
    # Create datasets
    train_dataset = DIV2KDataset(train_hr_dir, train_lr_dir, patch_size=48)
    val_dataset = DIV2KDataset(val_hr_dir, val_lr_dir, patch_size=48)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
    # Create checkpoints directory
    os.makedirs('checkpoints', exist_ok=True)
    
    # Train models
    models = ['edsr']
    for model_name in models:
        print(f'Training {model_name.upper()}...')
        train_model(model_name, train_loader, val_loader, device)

if __name__ == '__main__':
    main()