valhalla commited on
Commit
c1c4bf8
·
1 Parent(s): 1678761

Delete pipeline_bddm.py

Browse files
Files changed (1) hide show
  1. pipeline_bddm.py +0 -304
pipeline_bddm.py DELETED
@@ -1,304 +0,0 @@
1
- #!/bin/env python
2
- # -*- coding: utf-8 -*-
3
- ########################################################################
4
- #
5
- # DiffWave: A Versatile Diffusion Model for Audio Synthesis
6
- # (https://arxiv.org/abs/2009.09761)
7
- # Modified from https://github.com/philsyn/DiffWave-Vocoder
8
- #
9
- # Author: Max W. Y. Lam ([email protected])
10
- # Copyright (c) 2021Tencent. All Rights Reserved
11
- #
12
- ########################################################################
13
-
14
-
15
- import math
16
- import numpy as np
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- import tqdm
21
-
22
- from diffusers.modeling_utils import ModelMixin
23
- from diffusers.configuration_utils import ConfigMixin
24
- from diffusers.pipeline_utils import DiffusionPipeline
25
-
26
-
27
- def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
28
- """
29
- Embed a diffusion step $t$ into a higher dimensional space
30
- E.g. the embedding vector in the 128-dimensional space is
31
- [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
32
- cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
33
-
34
- Parameters:
35
- diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
36
- diffusion steps for batch data
37
- diffusion_step_embed_dim_in (int, default=128):
38
- dimensionality of the embedding space for discrete diffusion steps
39
- Returns:
40
- the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
41
- """
42
-
43
- assert diffusion_step_embed_dim_in % 2 == 0
44
-
45
- half_dim = diffusion_step_embed_dim_in // 2
46
- _embed = np.log(10000) / (half_dim - 1)
47
- _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
48
- _embed = diffusion_steps * _embed
49
- diffusion_step_embed = torch.cat((torch.sin(_embed),
50
- torch.cos(_embed)), 1)
51
- return diffusion_step_embed
52
-
53
-
54
- """
55
- Below scripts were borrowed from
56
- https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
57
- """
58
-
59
-
60
- def swish(x):
61
- return x * torch.sigmoid(x)
62
-
63
-
64
- # dilated conv layer with kaiming_normal initialization
65
- # from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
66
- class Conv(nn.Module):
67
- def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
68
- super().__init__()
69
- self.padding = dilation * (kernel_size - 1) // 2
70
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
71
- dilation=dilation, padding=self.padding)
72
- self.conv = nn.utils.weight_norm(self.conv)
73
- nn.init.kaiming_normal_(self.conv.weight)
74
-
75
- def forward(self, x):
76
- out = self.conv(x)
77
- return out
78
-
79
-
80
- # conv1x1 layer with zero initialization
81
- # from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
82
- class ZeroConv1d(nn.Module):
83
- def __init__(self, in_channel, out_channel):
84
- super().__init__()
85
- self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
86
- self.conv.weight.data.zero_()
87
- self.conv.bias.data.zero_()
88
-
89
- def forward(self, x):
90
- out = self.conv(x)
91
- return out
92
-
93
-
94
- # every residual block (named residual layer in paper)
95
- # contains one noncausal dilated conv
96
- class ResidualBlock(nn.Module):
97
- def __init__(self, res_channels, skip_channels, dilation,
98
- diffusion_step_embed_dim_out):
99
- super().__init__()
100
- self.res_channels = res_channels
101
-
102
- # Use a FC layer for diffusion step embedding
103
- self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
104
-
105
- # Dilated conv layer
106
- self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
107
- kernel_size=3, dilation=dilation)
108
-
109
- # Add mel spectrogram upsampler and conditioner conv1x1 layer
110
- self.upsample_conv2d = nn.ModuleList()
111
- for s in [16, 16]:
112
- conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
113
- padding=(1, s // 2),
114
- stride=(1, s))
115
- conv_trans2d = nn.utils.weight_norm(conv_trans2d)
116
- nn.init.kaiming_normal_(conv_trans2d.weight)
117
- self.upsample_conv2d.append(conv_trans2d)
118
-
119
- # 80 is mel bands
120
- self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)
121
-
122
- # Residual conv1x1 layer, connect to next residual layer
123
- self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
124
- self.res_conv = nn.utils.weight_norm(self.res_conv)
125
- nn.init.kaiming_normal_(self.res_conv.weight)
126
-
127
- # Skip conv1x1 layer, add to all skip outputs through skip connections
128
- self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
129
- self.skip_conv = nn.utils.weight_norm(self.skip_conv)
130
- nn.init.kaiming_normal_(self.skip_conv.weight)
131
-
132
- def forward(self, input_data):
133
- x, mel_spec, diffusion_step_embed = input_data
134
- h = x
135
- batch_size, n_channels, seq_len = x.shape
136
- assert n_channels == self.res_channels
137
-
138
- # Add in diffusion step embedding
139
- part_t = self.fc_t(diffusion_step_embed)
140
- part_t = part_t.view([batch_size, self.res_channels, 1])
141
- h += part_t
142
-
143
- # Dilated conv layer
144
- h = self.dilated_conv_layer(h)
145
-
146
- # Upsample spectrogram to size of audio
147
- mel_spec = torch.unsqueeze(mel_spec, dim=1)
148
- mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
149
- mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
150
- mel_spec = torch.squeeze(mel_spec, dim=1)
151
-
152
- assert mel_spec.size(2) >= seq_len
153
- if mel_spec.size(2) > seq_len:
154
- mel_spec = mel_spec[:, :, :seq_len]
155
-
156
- mel_spec = self.mel_conv(mel_spec)
157
- h += mel_spec
158
-
159
- # Gated-tanh nonlinearity
160
- out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
161
-
162
- # Residual and skip outputs
163
- res = self.res_conv(out)
164
- assert x.shape == res.shape
165
- skip = self.skip_conv(out)
166
-
167
- # Normalize for training stability
168
- return (x + res) * math.sqrt(0.5), skip
169
-
170
-
171
- class ResidualGroup(nn.Module):
172
- def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
173
- diffusion_step_embed_dim_in,
174
- diffusion_step_embed_dim_mid,
175
- diffusion_step_embed_dim_out):
176
- super().__init__()
177
- self.num_res_layers = num_res_layers
178
- self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
179
-
180
- # Use the shared two FC layers for diffusion step embedding
181
- self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
182
- self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)
183
-
184
- # Stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
185
- self.residual_blocks = nn.ModuleList()
186
- for n in range(self.num_res_layers):
187
- self.residual_blocks.append(
188
- ResidualBlock(res_channels, skip_channels,
189
- dilation=2 ** (n % dilation_cycle),
190
- diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
191
-
192
- def forward(self, input_data):
193
- x, mel_spectrogram, diffusion_steps = input_data
194
-
195
- # Embed diffusion step t
196
- diffusion_step_embed = calc_diffusion_step_embedding(
197
- diffusion_steps, self.diffusion_step_embed_dim_in)
198
- diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
199
- diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
200
-
201
- # Pass all residual layers
202
- h = x
203
- skip = 0
204
- for n in range(self.num_res_layers):
205
- # Use the output from last residual layer
206
- h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, diffusion_step_embed))
207
- # Accumulate all skip outputs
208
- skip += skip_n
209
-
210
- # Normalize for training stability
211
- return skip * math.sqrt(1.0 / self.num_res_layers)
212
-
213
-
214
- class DiffWave(ModelMixin, ConfigMixin):
215
- def __init__(
216
- self,
217
- in_channels=1,
218
- res_channels=128,
219
- skip_channels=128,
220
- out_channels=1,
221
- num_res_layers=30,
222
- dilation_cycle=10,
223
- diffusion_step_embed_dim_in=128,
224
- diffusion_step_embed_dim_mid=512,
225
- diffusion_step_embed_dim_out=512,
226
- ):
227
- super().__init__()
228
-
229
- # register all init arguments with self.register
230
- self.register(
231
- in_channels=in_channels,
232
- res_channels=res_channels,
233
- skip_channels=skip_channels,
234
- out_channels=out_channels,
235
- num_res_layers=num_res_layers,
236
- dilation_cycle=dilation_cycle,
237
- diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
238
- diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
239
- diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
240
- )
241
-
242
-
243
- # Initial conv1x1 with relu
244
- self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
245
- # All residual layers
246
- self.residual_layer = ResidualGroup(res_channels,
247
- skip_channels,
248
- num_res_layers,
249
- dilation_cycle,
250
- diffusion_step_embed_dim_in,
251
- diffusion_step_embed_dim_mid,
252
- diffusion_step_embed_dim_out)
253
- # Final conv1x1 -> relu -> zeroconv1x1
254
- self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
255
- nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
256
-
257
- def forward(self, input_data):
258
- audio, mel_spectrogram, diffusion_steps = input_data
259
- x = audio
260
- x = self.init_conv(x).clone()
261
- x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
262
- return self.final_conv(x)
263
-
264
-
265
- class BDDM(DiffusionPipeline):
266
- def __init__(self, diffwave, noise_scheduler):
267
- super().__init__()
268
- noise_scheduler = noise_scheduler.set_format("pt")
269
- self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
270
-
271
- @torch.no_grad()
272
- def __call__(self, mel_spectrogram, generator, torch_device=None):
273
- if torch_device is None:
274
- torch_device = "cuda" if torch.cuda.is_available() else "cpu"
275
-
276
- self.diffwave.to(torch_device)
277
-
278
- mel_spectrogram = mel_spectrogram.to(torch_device)
279
- audio_length = mel_spectrogram.size(-1) * 256
280
- audio_size = (1, 1, audio_length)
281
-
282
- # Sample gaussian noise to begin loop
283
- audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
284
-
285
- timestep_values = self.noise_scheduler.timestep_values
286
- num_prediction_steps = len(self.noise_scheduler)
287
- for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
288
- # 1. predict noise residual
289
- ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
290
- residual = self.diffwave((audio, mel_spectrogram, ts))
291
-
292
- # 2. predict previous mean of audio x_t-1
293
- pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
294
-
295
- # 3. optionally sample variance
296
- variance = 0
297
- if t > 0:
298
- noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
299
- variance = self.noise_scheduler.get_variance(t).sqrt() * noise
300
-
301
- # 4. set current audio to prev_audio: x_t -> x_t-1
302
- audio = pred_prev_audio + variance
303
-
304
- return audio