lilingxi01's picture
[ERCBCM] Optimize the model interfaces and prints.
8b850ac
raw
history blame contribute delete
321 Bytes
import torch
def load(load_path, model, device):
if load_path == None: return
state_dict = torch.load(load_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
return state_dict['valid_loss']