Spaces:
Runtime error
Runtime error
import os | |
import torch | |
def save_model(model, optimizer, state, path): | |
if isinstance(model, torch.nn.DataParallel): | |
model = model.module # save state dict of wrapped module | |
if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)): | |
os.makedirs(os.path.dirname(path)) | |
torch.save({ | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'state': state, # state of training loop (was 'step') | |
}, path) | |
def load_model(model, optimizer, path, cuda): | |
if isinstance(model, torch.nn.DataParallel): | |
model = model.module # load state dict of wrapped module | |
if cuda: | |
checkpoint = torch.load(path) | |
else: | |
checkpoint = torch.load(path, map_location='cpu') | |
try: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
except: | |
# work-around for loading checkpoints where DataParallel was saved instead of inner module | |
from collections import OrderedDict | |
model_state_dict_fixed = OrderedDict() | |
prefix = 'module.' | |
for k, v in checkpoint['model_state_dict'].items(): | |
if k.startswith(prefix): | |
k = k[len(prefix):] | |
model_state_dict_fixed[k] = v | |
model.load_state_dict(model_state_dict_fixed) | |
if optimizer is not None: | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
if 'state' in checkpoint: | |
state = checkpoint['state'] | |
else: | |
# older checkpoints only store step, rest of state won't be there | |
state = {'step': checkpoint['step']} | |
return state | |
def compute_loss(model, inputs, targets, criterion, compute_grad=False): | |
''' | |
Computes gradients of model with given inputs and targets and loss function. | |
Optionally backpropagates to compute gradients for weights. | |
Procedure depends on whether we have one model for each source or not | |
:param model: Model to train with | |
:param inputs: Input mixture | |
:param targets: Target sources | |
:param criterion: Loss function to use (L1, L2, ..) | |
:param compute_grad: Whether to compute gradients | |
:return: Model outputs, Average loss over batch | |
''' | |
all_outputs = {} | |
if model.separate: | |
avg_loss = 0.0 | |
num_sources = 0 | |
for inst in model.instruments: | |
output = model(inputs, inst) | |
loss = criterion(output[inst], targets[inst]) | |
if compute_grad: | |
loss.backward() | |
avg_loss += loss.item() | |
num_sources += 1 | |
all_outputs[inst] = output[inst].detach().clone() | |
avg_loss /= float(num_sources) | |
else: | |
loss = 0 | |
all_outputs = model(inputs) | |
for inst in all_outputs.keys(): | |
loss += criterion(all_outputs[inst], targets[inst]) | |
if compute_grad: | |
loss.backward() | |
avg_loss = loss.item() / float(len(all_outputs)) | |
return all_outputs, avg_loss | |
class DataParallel(torch.nn.DataParallel): | |
def __init__(self, module, device_ids=None, output_device=None, dim=0): | |
super(DataParallel, self).__init__(module, device_ids, output_device, dim) | |
def __getattr__(self, name): | |
try: | |
return super().__getattr__(name) | |
except AttributeError: | |
return getattr(self.module, name) |