import torch def load(load_path, model, device): if load_path == None: return state_dict = torch.load(load_path, map_location=device) model.load_state_dict(state_dict['model_state_dict']) print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path)) return state_dict['valid_loss']