|
import logging
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from basicsr.utils.registry import MODEL_REGISTRY
|
|
from .video_base_model import VideoBaseModel
|
|
|
|
logger = logging.getLogger('basicsr')
|
|
|
|
|
|
@MODEL_REGISTRY.register()
|
|
class EDVRModel(VideoBaseModel):
|
|
"""EDVR Model.
|
|
|
|
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
|
|
"""
|
|
|
|
def __init__(self, opt):
|
|
super(EDVRModel, self).__init__(opt)
|
|
if self.is_train:
|
|
self.train_tsa_iter = opt['train'].get('tsa_iter')
|
|
|
|
def setup_optimizers(self):
|
|
train_opt = self.opt['train']
|
|
dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
|
|
logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
|
|
if dcn_lr_mul == 1:
|
|
optim_params = self.net_g.parameters()
|
|
else:
|
|
normal_params = []
|
|
dcn_params = []
|
|
for name, param in self.net_g.named_parameters():
|
|
if 'dcn' in name:
|
|
dcn_params.append(param)
|
|
else:
|
|
normal_params.append(param)
|
|
optim_params = [
|
|
{
|
|
'params': normal_params,
|
|
'lr': train_opt['optim_g']['lr']
|
|
},
|
|
{
|
|
'params': dcn_params,
|
|
'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
|
|
},
|
|
]
|
|
|
|
optim_type = train_opt['optim_g'].pop('type')
|
|
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
|
|
self.optimizers.append(self.optimizer_g)
|
|
|
|
def optimize_parameters(self, current_iter):
|
|
if self.train_tsa_iter:
|
|
if current_iter == 1:
|
|
logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.')
|
|
for name, param in self.net_g.named_parameters():
|
|
if 'fusion' not in name:
|
|
param.requires_grad = False
|
|
elif current_iter == self.train_tsa_iter:
|
|
logger.warning('Train all the parameters.')
|
|
for param in self.net_g.parameters():
|
|
param.requires_grad = True
|
|
if isinstance(self.net_g, DistributedDataParallel):
|
|
logger.warning('Set net_g.find_unused_parameters = False.')
|
|
self.net_g.find_unused_parameters = False
|
|
|
|
super(VideoBaseModel, self).optimize_parameters(current_iter)
|
|
|