hieupt commited on
Commit
cce0a91
·
verified ·
1 Parent(s): f5979b8

Upload waveunet.py

Browse files
Files changed (1) hide show
  1. model/waveunet.py +233 -0
model/waveunet.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.crop import centre_crop
5
+ from model.resample import Resample1d
6
+ from model.conv import ConvLayer
7
+
8
+ class UpsamplingBlock(nn.Module):
9
+ def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
10
+ super(UpsamplingBlock, self).__init__()
11
+ assert(stride > 1)
12
+
13
+ # CONV 1 for UPSAMPLING
14
+ if res == "fixed":
15
+ self.upconv = Resample1d(n_inputs, 15, stride, transpose=True)
16
+ else:
17
+ self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True)
18
+
19
+ self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] +
20
+ [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
21
+
22
+ # CONVS to combine high- with low-level information (from shortcut)
23
+ self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
24
+ [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
25
+
26
+ def forward(self, x, shortcut):
27
+ # UPSAMPLE HIGH-LEVEL FEATURES
28
+ upsampled = self.upconv(x)
29
+
30
+ for conv in self.pre_shortcut_convs:
31
+ upsampled = conv(upsampled)
32
+
33
+ # Prepare shortcut connection
34
+ combined = centre_crop(shortcut, upsampled)
35
+
36
+ # Combine high- and low-level features
37
+ for conv in self.post_shortcut_convs:
38
+ combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1))
39
+ return combined
40
+
41
+ def get_output_size(self, input_size):
42
+ curr_size = self.upconv.get_output_size(input_size)
43
+
44
+ # Upsampling convs
45
+ for conv in self.pre_shortcut_convs:
46
+ curr_size = conv.get_output_size(curr_size)
47
+
48
+ # Combine convolutions
49
+ for conv in self.post_shortcut_convs:
50
+ curr_size = conv.get_output_size(curr_size)
51
+
52
+ return curr_size
53
+
54
+ class DownsamplingBlock(nn.Module):
55
+ def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
56
+ super(DownsamplingBlock, self).__init__()
57
+ assert(stride > 1)
58
+
59
+ self.kernel_size = kernel_size
60
+ self.stride = stride
61
+
62
+ # CONV 1
63
+ self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] +
64
+ [ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)])
65
+
66
+ self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
67
+ [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in
68
+ range(depth - 1)])
69
+
70
+ # CONV 2 with decimation
71
+ if res == "fixed":
72
+ self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter
73
+ else:
74
+ self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type)
75
+
76
+ def forward(self, x):
77
+ # PREPARING SHORTCUT FEATURES
78
+ shortcut = x
79
+ for conv in self.pre_shortcut_convs:
80
+ shortcut = conv(shortcut)
81
+
82
+ # PREPARING FOR DOWNSAMPLING
83
+ out = shortcut
84
+ for conv in self.post_shortcut_convs:
85
+ out = conv(out)
86
+
87
+ # DOWNSAMPLING
88
+ out = self.downconv(out)
89
+
90
+ return out, shortcut
91
+
92
+ def get_input_size(self, output_size):
93
+ curr_size = self.downconv.get_input_size(output_size)
94
+
95
+ for conv in reversed(self.post_shortcut_convs):
96
+ curr_size = conv.get_input_size(curr_size)
97
+
98
+ for conv in reversed(self.pre_shortcut_convs):
99
+ curr_size = conv.get_input_size(curr_size)
100
+ return curr_size
101
+
102
+ class Waveunet(nn.Module):
103
+ def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2):
104
+ super(Waveunet, self).__init__()
105
+
106
+ self.num_levels = len(num_channels)
107
+ self.strides = strides
108
+ self.kernel_size = kernel_size
109
+ self.num_inputs = num_inputs
110
+ self.num_outputs = num_outputs
111
+ self.depth = depth
112
+ self.instruments = instruments
113
+ self.separate = separate
114
+
115
+ # Only odd filter kernels allowed
116
+ assert(kernel_size % 2 == 1)
117
+
118
+ self.waveunets = nn.ModuleDict()
119
+
120
+ model_list = instruments if separate else ["ALL"]
121
+ # Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"])
122
+ for instrument in model_list:
123
+ module = nn.Module()
124
+
125
+ module.downsampling_blocks = nn.ModuleList()
126
+ module.upsampling_blocks = nn.ModuleList()
127
+
128
+ for i in range(self.num_levels - 1):
129
+ in_ch = num_inputs if i == 0 else num_channels[i]
130
+
131
+ module.downsampling_blocks.append(
132
+ DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res))
133
+
134
+ for i in range(0, self.num_levels - 1):
135
+ module.upsampling_blocks.append(
136
+ UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res))
137
+
138
+ module.bottlenecks = nn.ModuleList(
139
+ [ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)])
140
+
141
+ # Output conv
142
+ outputs = num_outputs if separate else num_outputs * len(instruments)
143
+ module.output_conv = nn.Conv1d(num_channels[0], outputs, 1)
144
+
145
+ self.waveunets[instrument] = module
146
+
147
+ self.set_output_size(target_output_size)
148
+
149
+ def set_output_size(self, target_output_size):
150
+ self.target_output_size = target_output_size
151
+
152
+ self.input_size, self.output_size = self.check_padding(target_output_size)
153
+ print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs")
154
+
155
+ assert((self.input_size - self.output_size) % 2 == 0)
156
+ self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2,
157
+ "output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size,
158
+ "output_frames" : self.output_size,
159
+ "input_frames" : self.input_size}
160
+
161
+ def check_padding(self, target_output_size):
162
+ # Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training
163
+ bottleneck = 1
164
+
165
+ while True:
166
+ out = self.check_padding_for_bottleneck(bottleneck, target_output_size)
167
+ if out is not False:
168
+ return out
169
+ bottleneck += 1
170
+
171
+ def check_padding_for_bottleneck(self, bottleneck, target_output_size):
172
+ module = self.waveunets[[k for k in self.waveunets.keys()][0]]
173
+ try:
174
+ curr_size = bottleneck
175
+ for idx, block in enumerate(module.upsampling_blocks):
176
+ curr_size = block.get_output_size(curr_size)
177
+ output_size = curr_size
178
+
179
+ # Bottleneck-Conv
180
+ curr_size = bottleneck
181
+ for block in reversed(module.bottlenecks):
182
+ curr_size = block.get_input_size(curr_size)
183
+ for idx, block in enumerate(reversed(module.downsampling_blocks)):
184
+ curr_size = block.get_input_size(curr_size)
185
+
186
+ assert(output_size >= target_output_size)
187
+ return curr_size, output_size
188
+ except AssertionError as e:
189
+ return False
190
+
191
+ def forward_module(self, x, module):
192
+ '''
193
+ A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source)
194
+ :param x: Input mix
195
+ :param module: Network module to be used for prediction
196
+ :return: Source estimates
197
+ '''
198
+ shortcuts = []
199
+ out = x
200
+
201
+ # DOWNSAMPLING BLOCKS
202
+ for block in module.downsampling_blocks:
203
+ out, short = block(out)
204
+ shortcuts.append(short)
205
+
206
+ # BOTTLENECK CONVOLUTION
207
+ for conv in module.bottlenecks:
208
+ out = conv(out)
209
+
210
+ # UPSAMPLING BLOCKS
211
+ for idx, block in enumerate(module.upsampling_blocks):
212
+ out = block(out, shortcuts[-1 - idx])
213
+
214
+ # OUTPUT CONV
215
+ out = module.output_conv(out)
216
+ if not self.training: # At test time clip predictions to valid amplitude range
217
+ out = out.clamp(min=-1.0, max=1.0)
218
+ return out
219
+
220
+ def forward(self, x, inst=None):
221
+ curr_input_size = x.shape[-1]
222
+ assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size
223
+
224
+ if self.separate:
225
+ return {inst : self.forward_module(x, self.waveunets[inst])}
226
+ else:
227
+ assert(len(self.waveunets) == 1)
228
+ out = self.forward_module(x, self.waveunets["ALL"])
229
+
230
+ out_dict = {}
231
+ for idx, inst in enumerate(self.instruments):
232
+ out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs]
233
+ return out_dict