Histoformer / basicsr /models /image_restoration_model.py
sunshangquan
c
d12e905
import importlib
import torch
from collections import OrderedDict
from copy import deepcopy
from os import path as osp
from tqdm import tqdm
from basicsr.models.archs import define_network
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img
from huggingface_hub import PyTorchModelHubMixin
loss_module = importlib.import_module('basicsr.models.losses')
metric_module = importlib.import_module('basicsr.metrics')
import os
import random
import numpy as np
import cv2
import torch.nn.functional as F
from functools import partial
#from audtorch.metrics.functional import pearsonr
import torch.autograd as autograd
class Mixing_Augment:
def __init__(self, mixup_beta, use_identity, device):
self.dist = torch.distributions.beta.Beta(torch.tensor([mixup_beta]), torch.tensor([mixup_beta]))
self.device = device
self.use_identity = use_identity
self.augments = [self.mixup]
def mixup(self, target, input_):
lam = self.dist.rsample((1,1)).item()
r_index = torch.randperm(target.size(0)).to(self.device)
target = lam * target + (1-lam) * target[r_index, :]
input_ = lam * input_ + (1-lam) * input_[r_index, :]
return target, input_
def __call__(self, target, input_):
if self.use_identity:
augment = random.randint(0, len(self.augments))
if augment < len(self.augments):
target, input_ = self.augments[augment](target, input_)
else:
augment = random.randint(0, len(self.augments)-1)
target, input_ = self.augments[augment](target, input_)
return target, input_
class ImageCleanModel(BaseModel):
"""Base Deblur model for single image deblur."""
def __init__(self, opt):
super(ImageCleanModel, self).__init__(opt)
# define network
self.mixing_flag = self.opt['train']['mixing_augs'].get('mixup', False)
if self.mixing_flag:
mixup_beta = self.opt['train']['mixing_augs'].get('mixup_beta', 1.2)
use_identity = self.opt['train']['mixing_augs'].get('use_identity', False)
self.mixing_augmentation = Mixing_Augment(mixup_beta, use_identity, self.device)
self.net_g = define_network(deepcopy(opt['network_g']))
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g, load_path,
self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params'))
if self.is_train:
self.init_training_settings()
self.psnr_best = -1
def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(
f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = define_network(self.opt['network_g']).to(
self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path,
self.opt['path'].get('strict_load_g',
True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define losses
if train_opt.get('pixel_opt'):
pixel_type = train_opt['pixel_opt'].pop('type')
cri_pix_cls = getattr(loss_module, pixel_type)
self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
self.device)
else:
raise ValueError('pixel loss are None.')
if train_opt.get('seq_opt'):
# from audtorch.metrics.functional import pearsonr
# self.cri_seq = pearsonr
self.cri_seq = self.pearson_correlation_loss #
self.cri_celoss = torch.nn.CrossEntropyLoss()
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def pearson_correlation_loss(self, x1, x2):
assert x1.shape == x2.shape
b, c = x1.shape[:2]
dim = -1
x1, x2 = x1.reshape(b, -1), x2.reshape(b, -1)
x1_mean, x2_mean = x1.mean(dim=dim, keepdims=True), x2.mean(dim=dim, keepdims=True)
numerator = ((x1 - x1_mean) * (x2 - x2_mean)).sum( dim=dim, keepdims=True )
std1 = (x1 - x1_mean).pow(2).sum(dim=dim, keepdims=True).sqrt()
std2 = (x2 - x2_mean).pow(2).sum(dim=dim, keepdims=True).sqrt()
denominator = std1 * std2
corr = numerator.div(denominator + 1e-6)
return corr
def setup_optimizers(self):
train_opt = self.opt['train']
optim_params = []
for k, v in self.net_g.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')
optim_type = train_opt['optim_g'].pop('type')
if optim_type == 'Adam':
self.optimizer_g = torch.optim.Adam(optim_params, **train_opt['optim_g'])
elif optim_type == 'AdamW':
self.optimizer_g = torch.optim.AdamW(optim_params, **train_opt['optim_g'])
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
self.optimizers.append(self.optimizer_g)
def feed_train_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
if 'label' in data:
self.label = data['label']
# self.label = torch.nn.functional.one_hot(data['label'], num_classes=3)
if self.mixing_flag:
self.gt, self.lq = self.mixing_augmentation(self.gt, self.lq)
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
def check_inf_nan(self, x):
x[x.isnan()] = 0
x[x.isinf()] = 1e7
return x
def compute_correlation_loss(self, x1, x2):
b, c = x1.shape[0:2]
x1 = x1.view(b, -1)
x2 = x2.view(b, -1)
# print(x1, x2)
pearson = (1. - self.cri_seq(x1, x2)) / 2.
return pearson[~pearson.isnan()*~pearson.isinf()].mean()
def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq, )
loss_dict = OrderedDict()
# pixel loss
l_pix = self.cri_pix(self.output, self.gt)
loss_dict['l_pix'] = l_pix
'''
l_mask = self.cri_pix(self.pred_mask, self.gt - self.output.detach())
loss_dict['l_mask'] = l_mask
'''
l_pear = self.compute_correlation_loss(self.output, self.gt)
loss_dict['l_pear'] = l_pear
# l_pred = self.cri_celoss(self.pred, self.label.to(self.pred.device))
# loss_dict['l_pred'] = l_pred
# print("pear:", l_pear, "pix:", l_pix)
loss_total = l_pix + l_pear #+ 0.01*l_pred#+ l_mask
loss_total.backward()
if self.opt['train']['use_grad_clip']:
torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01, error_if_nonfinite=False)
self.optimizer_g.step()
self.log_dict, self.loss_total = self.reduce_loss_dict(loss_dict)
self.loss_dict = loss_dict
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def pad_test(self, window_size):
scale = self.opt.get('scale', 1)
mod_pad_h, mod_pad_w = 0, 0
_, _, h, w = self.lq.size()
if h % window_size != 0:
mod_pad_h = window_size - h % window_size
if w % window_size != 0:
mod_pad_w = window_size - w % window_size
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
self.nonpad_test(img)
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
def nonpad_test(self, img=None):
if img is None:
img = self.lq
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
pred = self.net_g_ema(img)
if isinstance(pred, list):
pred = pred[-1]
self.output = pred
else:
self.net_g.eval()
with torch.no_grad():
pred = self.net_g(img)
if isinstance(pred, list):
pred = pred[-1]
self.output = pred
self.net_g.train()
def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image):
if os.environ['LOCAL_RANK'] == '0':
return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image)
else:
return 0.
def nondist_validation(self, dataloader, current_iter, tb_logger,
save_img, rgb2bgr, use_image):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
self.metric_results = {
metric: 0
for metric in self.opt['val']['metrics'].keys()
}
# pbar = tqdm(total=len(dataloader), unit='image')
window_size = self.opt['val'].get('window_size', 0)
if window_size:
test = partial(self.pad_test, window_size)
else:
test = self.nonpad_test
cnt = 0
for idx, val_data in enumerate(dataloader):
if idx >= 60:
break
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr)
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr)
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'],
img_name,
f'{img_name}_{current_iter}.png')
save_gt_img_path = osp.join(self.opt['path']['visualization'],
img_name,
f'{img_name}_{current_iter}_gt.png')
else:
save_img_path = osp.join(
self.opt['path']['visualization'], dataset_name,
f'{img_name}.png')
save_gt_img_path = osp.join(
self.opt['path']['visualization'], dataset_name,
f'{img_name}_gt.png')
imwrite(sr_img, save_img_path)
imwrite(gt_img, save_gt_img_path)
if with_metrics:
# calculate metrics
opt_metric = deepcopy(self.opt['val']['metrics'])
if use_image:
for name, opt_ in opt_metric.items():
metric_type = opt_.pop('type')
self.metric_results[name] += getattr(
metric_module, metric_type)(sr_img, gt_img, **opt_)
else:
for name, opt_ in opt_metric.items():
metric_type = opt_.pop('type')
self.metric_results[name] += getattr(
metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_)
cnt += 1
current_metric = 0.
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= cnt
current_metric = max(current_metric, self.metric_results[metric])
self._log_validation_metric_values(current_iter, dataset_name,
tb_logger)
return current_metric
def _log_validation_metric_values(self, current_iter, dataset_name,
tb_logger):
log_str = f'Validation {dataset_name},\t'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}'
if metric == 'psnr' and value >= self.psnr_best:
self.save(0, current_iter, best=True)
self.psnr_best = value
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['lq'] = self.lq.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
if hasattr(self, 'gt'):
out_dict['gt'] = self.gt.detach().cpu()
return out_dict
def save(self, epoch, current_iter, best=False):
if self.ema_decay > 0:
self.save_network([self.net_g, self.net_g_ema],
'net_g',
current_iter,
param_key=['params', 'params_ema'], best=best)
else:
self.save_network(self.net_g, 'net_g', current_iter, best=best)
self.save_training_state(epoch, current_iter, best=best)