File size: 321 Bytes
154ca7b
 
8b850ac
 
154ca7b
 
8b850ac
 
1
2
3
4
5
6
7
8
import torch

def load(load_path, model, device):
    if load_path == None: return
    state_dict = torch.load(load_path, map_location=device)
    model.load_state_dict(state_dict['model_state_dict'])
    print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
    return state_dict['valid_loss']