hieupt commited on
Commit
f5979b8
·
verified ·
1 Parent(s): 57599d7

Upload utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +97 -0
model/utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ def save_model(model, optimizer, state, path):
5
+ if isinstance(model, torch.nn.DataParallel):
6
+ model = model.module # save state dict of wrapped module
7
+ if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)):
8
+ os.makedirs(os.path.dirname(path))
9
+ torch.save({
10
+ 'model_state_dict': model.state_dict(),
11
+ 'optimizer_state_dict': optimizer.state_dict(),
12
+ 'state': state, # state of training loop (was 'step')
13
+ }, path)
14
+
15
+
16
+ def load_model(model, optimizer, path, cuda):
17
+ if isinstance(model, torch.nn.DataParallel):
18
+ model = model.module # load state dict of wrapped module
19
+ if cuda:
20
+ checkpoint = torch.load(path)
21
+ else:
22
+ checkpoint = torch.load(path, map_location='cpu')
23
+ try:
24
+ model.load_state_dict(checkpoint['model_state_dict'])
25
+ except:
26
+ # work-around for loading checkpoints where DataParallel was saved instead of inner module
27
+ from collections import OrderedDict
28
+ model_state_dict_fixed = OrderedDict()
29
+ prefix = 'module.'
30
+ for k, v in checkpoint['model_state_dict'].items():
31
+ if k.startswith(prefix):
32
+ k = k[len(prefix):]
33
+ model_state_dict_fixed[k] = v
34
+ model.load_state_dict(model_state_dict_fixed)
35
+ if optimizer is not None:
36
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
37
+ if 'state' in checkpoint:
38
+ state = checkpoint['state']
39
+ else:
40
+ # older checkpoints only store step, rest of state won't be there
41
+ state = {'step': checkpoint['step']}
42
+ return state
43
+
44
+
45
+ def compute_loss(model, inputs, targets, criterion, compute_grad=False):
46
+ '''
47
+ Computes gradients of model with given inputs and targets and loss function.
48
+ Optionally backpropagates to compute gradients for weights.
49
+ Procedure depends on whether we have one model for each source or not
50
+ :param model: Model to train with
51
+ :param inputs: Input mixture
52
+ :param targets: Target sources
53
+ :param criterion: Loss function to use (L1, L2, ..)
54
+ :param compute_grad: Whether to compute gradients
55
+ :return: Model outputs, Average loss over batch
56
+ '''
57
+ all_outputs = {}
58
+
59
+ if model.separate:
60
+ avg_loss = 0.0
61
+ num_sources = 0
62
+ for inst in model.instruments:
63
+ output = model(inputs, inst)
64
+ loss = criterion(output[inst], targets[inst])
65
+
66
+ if compute_grad:
67
+ loss.backward()
68
+
69
+ avg_loss += loss.item()
70
+ num_sources += 1
71
+
72
+ all_outputs[inst] = output[inst].detach().clone()
73
+
74
+ avg_loss /= float(num_sources)
75
+ else:
76
+ loss = 0
77
+ all_outputs = model(inputs)
78
+ for inst in all_outputs.keys():
79
+ loss += criterion(all_outputs[inst], targets[inst])
80
+
81
+ if compute_grad:
82
+ loss.backward()
83
+
84
+ avg_loss = loss.item() / float(len(all_outputs))
85
+
86
+ return all_outputs, avg_loss
87
+
88
+
89
+ class DataParallel(torch.nn.DataParallel):
90
+ def __init__(self, module, device_ids=None, output_device=None, dim=0):
91
+ super(DataParallel, self).__init__(module, device_ids, output_device, dim)
92
+
93
+ def __getattr__(self, name):
94
+ try:
95
+ return super().__getattr__(name)
96
+ except AttributeError:
97
+ return getattr(self.module, name)