Spaces:
Running
Running
File size: 6,044 Bytes
42d27cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
#
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
import os
import os.path as osp
import logging
import yaml
from utils.util import OrderedYaml
Loader, Dumper = OrderedYaml()
def parse(opt_path, is_train=True):
with open(opt_path, mode='r') as f:
opt = yaml.load(f, Loader=Loader)
# export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', []))
# os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
# print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train
if opt['distortion'] == 'sr':
scale = opt['scale']
# datasets
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
if opt['distortion'] == 'sr':
dataset['scale'] = scale
is_lmdb = False
if dataset.get('dataroot_GT', None) is not None:
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
if dataset['dataroot_GT'].endswith('lmdb'):
is_lmdb = True
if dataset.get('dataroot_LQ', None) is not None:
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
if dataset['dataroot_LQ'].endswith('lmdb'):
is_lmdb = True
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
if dataset['mode'].endswith('mc'): # for memcached
dataset['data_type'] = 'mc'
dataset['mode'] = dataset['mode'].replace('_mc', '')
# path
for key, path in opt['path'].items():
if path and key in opt['path'] and key != 'strict_load':
opt['path'][key] = osp.expanduser(path)
opt['path']['root'] = '/kaggle/working/'
if is_train:
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
opt['path']['log'] = experiments_root
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
# change some options for debug mode
if 'debug' in opt['name']:
opt['train']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
if not opt['path'].get('results_root', None):
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = opt['path']['results_root']
# network
if opt['distortion'] == 'sr':
opt['network_G']['scale'] = scale
# relative learning rate
if 'train' in opt:
niter = opt['train']['niter']
if 'T_period_rel' in opt['train']:
opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']]
if 'restarts_rel' in opt['train']:
opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']]
if 'lr_steps_rel' in opt['train']:
opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']]
if 'lr_steps_inverse_rel' in opt['train']:
opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']]
print(opt['train'])
return opt
def dict2str(opt, indent_l=1):
'''dict to string for logger'''
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg
class NoneDict(dict):
def __missing__(self, key):
return None
# convert to NoneDict, which return None for missing key.
def dict_to_nonedict(opt):
if isinstance(opt, dict):
new_opt = dict()
for key, sub_opt in opt.items():
new_opt[key] = dict_to_nonedict(sub_opt)
return NoneDict(**new_opt)
elif isinstance(opt, list):
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
else:
return opt
def check_resume(opt, resume_iter):
'''Check resume states and pretrain_model paths'''
logger = logging.getLogger('base')
if opt['path']['resume_state']:
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
'pretrain_model_D', None) is not None:
logger.warning('pretrain_model path will be ignored when resuming training.')
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|