hieupt's picture
Upload utils.py
f5979b8 verified
import os
import torch
def save_model(model, optimizer, state, path):
if isinstance(model, torch.nn.DataParallel):
model = model.module # save state dict of wrapped module
if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'state': state, # state of training loop (was 'step')
}, path)
def load_model(model, optimizer, path, cuda):
if isinstance(model, torch.nn.DataParallel):
model = model.module # load state dict of wrapped module
if cuda:
checkpoint = torch.load(path)
else:
checkpoint = torch.load(path, map_location='cpu')
try:
model.load_state_dict(checkpoint['model_state_dict'])
except:
# work-around for loading checkpoints where DataParallel was saved instead of inner module
from collections import OrderedDict
model_state_dict_fixed = OrderedDict()
prefix = 'module.'
for k, v in checkpoint['model_state_dict'].items():
if k.startswith(prefix):
k = k[len(prefix):]
model_state_dict_fixed[k] = v
model.load_state_dict(model_state_dict_fixed)
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'state' in checkpoint:
state = checkpoint['state']
else:
# older checkpoints only store step, rest of state won't be there
state = {'step': checkpoint['step']}
return state
def compute_loss(model, inputs, targets, criterion, compute_grad=False):
'''
Computes gradients of model with given inputs and targets and loss function.
Optionally backpropagates to compute gradients for weights.
Procedure depends on whether we have one model for each source or not
:param model: Model to train with
:param inputs: Input mixture
:param targets: Target sources
:param criterion: Loss function to use (L1, L2, ..)
:param compute_grad: Whether to compute gradients
:return: Model outputs, Average loss over batch
'''
all_outputs = {}
if model.separate:
avg_loss = 0.0
num_sources = 0
for inst in model.instruments:
output = model(inputs, inst)
loss = criterion(output[inst], targets[inst])
if compute_grad:
loss.backward()
avg_loss += loss.item()
num_sources += 1
all_outputs[inst] = output[inst].detach().clone()
avg_loss /= float(num_sources)
else:
loss = 0
all_outputs = model(inputs)
for inst in all_outputs.keys():
loss += criterion(all_outputs[inst], targets[inst])
if compute_grad:
loss.backward()
avg_loss = loss.item() / float(len(all_outputs))
return all_outputs, avg_loss
class DataParallel(torch.nn.DataParallel):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__(module, device_ids, output_device, dim)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)