MultiMatrix commited on
Commit
4bef05b
·
verified ·
1 Parent(s): 9c9a4d4

Upload train_stage1.py

Browse files
Files changed (1) hide show
  1. train_stage1.py +249 -0
train_stage1.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import ArgumentParser
3
+ import warnings
4
+
5
+ from omegaconf import OmegaConf
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from torchvision.utils import make_grid
11
+ from accelerate import Accelerator
12
+ from accelerate.utils import set_seed
13
+ from einops import rearrange
14
+ from tqdm import tqdm
15
+ import lpips
16
+
17
+ from model import SwinIR
18
+ from utils.common import instantiate_from_config
19
+
20
+
21
+ # https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/utils/color_util.py#L186
22
+ def rgb2ycbcr_pt(img, y_only=False):
23
+ """Convert RGB images to YCbCr images (PyTorch version).
24
+
25
+ It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
26
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
27
+
28
+ Args:
29
+ img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
30
+ y_only (bool): Whether to only return Y channel. Default: False.
31
+
32
+ Returns:
33
+ (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
34
+ """
35
+ if y_only:
36
+ weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
37
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
38
+ else:
39
+ weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
40
+ bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
41
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
42
+
43
+ out_img = out_img / 255.
44
+ return out_img
45
+
46
+
47
+ # https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/metrics/psnr_ssim.py#L52
48
+ def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False):
49
+ """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
50
+
51
+ Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
52
+
53
+ Args:
54
+ img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
55
+ img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
56
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
57
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
58
+
59
+ Returns:
60
+ float: PSNR result.
61
+ """
62
+
63
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
64
+
65
+ if crop_border != 0:
66
+ img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
67
+ img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
68
+
69
+ if test_y_channel:
70
+ img = rgb2ycbcr_pt(img, y_only=True)
71
+ img2 = rgb2ycbcr_pt(img2, y_only=True)
72
+
73
+ img = img.to(torch.float64)
74
+ img2 = img2.to(torch.float64)
75
+
76
+ mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
77
+ return 10. * torch.log10(1. / (mse + 1e-8))
78
+
79
+
80
+ def main(args) -> None:
81
+ # Setup accelerator:
82
+ accelerator = Accelerator(split_batches=True)
83
+ set_seed(231)
84
+ device = accelerator.device
85
+ cfg = OmegaConf.load(args.config)
86
+
87
+ # Setup an experiment folder:
88
+ if accelerator.is_local_main_process:
89
+ exp_dir = cfg.train.exp_dir
90
+ os.makedirs(exp_dir, exist_ok=True)
91
+ ckpt_dir = os.path.join(exp_dir, "checkpoints")
92
+ os.makedirs(ckpt_dir, exist_ok=True)
93
+ print(f"Experiment directory created at {exp_dir}")
94
+
95
+ # Create model:
96
+ swinir: SwinIR = instantiate_from_config(cfg.model.swinir)
97
+ if cfg.train.resume:
98
+ swinir.load_state_dict(torch.load(cfg.train.resume, map_location="cpu"), strict=True)
99
+ if accelerator.is_local_main_process:
100
+ print(f"strictly load weight from checkpoint: {cfg.train.resume}")
101
+ else:
102
+ if accelerator.is_local_main_process:
103
+ print("initialize from scratch")
104
+
105
+ # Setup optimizer:
106
+ opt = torch.optim.AdamW(
107
+ swinir.parameters(), lr=cfg.train.learning_rate,
108
+ weight_decay=0
109
+ )
110
+
111
+ # Setup data:
112
+ dataset = instantiate_from_config(cfg.dataset.train)
113
+ loader = DataLoader(
114
+ dataset=dataset, batch_size=cfg.train.batch_size,
115
+ num_workers=cfg.train.num_workers,
116
+ shuffle=True, drop_last=True
117
+ )
118
+ val_dataset = instantiate_from_config(cfg.dataset.val)
119
+ val_loader = DataLoader(
120
+ dataset=val_dataset, batch_size=cfg.train.batch_size,
121
+ num_workers=cfg.train.num_workers,
122
+ shuffle=False, drop_last=False
123
+ )
124
+ if accelerator.is_local_main_process:
125
+ print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}")
126
+
127
+ # Prepare models for training:
128
+ swinir.train().to(device)
129
+ swinir, opt, loader, val_loader = accelerator.prepare(swinir, opt, loader, val_loader)
130
+ pure_swinir = accelerator.unwrap_model(swinir)
131
+
132
+ # Variables for monitoring/logging purposes:
133
+ global_step = 0
134
+ max_steps = cfg.train.train_steps
135
+ step_loss = []
136
+ epoch = 0
137
+ epoch_loss = []
138
+ with warnings.catch_warnings():
139
+ # avoid warnings from lpips internal
140
+ warnings.simplefilter("ignore")
141
+ lpips_model = lpips.LPIPS(net="alex", verbose=accelerator.is_local_main_process).eval().to(device)
142
+ if accelerator.is_local_main_process:
143
+ writer = SummaryWriter(exp_dir)
144
+ print(f"Training for {max_steps} steps...")
145
+
146
+ while global_step < max_steps:
147
+ pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader))
148
+ for gt, lq, _ in loader:
149
+ gt = rearrange((gt + 1) / 2, "b h w c -> b c h w").contiguous().float().to(device)
150
+ lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device)
151
+ pred = swinir(lq)
152
+ loss = F.mse_loss(input=pred, target=gt, reduction="sum")
153
+
154
+ opt.zero_grad()
155
+ accelerator.backward(loss)
156
+ opt.step()
157
+ accelerator.wait_for_everyone()
158
+
159
+ global_step += 1
160
+ step_loss.append(loss.item())
161
+ epoch_loss.append(loss.item())
162
+ pbar.update(1)
163
+ pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}")
164
+
165
+ # Log loss values:
166
+ if global_step % cfg.train.log_every == 0:
167
+ # Gather values from all processes
168
+ avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item()
169
+ step_loss.clear()
170
+ if accelerator.is_local_main_process:
171
+ writer.add_scalar("train/loss_step", avg_loss, global_step)
172
+
173
+ # Save checkpoint:
174
+ if global_step % cfg.train.ckpt_every == 0:
175
+ if accelerator.is_local_main_process:
176
+ checkpoint = pure_swinir.state_dict()
177
+ ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt"
178
+ torch.save(checkpoint, ckpt_path)
179
+
180
+ if global_step % cfg.train.image_every == 0 or global_step == 1:
181
+ swinir.eval()
182
+ N = 12
183
+ log_gt, log_lq = gt[:N], lq[:N]
184
+ with torch.no_grad():
185
+ log_pred = swinir(log_lq)
186
+ if accelerator.is_local_main_process:
187
+ for tag, image in [
188
+ ("image/pred", log_pred),
189
+ ("image/gt", log_gt),
190
+ ("image/lq", log_lq),
191
+ ]:
192
+ writer.add_image(tag, make_grid(image, nrow=4), global_step)
193
+ swinir.train()
194
+
195
+ # Evaluate model:
196
+ if global_step % cfg.train.val_every == 0:
197
+ swinir.eval()
198
+ val_loss = []
199
+ val_lpips = []
200
+ val_psnr = []
201
+ val_pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch",
202
+ total=len(val_loader), leave=False, desc="Validation")
203
+ # TODO: use accelerator.gather_for_metrics for more precise metric calculation?
204
+ for val_gt, val_lq, _ in val_loader:
205
+ val_gt = rearrange((val_gt + 1) / 2, "b h w c -> b c h w").contiguous().float().to(device)
206
+ val_lq = rearrange(val_lq, "b h w c -> b c h w").contiguous().float().to(device)
207
+ with torch.no_grad():
208
+ # forward
209
+ val_pred = swinir(val_lq)
210
+ # compute metrics (loss, lpips, psnr)
211
+ val_loss.append(F.mse_loss(input=val_pred, target=val_gt, reduction="sum").item())
212
+ val_lpips.append(lpips_model(val_pred, val_gt, normalize=True).mean().item())
213
+ val_psnr.append(calculate_psnr_pt(val_pred, val_gt, crop_border=0).mean().item())
214
+ val_pbar.update(1)
215
+ val_pbar.close()
216
+ avg_val_loss = accelerator.gather(torch.tensor(val_loss, device=device).unsqueeze(0)).mean().item()
217
+ avg_val_lpips = accelerator.gather(torch.tensor(val_lpips, device=device).unsqueeze(0)).mean().item()
218
+ avg_val_psnr = accelerator.gather(torch.tensor(val_psnr, device=device).unsqueeze(0)).mean().item()
219
+ if accelerator.is_local_main_process:
220
+ for tag, val in [
221
+ ("val/loss", avg_val_loss),
222
+ ("val/lpips", avg_val_lpips),
223
+ ("val/psnr", avg_val_psnr)
224
+ ]:
225
+ writer.add_scalar(tag, val, global_step)
226
+ swinir.train()
227
+
228
+ accelerator.wait_for_everyone()
229
+
230
+ if global_step == max_steps:
231
+ break
232
+
233
+ pbar.close()
234
+ epoch += 1
235
+ avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item()
236
+ epoch_loss.clear()
237
+ if accelerator.is_local_main_process:
238
+ writer.add_scalar("train/loss_epoch", avg_epoch_loss, global_step)
239
+
240
+ if accelerator.is_local_main_process:
241
+ print("done!")
242
+ writer.close()
243
+
244
+
245
+ if __name__ == "__main__":
246
+ parser = ArgumentParser()
247
+ parser.add_argument("--config", type=str, required=True)
248
+ args = parser.parse_args()
249
+ main(args)