|
import logging |
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
from basicsr.utils.registry import MODEL_REGISTRY |
|
from .video_base_model import VideoBaseModel |
|
|
|
logger = logging.getLogger('basicsr') |
|
|
|
|
|
@MODEL_REGISTRY.register() |
|
class EDVRModel(VideoBaseModel): |
|
"""EDVR Model. |
|
|
|
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 |
|
""" |
|
|
|
def __init__(self, opt): |
|
super(EDVRModel, self).__init__(opt) |
|
if self.is_train: |
|
self.train_tsa_iter = opt['train'].get('tsa_iter') |
|
|
|
def setup_optimizers(self): |
|
train_opt = self.opt['train'] |
|
dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) |
|
logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') |
|
if dcn_lr_mul == 1: |
|
optim_params = self.net_g.parameters() |
|
else: |
|
normal_params = [] |
|
dcn_params = [] |
|
for name, param in self.net_g.named_parameters(): |
|
if 'dcn' in name: |
|
dcn_params.append(param) |
|
else: |
|
normal_params.append(param) |
|
optim_params = [ |
|
{ |
|
'params': normal_params, |
|
'lr': train_opt['optim_g']['lr'] |
|
}, |
|
{ |
|
'params': dcn_params, |
|
'lr': train_opt['optim_g']['lr'] * dcn_lr_mul |
|
}, |
|
] |
|
|
|
optim_type = train_opt['optim_g'].pop('type') |
|
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) |
|
self.optimizers.append(self.optimizer_g) |
|
|
|
def optimize_parameters(self, current_iter): |
|
if self.train_tsa_iter: |
|
if current_iter == 1: |
|
logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') |
|
for name, param in self.net_g.named_parameters(): |
|
if 'fusion' not in name: |
|
param.requires_grad = False |
|
elif current_iter == self.train_tsa_iter: |
|
logger.warning('Train all the parameters.') |
|
for param in self.net_g.parameters(): |
|
param.requires_grad = True |
|
if isinstance(self.net_g, DistributedDataParallel): |
|
logger.warning('Set net_g.find_unused_parameters = False.') |
|
self.net_g.find_unused_parameters = False |
|
|
|
super(VideoBaseModel, self).optimize_parameters(current_iter) |
|
|