Spaces:
Runtime error
Runtime error
import math | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
import modules | |
import commons | |
import attentions | |
import monotonic_align | |
class DurationPredictor(nn.Module): | |
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): | |
super().__init__() | |
self.in_channels = in_channels | |
self.filter_channels = filter_channels | |
self.kernel_size = kernel_size | |
self.p_dropout = p_dropout | |
self.drop = nn.Dropout(p_dropout) | |
self.conv_1 = nn.Conv1d( | |
in_channels, filter_channels, kernel_size, padding=kernel_size // 2 | |
) | |
self.norm_1 = attentions.LayerNorm(filter_channels) | |
self.conv_2 = nn.Conv1d( | |
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 | |
) | |
self.norm_2 = attentions.LayerNorm(filter_channels) | |
self.proj = nn.Conv1d(filter_channels, 1, 1) | |
def forward(self, x, x_mask): | |
x = self.conv_1(x * x_mask) | |
x = torch.relu(x) | |
x = self.norm_1(x) | |
x = self.drop(x) | |
x = self.conv_2(x * x_mask) | |
x = torch.relu(x) | |
x = self.norm_2(x) | |
x = self.drop(x) | |
x = self.proj(x * x_mask) | |
return x * x_mask | |
class TextEncoder(nn.Module): | |
def __init__( | |
self, | |
n_vocab, | |
out_channels, | |
hidden_channels, | |
filter_channels, | |
filter_channels_dp, | |
n_heads, | |
n_layers, | |
kernel_size, | |
p_dropout, | |
window_size=None, | |
block_length=None, | |
mean_only=False, | |
prenet=False, | |
gin_channels=0, | |
): | |
super().__init__() | |
self.n_vocab = n_vocab | |
self.out_channels = out_channels | |
self.hidden_channels = hidden_channels | |
self.filter_channels = filter_channels | |
self.filter_channels_dp = filter_channels_dp | |
self.n_heads = n_heads | |
self.n_layers = n_layers | |
self.kernel_size = kernel_size | |
self.p_dropout = p_dropout | |
self.window_size = window_size | |
self.block_length = block_length | |
self.mean_only = mean_only | |
self.prenet = prenet | |
self.gin_channels = gin_channels | |
self.emb = nn.Embedding(n_vocab, hidden_channels) | |
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) | |
if prenet: | |
self.pre = modules.ConvReluNorm( | |
hidden_channels, | |
hidden_channels, | |
hidden_channels, | |
kernel_size=5, | |
n_layers=3, | |
p_dropout=0.5, | |
) | |
self.encoder = attentions.Encoder( | |
hidden_channels, | |
filter_channels, | |
n_heads, | |
n_layers, | |
kernel_size, | |
p_dropout, | |
window_size=window_size, | |
block_length=block_length, | |
) | |
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) | |
if not mean_only: | |
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) | |
self.proj_w = DurationPredictor( | |
hidden_channels + gin_channels, filter_channels_dp, kernel_size, p_dropout | |
) | |
def forward(self, x, x_lengths, g=None): | |
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] | |
x = torch.transpose(x, 1, -1) # [b, h, t] | |
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( | |
x.dtype | |
) | |
if self.prenet: | |
x = self.pre(x, x_mask) | |
x = self.encoder(x, x_mask) | |
if g is not None: | |
g_exp = g.expand(-1, -1, x.size(-1)) | |
x_dp = torch.cat([torch.detach(x), g_exp], 1) | |
else: | |
x_dp = torch.detach(x) | |
x_m = self.proj_m(x) * x_mask | |
if not self.mean_only: | |
x_logs = self.proj_s(x) * x_mask | |
else: | |
x_logs = torch.zeros_like(x_m) | |
logw = self.proj_w(x_dp, x_mask) | |
return x_m, x_logs, logw, x_mask | |
class FlowSpecDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
hidden_channels, | |
kernel_size, | |
dilation_rate, | |
n_blocks, | |
n_layers, | |
p_dropout=0.0, | |
n_split=4, | |
n_sqz=2, | |
sigmoid_scale=False, | |
gin_channels=0, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.hidden_channels = hidden_channels | |
self.kernel_size = kernel_size | |
self.dilation_rate = dilation_rate | |
self.n_blocks = n_blocks | |
self.n_layers = n_layers | |
self.p_dropout = p_dropout | |
self.n_split = n_split | |
self.n_sqz = n_sqz | |
self.sigmoid_scale = sigmoid_scale | |
self.gin_channels = gin_channels | |
self.flows = nn.ModuleList() | |
for b in range(n_blocks): | |
self.flows.append(modules.ActNorm(channels=in_channels * n_sqz)) | |
self.flows.append( | |
modules.InvConvNear(channels=in_channels * n_sqz, n_split=n_split) | |
) | |
self.flows.append( | |
attentions.CouplingBlock( | |
in_channels * n_sqz, | |
hidden_channels, | |
kernel_size=kernel_size, | |
dilation_rate=dilation_rate, | |
n_layers=n_layers, | |
gin_channels=gin_channels, | |
p_dropout=p_dropout, | |
sigmoid_scale=sigmoid_scale, | |
) | |
) | |
def forward(self, x, x_mask, g=None, reverse=False): | |
if not reverse: | |
flows = self.flows | |
logdet_tot = 0 | |
else: | |
flows = reversed(self.flows) | |
logdet_tot = None | |
if self.n_sqz > 1: | |
x, x_mask = commons.squeeze(x, x_mask, self.n_sqz) | |
for f in flows: | |
if not reverse: | |
x, logdet = f(x, x_mask, g=g, reverse=reverse) | |
logdet_tot += logdet | |
else: | |
x, logdet = f(x, x_mask, g=g, reverse=reverse) | |
if self.n_sqz > 1: | |
x, x_mask = commons.unsqueeze(x, x_mask, self.n_sqz) | |
return x, logdet_tot | |
def store_inverse(self): | |
for f in self.flows: | |
f.store_inverse() | |
class FlowGenerator(nn.Module): | |
def __init__( | |
self, | |
n_vocab, | |
hidden_channels, | |
filter_channels, | |
filter_channels_dp, | |
out_channels, | |
kernel_size=3, | |
n_heads=2, | |
n_layers_enc=6, | |
p_dropout=0.0, | |
n_blocks_dec=12, | |
kernel_size_dec=5, | |
dilation_rate=5, | |
n_block_layers=4, | |
p_dropout_dec=0.0, | |
n_speakers=0, | |
gin_channels=0, | |
n_split=4, | |
n_sqz=1, | |
sigmoid_scale=False, | |
window_size=None, | |
block_length=None, | |
mean_only=False, | |
hidden_channels_enc=None, | |
hidden_channels_dec=None, | |
prenet=False, | |
**kwargs | |
): | |
super().__init__() | |
self.n_vocab = n_vocab | |
self.hidden_channels = hidden_channels | |
self.filter_channels = filter_channels | |
self.filter_channels_dp = filter_channels_dp | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.n_heads = n_heads | |
self.n_layers_enc = n_layers_enc | |
self.p_dropout = p_dropout | |
self.n_blocks_dec = n_blocks_dec | |
self.kernel_size_dec = kernel_size_dec | |
self.dilation_rate = dilation_rate | |
self.n_block_layers = n_block_layers | |
self.p_dropout_dec = p_dropout_dec | |
self.n_speakers = n_speakers | |
self.gin_channels = gin_channels | |
self.n_split = n_split | |
self.n_sqz = n_sqz | |
self.sigmoid_scale = sigmoid_scale | |
self.window_size = window_size | |
self.block_length = block_length | |
self.mean_only = mean_only | |
self.hidden_channels_enc = hidden_channels_enc | |
self.hidden_channels_dec = hidden_channels_dec | |
self.prenet = prenet | |
self.encoder = TextEncoder( | |
n_vocab, | |
out_channels, | |
hidden_channels_enc or hidden_channels, | |
filter_channels, | |
filter_channels_dp, | |
n_heads, | |
n_layers_enc, | |
kernel_size, | |
p_dropout, | |
window_size=window_size, | |
block_length=block_length, | |
mean_only=mean_only, | |
prenet=prenet, | |
gin_channels=gin_channels, | |
) | |
self.decoder = FlowSpecDecoder( | |
out_channels, | |
hidden_channels_dec or hidden_channels, | |
kernel_size_dec, | |
dilation_rate, | |
n_blocks_dec, | |
n_block_layers, | |
p_dropout=p_dropout_dec, | |
n_split=n_split, | |
n_sqz=n_sqz, | |
sigmoid_scale=sigmoid_scale, | |
gin_channels=gin_channels, | |
) | |
if n_speakers > 1: | |
self.emb_g = nn.Embedding(n_speakers, gin_channels) | |
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) | |
def forward( | |
self, | |
x, | |
x_lengths, | |
y=None, | |
y_lengths=None, | |
g=None, | |
gen=False, | |
noise_scale=1.0, | |
length_scale=1.0, | |
): | |
if g is not None: | |
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] | |
x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths, g=g) | |
if gen: | |
w = torch.exp(logw) * x_mask * length_scale | |
w_ceil = torch.ceil(w) | |
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() | |
y_max_length = None | |
else: | |
y_max_length = y.size(2) | |
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length) | |
z_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y_max_length), 1).to( | |
x_mask.dtype | |
) | |
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2) | |
if gen: | |
attn = commons.generate_path( | |
w_ceil.squeeze(1), attn_mask.squeeze(1) | |
).unsqueeze(1) | |
z_m = torch.matmul( | |
attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2) | |
).transpose( | |
1, 2 | |
) # [b, t', t], [b, t, d] -> [b, d, t'] | |
z_logs = torch.matmul( | |
attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2) | |
).transpose( | |
1, 2 | |
) # [b, t', t], [b, t, d] -> [b, d, t'] | |
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask | |
z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask | |
y, logdet = self.decoder(z, z_mask, g=g, reverse=True) | |
return ( | |
(y, z_m, z_logs, logdet, z_mask), | |
(x_m, x_logs, x_mask), | |
(attn, logw, logw_), | |
) | |
else: | |
z, logdet = self.decoder(y, z_mask, g=g, reverse=False) | |
with torch.no_grad(): | |
x_s_sq_r = torch.exp(-2 * x_logs) | |
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze( | |
-1 | |
) # [b, t, 1] | |
logp2 = torch.matmul( | |
x_s_sq_r.transpose(1, 2), -0.5 * (z ** 2) | |
) # [b, t, d] x [b, d, t'] = [b, t, t'] | |
logp3 = torch.matmul( | |
(x_m * x_s_sq_r).transpose(1, 2), z | |
) # [b, t, d] x [b, d, t'] = [b, t, t'] | |
logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze( | |
-1 | |
) # [b, t, 1] | |
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] | |
attn = ( | |
monotonic_align.maximum_path(logp, attn_mask.squeeze(1)) | |
.unsqueeze(1) | |
.detach() | |
) | |
z_m = torch.matmul( | |
attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2) | |
).transpose( | |
1, 2 | |
) # [b, t', t], [b, t, d] -> [b, d, t'] | |
z_logs = torch.matmul( | |
attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2) | |
).transpose( | |
1, 2 | |
) # [b, t', t], [b, t, d] -> [b, d, t'] | |
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask | |
return ( | |
(z, z_m, z_logs, logdet, z_mask), | |
(x_m, x_logs, x_mask), | |
(attn, logw, logw_), | |
) | |
def preprocess(self, y, y_lengths, y_max_length): | |
if y_max_length is not None: | |
y_max_length = (y_max_length // self.n_sqz) * self.n_sqz | |
y = y[:, :, :y_max_length] | |
y_lengths = (y_lengths // self.n_sqz) * self.n_sqz | |
return y, y_lengths, y_max_length | |
def store_inverse(self): | |
self.decoder.store_inverse() | |