valhalla commited on
Commit
2dbecff
·
1 Parent(s): 6094838

Create pipeline_bddm.py

Browse files
Files changed (1) hide show
  1. pipeline_bddm.py +304 -0
pipeline_bddm.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ ########################################################################
4
+ #
5
+ # DiffWave: A Versatile Diffusion Model for Audio Synthesis
6
+ # (https://arxiv.org/abs/2009.09761)
7
+ # Modified from https://github.com/philsyn/DiffWave-Vocoder
8
+ #
9
+ # Author: Max W. Y. Lam ([email protected])
10
+ # Copyright (c) 2021Tencent. All Rights Reserved
11
+ #
12
+ ########################################################################
13
+
14
+
15
+ import math
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import tqdm
21
+
22
+ from ..modeling_utils import ModelMixin
23
+ from ..configuration_utils import ConfigMixin
24
+ from ..pipeline_utils import DiffusionPipeline
25
+
26
+
27
+ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
28
+ """
29
+ Embed a diffusion step $t$ into a higher dimensional space
30
+ E.g. the embedding vector in the 128-dimensional space is
31
+ [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
32
+ cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
33
+
34
+ Parameters:
35
+ diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
36
+ diffusion steps for batch data
37
+ diffusion_step_embed_dim_in (int, default=128):
38
+ dimensionality of the embedding space for discrete diffusion steps
39
+ Returns:
40
+ the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
41
+ """
42
+
43
+ assert diffusion_step_embed_dim_in % 2 == 0
44
+
45
+ half_dim = diffusion_step_embed_dim_in // 2
46
+ _embed = np.log(10000) / (half_dim - 1)
47
+ _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
48
+ _embed = diffusion_steps * _embed
49
+ diffusion_step_embed = torch.cat((torch.sin(_embed),
50
+ torch.cos(_embed)), 1)
51
+ return diffusion_step_embed
52
+
53
+
54
+ """
55
+ Below scripts were borrowed from
56
+ https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
57
+ """
58
+
59
+
60
+ def swish(x):
61
+ return x * torch.sigmoid(x)
62
+
63
+
64
+ # dilated conv layer with kaiming_normal initialization
65
+ # from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
66
+ class Conv(nn.Module):
67
+ def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
68
+ super().__init__()
69
+ self.padding = dilation * (kernel_size - 1) // 2
70
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
71
+ dilation=dilation, padding=self.padding)
72
+ self.conv = nn.utils.weight_norm(self.conv)
73
+ nn.init.kaiming_normal_(self.conv.weight)
74
+
75
+ def forward(self, x):
76
+ out = self.conv(x)
77
+ return out
78
+
79
+
80
+ # conv1x1 layer with zero initialization
81
+ # from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
82
+ class ZeroConv1d(nn.Module):
83
+ def __init__(self, in_channel, out_channel):
84
+ super().__init__()
85
+ self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
86
+ self.conv.weight.data.zero_()
87
+ self.conv.bias.data.zero_()
88
+
89
+ def forward(self, x):
90
+ out = self.conv(x)
91
+ return out
92
+
93
+
94
+ # every residual block (named residual layer in paper)
95
+ # contains one noncausal dilated conv
96
+ class ResidualBlock(nn.Module):
97
+ def __init__(self, res_channels, skip_channels, dilation,
98
+ diffusion_step_embed_dim_out):
99
+ super().__init__()
100
+ self.res_channels = res_channels
101
+
102
+ # Use a FC layer for diffusion step embedding
103
+ self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
104
+
105
+ # Dilated conv layer
106
+ self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
107
+ kernel_size=3, dilation=dilation)
108
+
109
+ # Add mel spectrogram upsampler and conditioner conv1x1 layer
110
+ self.upsample_conv2d = nn.ModuleList()
111
+ for s in [16, 16]:
112
+ conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
113
+ padding=(1, s // 2),
114
+ stride=(1, s))
115
+ conv_trans2d = nn.utils.weight_norm(conv_trans2d)
116
+ nn.init.kaiming_normal_(conv_trans2d.weight)
117
+ self.upsample_conv2d.append(conv_trans2d)
118
+
119
+ # 80 is mel bands
120
+ self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)
121
+
122
+ # Residual conv1x1 layer, connect to next residual layer
123
+ self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
124
+ self.res_conv = nn.utils.weight_norm(self.res_conv)
125
+ nn.init.kaiming_normal_(self.res_conv.weight)
126
+
127
+ # Skip conv1x1 layer, add to all skip outputs through skip connections
128
+ self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
129
+ self.skip_conv = nn.utils.weight_norm(self.skip_conv)
130
+ nn.init.kaiming_normal_(self.skip_conv.weight)
131
+
132
+ def forward(self, input_data):
133
+ x, mel_spec, diffusion_step_embed = input_data
134
+ h = x
135
+ batch_size, n_channels, seq_len = x.shape
136
+ assert n_channels == self.res_channels
137
+
138
+ # Add in diffusion step embedding
139
+ part_t = self.fc_t(diffusion_step_embed)
140
+ part_t = part_t.view([batch_size, self.res_channels, 1])
141
+ h += part_t
142
+
143
+ # Dilated conv layer
144
+ h = self.dilated_conv_layer(h)
145
+
146
+ # Upsample spectrogram to size of audio
147
+ mel_spec = torch.unsqueeze(mel_spec, dim=1)
148
+ mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
149
+ mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
150
+ mel_spec = torch.squeeze(mel_spec, dim=1)
151
+
152
+ assert mel_spec.size(2) >= seq_len
153
+ if mel_spec.size(2) > seq_len:
154
+ mel_spec = mel_spec[:, :, :seq_len]
155
+
156
+ mel_spec = self.mel_conv(mel_spec)
157
+ h += mel_spec
158
+
159
+ # Gated-tanh nonlinearity
160
+ out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
161
+
162
+ # Residual and skip outputs
163
+ res = self.res_conv(out)
164
+ assert x.shape == res.shape
165
+ skip = self.skip_conv(out)
166
+
167
+ # Normalize for training stability
168
+ return (x + res) * math.sqrt(0.5), skip
169
+
170
+
171
+ class ResidualGroup(nn.Module):
172
+ def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
173
+ diffusion_step_embed_dim_in,
174
+ diffusion_step_embed_dim_mid,
175
+ diffusion_step_embed_dim_out):
176
+ super().__init__()
177
+ self.num_res_layers = num_res_layers
178
+ self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
179
+
180
+ # Use the shared two FC layers for diffusion step embedding
181
+ self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
182
+ self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)
183
+
184
+ # Stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
185
+ self.residual_blocks = nn.ModuleList()
186
+ for n in range(self.num_res_layers):
187
+ self.residual_blocks.append(
188
+ ResidualBlock(res_channels, skip_channels,
189
+ dilation=2 ** (n % dilation_cycle),
190
+ diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
191
+
192
+ def forward(self, input_data):
193
+ x, mel_spectrogram, diffusion_steps = input_data
194
+
195
+ # Embed diffusion step t
196
+ diffusion_step_embed = calc_diffusion_step_embedding(
197
+ diffusion_steps, self.diffusion_step_embed_dim_in)
198
+ diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
199
+ diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
200
+
201
+ # Pass all residual layers
202
+ h = x
203
+ skip = 0
204
+ for n in range(self.num_res_layers):
205
+ # Use the output from last residual layer
206
+ h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, diffusion_step_embed))
207
+ # Accumulate all skip outputs
208
+ skip += skip_n
209
+
210
+ # Normalize for training stability
211
+ return skip * math.sqrt(1.0 / self.num_res_layers)
212
+
213
+
214
+ class DiffWave(ModelMixin, ConfigMixin):
215
+ def __init__(
216
+ self,
217
+ in_channels=1,
218
+ res_channels=128,
219
+ skip_channels=128,
220
+ out_channels=1,
221
+ num_res_layers=30,
222
+ dilation_cycle=10,
223
+ diffusion_step_embed_dim_in=128,
224
+ diffusion_step_embed_dim_mid=512,
225
+ diffusion_step_embed_dim_out=512,
226
+ ):
227
+ super().__init__()
228
+
229
+ # register all init arguments with self.register
230
+ self.register(
231
+ in_channels=in_channels,
232
+ res_channels=res_channels,
233
+ skip_channels=skip_channels,
234
+ out_channels=out_channels,
235
+ num_res_layers=num_res_layers,
236
+ dilation_cycle=dilation_cycle,
237
+ diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
238
+ diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
239
+ diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
240
+ )
241
+
242
+
243
+ # Initial conv1x1 with relu
244
+ self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
245
+ # All residual layers
246
+ self.residual_layer = ResidualGroup(res_channels,
247
+ skip_channels,
248
+ num_res_layers,
249
+ dilation_cycle,
250
+ diffusion_step_embed_dim_in,
251
+ diffusion_step_embed_dim_mid,
252
+ diffusion_step_embed_dim_out)
253
+ # Final conv1x1 -> relu -> zeroconv1x1
254
+ self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
255
+ nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
256
+
257
+ def forward(self, input_data):
258
+ audio, mel_spectrogram, diffusion_steps = input_data
259
+ x = audio
260
+ x = self.init_conv(x).clone()
261
+ x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
262
+ return self.final_conv(x)
263
+
264
+
265
+ class BDDM(DiffusionPipeline):
266
+ def __init__(self, diffwave, noise_scheduler):
267
+ super().__init__()
268
+ noise_scheduler = noise_scheduler.set_format("pt")
269
+ self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
270
+
271
+ @torch.no_grad()
272
+ def __call__(self, mel_spectrogram, generator, torch_device=None):
273
+ if torch_device is None:
274
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
275
+
276
+ self.diffwave.to(torch_device)
277
+
278
+ mel_spectrogram = mel_spectrogram.to(torch_device)
279
+ audio_length = mel_spectrogram.size(-1) * 256
280
+ audio_size = (1, 1, audio_length)
281
+
282
+ # Sample gaussian noise to begin loop
283
+ audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
284
+
285
+ timestep_values = self.noise_scheduler.timestep_values
286
+ num_prediction_steps = len(self.noise_scheduler)
287
+ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
288
+ # 1. predict noise residual
289
+ ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
290
+ residual = self.diffwave((audio, mel_spectrogram, ts))
291
+
292
+ # 2. predict previous mean of audio x_t-1
293
+ pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
294
+
295
+ # 3. optionally sample variance
296
+ variance = 0
297
+ if t > 0:
298
+ noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
299
+ variance = self.noise_scheduler.get_variance(t).sqrt() * noise
300
+
301
+ # 4. set current audio to prev_audio: x_t -> x_t-1
302
+ audio = pred_prev_audio + variance
303
+
304
+ return audio