Create pipeline_bddm.py
Browse files- pipeline_bddm.py +304 -0
pipeline_bddm.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
########################################################################
|
4 |
+
#
|
5 |
+
# DiffWave: A Versatile Diffusion Model for Audio Synthesis
|
6 |
+
# (https://arxiv.org/abs/2009.09761)
|
7 |
+
# Modified from https://github.com/philsyn/DiffWave-Vocoder
|
8 |
+
#
|
9 |
+
# Author: Max W. Y. Lam ([email protected])
|
10 |
+
# Copyright (c) 2021Tencent. All Rights Reserved
|
11 |
+
#
|
12 |
+
########################################################################
|
13 |
+
|
14 |
+
|
15 |
+
import math
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import tqdm
|
21 |
+
|
22 |
+
from ..modeling_utils import ModelMixin
|
23 |
+
from ..configuration_utils import ConfigMixin
|
24 |
+
from ..pipeline_utils import DiffusionPipeline
|
25 |
+
|
26 |
+
|
27 |
+
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
28 |
+
"""
|
29 |
+
Embed a diffusion step $t$ into a higher dimensional space
|
30 |
+
E.g. the embedding vector in the 128-dimensional space is
|
31 |
+
[sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
|
32 |
+
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
|
36 |
+
diffusion steps for batch data
|
37 |
+
diffusion_step_embed_dim_in (int, default=128):
|
38 |
+
dimensionality of the embedding space for discrete diffusion steps
|
39 |
+
Returns:
|
40 |
+
the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
|
41 |
+
"""
|
42 |
+
|
43 |
+
assert diffusion_step_embed_dim_in % 2 == 0
|
44 |
+
|
45 |
+
half_dim = diffusion_step_embed_dim_in // 2
|
46 |
+
_embed = np.log(10000) / (half_dim - 1)
|
47 |
+
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
|
48 |
+
_embed = diffusion_steps * _embed
|
49 |
+
diffusion_step_embed = torch.cat((torch.sin(_embed),
|
50 |
+
torch.cos(_embed)), 1)
|
51 |
+
return diffusion_step_embed
|
52 |
+
|
53 |
+
|
54 |
+
"""
|
55 |
+
Below scripts were borrowed from
|
56 |
+
https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
|
57 |
+
"""
|
58 |
+
|
59 |
+
|
60 |
+
def swish(x):
|
61 |
+
return x * torch.sigmoid(x)
|
62 |
+
|
63 |
+
|
64 |
+
# dilated conv layer with kaiming_normal initialization
|
65 |
+
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
|
66 |
+
class Conv(nn.Module):
|
67 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
|
68 |
+
super().__init__()
|
69 |
+
self.padding = dilation * (kernel_size - 1) // 2
|
70 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
|
71 |
+
dilation=dilation, padding=self.padding)
|
72 |
+
self.conv = nn.utils.weight_norm(self.conv)
|
73 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
out = self.conv(x)
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
# conv1x1 layer with zero initialization
|
81 |
+
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
|
82 |
+
class ZeroConv1d(nn.Module):
|
83 |
+
def __init__(self, in_channel, out_channel):
|
84 |
+
super().__init__()
|
85 |
+
self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
|
86 |
+
self.conv.weight.data.zero_()
|
87 |
+
self.conv.bias.data.zero_()
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
out = self.conv(x)
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
# every residual block (named residual layer in paper)
|
95 |
+
# contains one noncausal dilated conv
|
96 |
+
class ResidualBlock(nn.Module):
|
97 |
+
def __init__(self, res_channels, skip_channels, dilation,
|
98 |
+
diffusion_step_embed_dim_out):
|
99 |
+
super().__init__()
|
100 |
+
self.res_channels = res_channels
|
101 |
+
|
102 |
+
# Use a FC layer for diffusion step embedding
|
103 |
+
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
|
104 |
+
|
105 |
+
# Dilated conv layer
|
106 |
+
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
|
107 |
+
kernel_size=3, dilation=dilation)
|
108 |
+
|
109 |
+
# Add mel spectrogram upsampler and conditioner conv1x1 layer
|
110 |
+
self.upsample_conv2d = nn.ModuleList()
|
111 |
+
for s in [16, 16]:
|
112 |
+
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
|
113 |
+
padding=(1, s // 2),
|
114 |
+
stride=(1, s))
|
115 |
+
conv_trans2d = nn.utils.weight_norm(conv_trans2d)
|
116 |
+
nn.init.kaiming_normal_(conv_trans2d.weight)
|
117 |
+
self.upsample_conv2d.append(conv_trans2d)
|
118 |
+
|
119 |
+
# 80 is mel bands
|
120 |
+
self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)
|
121 |
+
|
122 |
+
# Residual conv1x1 layer, connect to next residual layer
|
123 |
+
self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
|
124 |
+
self.res_conv = nn.utils.weight_norm(self.res_conv)
|
125 |
+
nn.init.kaiming_normal_(self.res_conv.weight)
|
126 |
+
|
127 |
+
# Skip conv1x1 layer, add to all skip outputs through skip connections
|
128 |
+
self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
|
129 |
+
self.skip_conv = nn.utils.weight_norm(self.skip_conv)
|
130 |
+
nn.init.kaiming_normal_(self.skip_conv.weight)
|
131 |
+
|
132 |
+
def forward(self, input_data):
|
133 |
+
x, mel_spec, diffusion_step_embed = input_data
|
134 |
+
h = x
|
135 |
+
batch_size, n_channels, seq_len = x.shape
|
136 |
+
assert n_channels == self.res_channels
|
137 |
+
|
138 |
+
# Add in diffusion step embedding
|
139 |
+
part_t = self.fc_t(diffusion_step_embed)
|
140 |
+
part_t = part_t.view([batch_size, self.res_channels, 1])
|
141 |
+
h += part_t
|
142 |
+
|
143 |
+
# Dilated conv layer
|
144 |
+
h = self.dilated_conv_layer(h)
|
145 |
+
|
146 |
+
# Upsample spectrogram to size of audio
|
147 |
+
mel_spec = torch.unsqueeze(mel_spec, dim=1)
|
148 |
+
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
|
149 |
+
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
|
150 |
+
mel_spec = torch.squeeze(mel_spec, dim=1)
|
151 |
+
|
152 |
+
assert mel_spec.size(2) >= seq_len
|
153 |
+
if mel_spec.size(2) > seq_len:
|
154 |
+
mel_spec = mel_spec[:, :, :seq_len]
|
155 |
+
|
156 |
+
mel_spec = self.mel_conv(mel_spec)
|
157 |
+
h += mel_spec
|
158 |
+
|
159 |
+
# Gated-tanh nonlinearity
|
160 |
+
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
|
161 |
+
|
162 |
+
# Residual and skip outputs
|
163 |
+
res = self.res_conv(out)
|
164 |
+
assert x.shape == res.shape
|
165 |
+
skip = self.skip_conv(out)
|
166 |
+
|
167 |
+
# Normalize for training stability
|
168 |
+
return (x + res) * math.sqrt(0.5), skip
|
169 |
+
|
170 |
+
|
171 |
+
class ResidualGroup(nn.Module):
|
172 |
+
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
|
173 |
+
diffusion_step_embed_dim_in,
|
174 |
+
diffusion_step_embed_dim_mid,
|
175 |
+
diffusion_step_embed_dim_out):
|
176 |
+
super().__init__()
|
177 |
+
self.num_res_layers = num_res_layers
|
178 |
+
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
|
179 |
+
|
180 |
+
# Use the shared two FC layers for diffusion step embedding
|
181 |
+
self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
|
182 |
+
self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)
|
183 |
+
|
184 |
+
# Stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
|
185 |
+
self.residual_blocks = nn.ModuleList()
|
186 |
+
for n in range(self.num_res_layers):
|
187 |
+
self.residual_blocks.append(
|
188 |
+
ResidualBlock(res_channels, skip_channels,
|
189 |
+
dilation=2 ** (n % dilation_cycle),
|
190 |
+
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
|
191 |
+
|
192 |
+
def forward(self, input_data):
|
193 |
+
x, mel_spectrogram, diffusion_steps = input_data
|
194 |
+
|
195 |
+
# Embed diffusion step t
|
196 |
+
diffusion_step_embed = calc_diffusion_step_embedding(
|
197 |
+
diffusion_steps, self.diffusion_step_embed_dim_in)
|
198 |
+
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
|
199 |
+
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
|
200 |
+
|
201 |
+
# Pass all residual layers
|
202 |
+
h = x
|
203 |
+
skip = 0
|
204 |
+
for n in range(self.num_res_layers):
|
205 |
+
# Use the output from last residual layer
|
206 |
+
h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, diffusion_step_embed))
|
207 |
+
# Accumulate all skip outputs
|
208 |
+
skip += skip_n
|
209 |
+
|
210 |
+
# Normalize for training stability
|
211 |
+
return skip * math.sqrt(1.0 / self.num_res_layers)
|
212 |
+
|
213 |
+
|
214 |
+
class DiffWave(ModelMixin, ConfigMixin):
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
in_channels=1,
|
218 |
+
res_channels=128,
|
219 |
+
skip_channels=128,
|
220 |
+
out_channels=1,
|
221 |
+
num_res_layers=30,
|
222 |
+
dilation_cycle=10,
|
223 |
+
diffusion_step_embed_dim_in=128,
|
224 |
+
diffusion_step_embed_dim_mid=512,
|
225 |
+
diffusion_step_embed_dim_out=512,
|
226 |
+
):
|
227 |
+
super().__init__()
|
228 |
+
|
229 |
+
# register all init arguments with self.register
|
230 |
+
self.register(
|
231 |
+
in_channels=in_channels,
|
232 |
+
res_channels=res_channels,
|
233 |
+
skip_channels=skip_channels,
|
234 |
+
out_channels=out_channels,
|
235 |
+
num_res_layers=num_res_layers,
|
236 |
+
dilation_cycle=dilation_cycle,
|
237 |
+
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
|
238 |
+
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
|
239 |
+
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
|
240 |
+
)
|
241 |
+
|
242 |
+
|
243 |
+
# Initial conv1x1 with relu
|
244 |
+
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
|
245 |
+
# All residual layers
|
246 |
+
self.residual_layer = ResidualGroup(res_channels,
|
247 |
+
skip_channels,
|
248 |
+
num_res_layers,
|
249 |
+
dilation_cycle,
|
250 |
+
diffusion_step_embed_dim_in,
|
251 |
+
diffusion_step_embed_dim_mid,
|
252 |
+
diffusion_step_embed_dim_out)
|
253 |
+
# Final conv1x1 -> relu -> zeroconv1x1
|
254 |
+
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
|
255 |
+
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
|
256 |
+
|
257 |
+
def forward(self, input_data):
|
258 |
+
audio, mel_spectrogram, diffusion_steps = input_data
|
259 |
+
x = audio
|
260 |
+
x = self.init_conv(x).clone()
|
261 |
+
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
|
262 |
+
return self.final_conv(x)
|
263 |
+
|
264 |
+
|
265 |
+
class BDDM(DiffusionPipeline):
|
266 |
+
def __init__(self, diffwave, noise_scheduler):
|
267 |
+
super().__init__()
|
268 |
+
noise_scheduler = noise_scheduler.set_format("pt")
|
269 |
+
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def __call__(self, mel_spectrogram, generator, torch_device=None):
|
273 |
+
if torch_device is None:
|
274 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
275 |
+
|
276 |
+
self.diffwave.to(torch_device)
|
277 |
+
|
278 |
+
mel_spectrogram = mel_spectrogram.to(torch_device)
|
279 |
+
audio_length = mel_spectrogram.size(-1) * 256
|
280 |
+
audio_size = (1, 1, audio_length)
|
281 |
+
|
282 |
+
# Sample gaussian noise to begin loop
|
283 |
+
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
284 |
+
|
285 |
+
timestep_values = self.noise_scheduler.timestep_values
|
286 |
+
num_prediction_steps = len(self.noise_scheduler)
|
287 |
+
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
288 |
+
# 1. predict noise residual
|
289 |
+
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
|
290 |
+
residual = self.diffwave((audio, mel_spectrogram, ts))
|
291 |
+
|
292 |
+
# 2. predict previous mean of audio x_t-1
|
293 |
+
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
|
294 |
+
|
295 |
+
# 3. optionally sample variance
|
296 |
+
variance = 0
|
297 |
+
if t > 0:
|
298 |
+
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
299 |
+
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
300 |
+
|
301 |
+
# 4. set current audio to prev_audio: x_t -> x_t-1
|
302 |
+
audio = pred_prev_audio + variance
|
303 |
+
|
304 |
+
return audio
|