Spaces:
Build error
Build error
""" | |
@Date: 2021/07/17 | |
@description: | |
""" | |
import os | |
import torch | |
import torch.nn as nn | |
import datetime | |
class BaseModule(nn.Module): | |
def __init__(self, ckpt_dir=None): | |
super().__init__() | |
self.ckpt_dir = ckpt_dir | |
if ckpt_dir: | |
if not os.path.exists(ckpt_dir): | |
os.makedirs(ckpt_dir) | |
else: | |
self.model_lst = [x for x in sorted(os.listdir(self.ckpt_dir)) if x.endswith('.pkl')] | |
self.last_model_path = None | |
self.best_model_path = None | |
self.best_accuracy = -float('inf') | |
self.acc_d = {} | |
def show_parameter_number(self, logger): | |
total = sum(p.numel() for p in self.parameters()) | |
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
logger.info('{} parameter total:{:,}, trainable:{:,}'.format(self._get_name(), total, trainable)) | |
def load(self, device, logger, optimizer=None, best=False): | |
if len(self.model_lst) == 0: | |
logger.info('*'*50) | |
logger.info("Empty model folder! Using initial weights") | |
logger.info('*'*50) | |
return 0 | |
last_model_lst = list(filter(lambda n: '_last_' in n, self.model_lst)) | |
best_model_lst = list(filter(lambda n: '_best_' in n, self.model_lst)) | |
if len(last_model_lst) == 0 and len(best_model_lst) == 0: | |
logger.info('*'*50) | |
ckpt_path = os.path.join(self.ckpt_dir, self.model_lst[0]) | |
logger.info(f"Load: {ckpt_path}") | |
checkpoint = torch.load(ckpt_path, map_location=torch.device(device)) | |
self.load_state_dict(checkpoint, strict=False) | |
logger.info('*'*50) | |
return 0 | |
checkpoint = None | |
if len(last_model_lst) > 0: | |
self.last_model_path = os.path.join(self.ckpt_dir, last_model_lst[-1]) | |
checkpoint = torch.load(self.last_model_path, map_location=torch.device(device)) | |
self.best_accuracy = checkpoint['accuracy'] | |
self.acc_d = checkpoint['acc_d'] | |
if len(best_model_lst) > 0: | |
self.best_model_path = os.path.join(self.ckpt_dir, best_model_lst[-1]) | |
best_checkpoint = torch.load(self.best_model_path, map_location=torch.device(device)) | |
self.best_accuracy = best_checkpoint['accuracy'] | |
self.acc_d = best_checkpoint['acc_d'] | |
if best: | |
checkpoint = best_checkpoint | |
for k in self.acc_d: | |
if isinstance(self.acc_d[k], float): | |
self.acc_d[k] = { | |
'acc': self.acc_d[k], | |
'epoch': checkpoint['epoch'] | |
} | |
if checkpoint is None: | |
logger.error("Invalid checkpoint") | |
return | |
self.load_state_dict(checkpoint['net'], strict=False) | |
if optimizer and not best: # best的时候使用新的优化器比如从adam->sgd | |
logger.info('Load optimizer') | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
for state in optimizer.state.values(): | |
for k, v in state.items(): | |
if torch.is_tensor(v): | |
state[k] = v.to(device) | |
logger.info('*'*50) | |
if best: | |
logger.info(f"Lode best: {self.best_model_path}") | |
else: | |
logger.info(f"Lode last: {self.last_model_path}") | |
logger.info(f"Best accuracy: {self.best_accuracy}") | |
logger.info(f"Last epoch: {checkpoint['epoch'] + 1}") | |
logger.info('*'*50) | |
return checkpoint['epoch'] + 1 | |
def update_acc(self, acc_d, epoch, logger): | |
logger.info("-" * 100) | |
for k in acc_d: | |
if k not in self.acc_d.keys() or acc_d[k] > self.acc_d[k]['acc']: | |
self.acc_d[k] = { | |
'acc': acc_d[k], | |
'epoch': epoch | |
} | |
logger.info(f"Update ACC: {k} {self.acc_d[k]['acc']:.4f}({self.acc_d[k]['epoch']}-{epoch})") | |
logger.info("-" * 100) | |
def save(self, optim, epoch, accuracy, logger, replace=True, acc_d=None, config=None): | |
""" | |
:param config: | |
:param optim: | |
:param epoch: | |
:param accuracy: | |
:param logger: | |
:param replace: | |
:param acc_d: 其他评估数据,visible_2/3d, full_2/3d, rmse... | |
:return: | |
""" | |
if acc_d: | |
self.update_acc(acc_d, epoch, logger) | |
name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S_last_{:.4f}_{}'.format(accuracy, epoch)) | |
name = f"model_{name}.pkl" | |
checkpoint = { | |
'net': self.state_dict(), | |
'optimizer': optim.state_dict(), | |
'epoch': epoch, | |
'accuracy': accuracy, | |
'acc_d': acc_d | |
} | |
# FIXME:: delete always true | |
if (True or config.MODEL.SAVE_LAST) and epoch % config.TRAIN.SAVE_FREQ == 0: | |
if replace and self.last_model_path and os.path.exists(self.last_model_path): | |
os.remove(self.last_model_path) | |
self.last_model_path = os.path.join(self.ckpt_dir, name) | |
torch.save(checkpoint, self.last_model_path) | |
logger.info(f"Saved last model: {self.last_model_path}") | |
if accuracy > self.best_accuracy: | |
self.best_accuracy = accuracy | |
# FIXME:: delete always true | |
if True or config.MODEL.SAVE_BEST: | |
if replace and self.best_model_path and os.path.exists(self.best_model_path): | |
os.remove(self.best_model_path) | |
self.best_model_path = os.path.join(self.ckpt_dir, name.replace('last', 'best')) | |
torch.save(checkpoint, self.best_model_path) | |
logger.info("#" * 100) | |
logger.info(f"Saved best model: {self.best_model_path}") | |
logger.info("#" * 100) |