DarkIR / archs /__init__.py
danifei's picture
Upload 12 files
159cb3e verified
raw
history blame
8.65 kB
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.parallel import DistributedDataParallel as DDP
from ptflops import get_model_complexity_info
from .DarkIR import DarkIR
def create_model(opt, rank, adapter = False):
'''
Creates the model.
opt: a dictionary from the yaml config key network
'''
name = opt['name']
model = DarkIR(img_channel=opt['img_channels'],
width=opt['width'],
middle_blk_num_enc=opt['middle_blk_num_enc'],
middle_blk_num_dec=opt['middle_blk_num_dec'],
enc_blk_nums=opt['enc_blk_nums'],
dec_blk_nums=opt['dec_blk_nums'],
dilations=opt['dilations'],
extra_depth_wise=opt['extra_depth_wise'])
if rank ==0:
print(f'Using {name} network')
input_size = (3, 256, 256)
macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False)
print(f'Computational complexity at {input_size}: {macs}')
print('Number of parameters: ', params)
else:
macs, params = None, None
model.to(rank)
model = DDP(model, device_ids=[rank], find_unused_parameters=adapter)
return model, macs, params
def create_optim_scheduler(opt, model):
'''
Returns the optim and its scheduler.
opt: a dictionary of the yaml config file with the train key
'''
optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) ,
lr = opt['lr_initial'],
weight_decay = opt['weight_decay'],
betas = opt['betas'])
if opt['lr_scheme'] == 'CosineAnnealing':
scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min'])
else:
raise NotImplementedError('scheduler not implemented')
return optim, scheduler
def load_weights(model, old_weights):
'''
Loads the weights of a pretrained model, picking only the weights that are
in the new model.
'''
new_weights = model.state_dict()
new_weights.update({k: v for k, v in old_weights.items() if k in new_weights})
model.load_state_dict(new_weights)
return model
def load_optim(optim, optim_weights):
'''
Loads the values of the optimizer picking only the weights that are in the new model.
'''
optim_new_weights = optim.state_dict()
# optim_new_weights.load_state_dict(optim_weights)
optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights})
return optim
def resume_model(model,
optim,
scheduler,
path_model,
rank,resume:str=None):
'''
Returns the loaded weights of model and optimizer if resume flag is True
'''
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
if resume:
checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
weights = checkpoints['model_state_dict']
model = load_weights(model, old_weights=weights)
optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
start_epochs = checkpoints['epoch']
if rank == 0: print('Loaded weights')
else:
start_epochs = 0
if rank==0: print('Starting from zero the training')
return model, optim, scheduler, start_epochs
def find_different_keys(dict1, dict2):
# Finding different keys
different_keys = set(dict1.keys()) ^ set(dict2.keys())
return different_keys
def number_common_keys(dict1, dict2):
# Finding common keys
common_keys = set(dict1.keys()) & set(dict2.keys())
# Counting the number of common keys
common_keys_count = len(common_keys)
return common_keys_count
# # Function to add 'modules_list' prefix after the first numeric index
# def add_middle_prefix(state_dict, middle_prefix, target_strings):
# new_state_dict = {}
# for key, value in state_dict.items():
# for target in target_strings:
# if target in key:
# parts = key.split('.')
# # Find the first numeric index after the target string
# for i, part in enumerate(parts):
# if part == target:
# # Insert the middle prefix after the first numeric index
# if i + 1 < len(parts) and parts[i + 1].isdigit():
# parts.insert(i + 2, middle_prefix)
# break
# new_key = '.'.join(parts)
# new_state_dict[new_key] = value
# break
# else:
# new_state_dict[key] = value
# return new_state_dict
# # Function to adjust keys for 'middle_blks.' prefix
# def adjust_middle_blks_keys(state_dict, target_prefix, middle_prefix):
# new_state_dict = {}
# for key, value in state_dict.items():
# if target_prefix in key:
# parts = key.split('.')
# # Find the target prefix and adjust the key
# for i, part in enumerate(parts):
# if part == target_prefix.rstrip('.'):
# if i + 1 < len(parts) and parts[i + 1].isdigit():
# # Swap the numerical part and the middle prefix
# new_key = '.'.join(parts[:i + 1] + [middle_prefix] + parts[i + 1:i + 2] + parts[i + 2:])
# new_state_dict[new_key] = value
# break
# else:
# new_state_dict[key] = value
# return new_state_dict
# def resume_nafnet(model,
# optim,
# scheduler,
# path_adapter,
# path_model,
# rank, resume:str=None):
# '''
# Returns the loaded weights of model and optimizer if resume flag is True
# '''
# map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
# #first load the model weights
# checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
# weights = checkpoints
# if rank==0:
# print(len(weights), len(model.state_dict().keys()))
# different_keys = find_different_keys(weights, model.state_dict())
# filtered_keys = {item for item in different_keys if 'adapter' not in item}
# print(filtered_keys)
# print(len(filtered_keys))
# model = load_weights(model, old_weights=weights)
# #now if needed load the adapter weights
# if resume:
# checkpoints = torch.load(path_adapter, map_location=map_location, weights_only=False)
# weights = checkpoints
# model = load_weights(model, old_weights=weights)
# # optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
# start_epochs = checkpoints['epoch']
# if rank == 0: print('Loaded weights')
# else:
# start_epochs = 0
# if rank == 0: print('Starting from zero the training')
# return model, optim, scheduler, start_epochs
def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None):
'''
Save the .pt of the model after each epoch.
'''
best_psnr = metrics_train['best_psnr']
if rank!=0:
return best_psnr
if type(next(iter(metrics_eval.values()))) != dict:
metrics_eval = {'metrics': metrics_eval}
weights = model.state_dict()
# Save the model after every epoch
model_to_save = {
'epoch': metrics_train['epoch'],
'model_state_dict': weights,
'optimizer_state_dict': optim.state_dict(),
'loss': metrics_train['train_loss'],
'scheduler_state_dict': scheduler.state_dict()
}
try:
torch.save(model_to_save, paths['new'])
# Save best model if new valid_psnr is higher than the best one
if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']:
torch.save(model_to_save, paths['best'])
metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr'] # update best psnr
except Exception as e:
print(f"Error saving model: {e}")
return metrics_train['best_psnr']
__all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint',
'load_optim', 'load_weights']