Spaces:
Running
Running
import torch | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from ptflops import get_model_complexity_info | |
from .DarkIR import DarkIR | |
def create_model(opt, rank, adapter = False): | |
''' | |
Creates the model. | |
opt: a dictionary from the yaml config key network | |
''' | |
name = opt['name'] | |
model = DarkIR(img_channel=opt['img_channels'], | |
width=opt['width'], | |
middle_blk_num_enc=opt['middle_blk_num_enc'], | |
middle_blk_num_dec=opt['middle_blk_num_dec'], | |
enc_blk_nums=opt['enc_blk_nums'], | |
dec_blk_nums=opt['dec_blk_nums'], | |
dilations=opt['dilations'], | |
extra_depth_wise=opt['extra_depth_wise']) | |
if rank ==0: | |
print(f'Using {name} network') | |
input_size = (3, 256, 256) | |
macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False) | |
print(f'Computational complexity at {input_size}: {macs}') | |
print('Number of parameters: ', params) | |
else: | |
macs, params = None, None | |
model.to(rank) | |
model = DDP(model, device_ids=[rank], find_unused_parameters=adapter) | |
return model, macs, params | |
def create_optim_scheduler(opt, model): | |
''' | |
Returns the optim and its scheduler. | |
opt: a dictionary of the yaml config file with the train key | |
''' | |
optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) , | |
lr = opt['lr_initial'], | |
weight_decay = opt['weight_decay'], | |
betas = opt['betas']) | |
if opt['lr_scheme'] == 'CosineAnnealing': | |
scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min']) | |
else: | |
raise NotImplementedError('scheduler not implemented') | |
return optim, scheduler | |
def load_weights(model, old_weights): | |
''' | |
Loads the weights of a pretrained model, picking only the weights that are | |
in the new model. | |
''' | |
new_weights = model.state_dict() | |
new_weights.update({k: v for k, v in old_weights.items() if k in new_weights}) | |
model.load_state_dict(new_weights) | |
return model | |
def load_optim(optim, optim_weights): | |
''' | |
Loads the values of the optimizer picking only the weights that are in the new model. | |
''' | |
optim_new_weights = optim.state_dict() | |
# optim_new_weights.load_state_dict(optim_weights) | |
optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights}) | |
return optim | |
def resume_model(model, | |
optim, | |
scheduler, | |
path_model, | |
rank,resume:str=None): | |
''' | |
Returns the loaded weights of model and optimizer if resume flag is True | |
''' | |
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} | |
if resume: | |
checkpoints = torch.load(path_model, map_location=map_location, weights_only=False) | |
weights = checkpoints['model_state_dict'] | |
model = load_weights(model, old_weights=weights) | |
optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict']) | |
scheduler.load_state_dict(checkpoints['scheduler_state_dict']) | |
start_epochs = checkpoints['epoch'] | |
if rank == 0: print('Loaded weights') | |
else: | |
start_epochs = 0 | |
if rank==0: print('Starting from zero the training') | |
return model, optim, scheduler, start_epochs | |
def find_different_keys(dict1, dict2): | |
# Finding different keys | |
different_keys = set(dict1.keys()) ^ set(dict2.keys()) | |
return different_keys | |
def number_common_keys(dict1, dict2): | |
# Finding common keys | |
common_keys = set(dict1.keys()) & set(dict2.keys()) | |
# Counting the number of common keys | |
common_keys_count = len(common_keys) | |
return common_keys_count | |
# # Function to add 'modules_list' prefix after the first numeric index | |
# def add_middle_prefix(state_dict, middle_prefix, target_strings): | |
# new_state_dict = {} | |
# for key, value in state_dict.items(): | |
# for target in target_strings: | |
# if target in key: | |
# parts = key.split('.') | |
# # Find the first numeric index after the target string | |
# for i, part in enumerate(parts): | |
# if part == target: | |
# # Insert the middle prefix after the first numeric index | |
# if i + 1 < len(parts) and parts[i + 1].isdigit(): | |
# parts.insert(i + 2, middle_prefix) | |
# break | |
# new_key = '.'.join(parts) | |
# new_state_dict[new_key] = value | |
# break | |
# else: | |
# new_state_dict[key] = value | |
# return new_state_dict | |
# # Function to adjust keys for 'middle_blks.' prefix | |
# def adjust_middle_blks_keys(state_dict, target_prefix, middle_prefix): | |
# new_state_dict = {} | |
# for key, value in state_dict.items(): | |
# if target_prefix in key: | |
# parts = key.split('.') | |
# # Find the target prefix and adjust the key | |
# for i, part in enumerate(parts): | |
# if part == target_prefix.rstrip('.'): | |
# if i + 1 < len(parts) and parts[i + 1].isdigit(): | |
# # Swap the numerical part and the middle prefix | |
# new_key = '.'.join(parts[:i + 1] + [middle_prefix] + parts[i + 1:i + 2] + parts[i + 2:]) | |
# new_state_dict[new_key] = value | |
# break | |
# else: | |
# new_state_dict[key] = value | |
# return new_state_dict | |
# def resume_nafnet(model, | |
# optim, | |
# scheduler, | |
# path_adapter, | |
# path_model, | |
# rank, resume:str=None): | |
# ''' | |
# Returns the loaded weights of model and optimizer if resume flag is True | |
# ''' | |
# map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} | |
# #first load the model weights | |
# checkpoints = torch.load(path_model, map_location=map_location, weights_only=False) | |
# weights = checkpoints | |
# if rank==0: | |
# print(len(weights), len(model.state_dict().keys())) | |
# different_keys = find_different_keys(weights, model.state_dict()) | |
# filtered_keys = {item for item in different_keys if 'adapter' not in item} | |
# print(filtered_keys) | |
# print(len(filtered_keys)) | |
# model = load_weights(model, old_weights=weights) | |
# #now if needed load the adapter weights | |
# if resume: | |
# checkpoints = torch.load(path_adapter, map_location=map_location, weights_only=False) | |
# weights = checkpoints | |
# model = load_weights(model, old_weights=weights) | |
# # optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict']) | |
# scheduler.load_state_dict(checkpoints['scheduler_state_dict']) | |
# start_epochs = checkpoints['epoch'] | |
# if rank == 0: print('Loaded weights') | |
# else: | |
# start_epochs = 0 | |
# if rank == 0: print('Starting from zero the training') | |
# return model, optim, scheduler, start_epochs | |
def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None): | |
''' | |
Save the .pt of the model after each epoch. | |
''' | |
best_psnr = metrics_train['best_psnr'] | |
if rank!=0: | |
return best_psnr | |
if type(next(iter(metrics_eval.values()))) != dict: | |
metrics_eval = {'metrics': metrics_eval} | |
weights = model.state_dict() | |
# Save the model after every epoch | |
model_to_save = { | |
'epoch': metrics_train['epoch'], | |
'model_state_dict': weights, | |
'optimizer_state_dict': optim.state_dict(), | |
'loss': metrics_train['train_loss'], | |
'scheduler_state_dict': scheduler.state_dict() | |
} | |
try: | |
torch.save(model_to_save, paths['new']) | |
# Save best model if new valid_psnr is higher than the best one | |
if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']: | |
torch.save(model_to_save, paths['best']) | |
metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr'] # update best psnr | |
except Exception as e: | |
print(f"Error saving model: {e}") | |
return metrics_train['best_psnr'] | |
__all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint', | |
'load_optim', 'load_weights'] | |