|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Literal, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq2.nn.embedding import Embedding, StandardEmbedding |
|
from fairseq2.nn.padding import PaddingMask |
|
from fairseq2.nn.position_encoder import PositionEncoder |
|
from fairseq2.nn.projection import Projection |
|
from fairseq2.typing import DataType, Device |
|
from torch.nn import ( |
|
ELU, |
|
BatchNorm1d, |
|
Conv1d, |
|
ConvTranspose1d, |
|
Dropout, |
|
Module, |
|
ModuleList, |
|
Parameter, |
|
Sequential, |
|
Tanh, |
|
init, |
|
) |
|
from torch.nn.utils.weight_norm import remove_weight_norm, weight_norm |
|
|
|
from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN |
|
from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer |
|
from seamless_communication.models.unity.length_regulator import VarianceAdaptor |
|
from seamless_communication.models.vocoder.hifigan import ( |
|
LRELU_SLOPE, |
|
ResBlock, |
|
init_weights, |
|
) |
|
|
|
from .streamable import ( |
|
StreamableConv1d, |
|
StreamableConvTranspose1d, |
|
StreamableLSTM, |
|
StreamableResnetBlock, |
|
) |
|
|
|
ELU_PARAMS: Dict[str, Any] = {"alpha": 1.0} |
|
|
|
|
|
class PretsselEncoderFrontend(Module): |
|
""" |
|
Represent Encoder frontend, including the prosody encoder and language embedding |
|
""" |
|
|
|
prosody_encoder: ECAPA_TDNN |
|
embed_tokens: Embedding |
|
embed_positions: PositionEncoder |
|
pos_emb_alpha: Parameter |
|
embed_lang: Embedding |
|
dropout: Dropout |
|
|
|
def __init__( |
|
self, |
|
prosody_encoder: ECAPA_TDNN, |
|
embed_tokens: Embedding, |
|
embed_positions: PositionEncoder, |
|
lang_to_index: Dict[str, int], |
|
lang_embed_dim: Optional[int], |
|
dropout_p: float, |
|
device: Optional[Device] = None, |
|
dtype: Optional[DataType] = None, |
|
): |
|
super().__init__() |
|
|
|
self.prosody_encoder = prosody_encoder |
|
|
|
self.embed_tokens = embed_tokens |
|
|
|
self.embed_positions = embed_positions |
|
self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype)) |
|
|
|
self.lang_to_index = lang_to_index |
|
|
|
if lang_embed_dim is not None: |
|
self.embed_lang = StandardEmbedding( |
|
len(lang_to_index), lang_embed_dim, device=device, dtype=dtype |
|
) |
|
else: |
|
self.register_module("embed_lang", None) |
|
|
|
self.dropout = Dropout(dropout_p) |
|
|
|
self.device = device |
|
self.dtype = dtype |
|
|
|
def forward( |
|
self, |
|
seqs: torch.Tensor, |
|
padding_mask: Optional[PaddingMask], |
|
prosody_input_seqs: torch.Tensor, |
|
prosody_padding_mask: Optional[PaddingMask], |
|
tgt_lang: str, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
prosody_embs = self.prosody_encoder( |
|
prosody_input_seqs, |
|
prosody_padding_mask, |
|
).unsqueeze(1) |
|
|
|
if self.embed_lang is not None: |
|
lang_index = self.lang_to_index[tgt_lang] |
|
lang_index_tensor = ( |
|
torch.Tensor([lang_index]).to(seqs).repeat(seqs.size(0), 1) |
|
) |
|
lang_embeds = self.embed_lang(lang_index_tensor) |
|
prosody_embs = torch.cat([prosody_embs, lang_embeds], dim=-1) |
|
|
|
seqs = self.embed_tokens(seqs) |
|
seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs) |
|
seqs = self.dropout(seqs) |
|
|
|
return seqs, prosody_embs |
|
|
|
|
|
class PretsselDecoderFrontend(Module): |
|
"""Represent Decoder frontend, including VarianceAdaptor & Positional embedding""" |
|
|
|
variance_adaptor: VarianceAdaptor |
|
embed_positions: PositionEncoder |
|
pos_emb_alpha: Parameter |
|
|
|
def __init__( |
|
self, |
|
variance_adaptor: VarianceAdaptor, |
|
embed_positions: PositionEncoder, |
|
device: Optional[Device] = None, |
|
dtype: Optional[DataType] = None, |
|
): |
|
super().__init__() |
|
|
|
self.variance_adaptor = variance_adaptor |
|
self.embed_positions = embed_positions |
|
self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype)) |
|
|
|
self.device = device |
|
self.dtype = dtype |
|
|
|
def forward( |
|
self, |
|
seqs: torch.Tensor, |
|
padding_mask: PaddingMask, |
|
durations: Optional[torch.Tensor] = None, |
|
duration_factor: float = 1.0, |
|
min_duration: int = 0, |
|
film_cond_emb: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, PaddingMask]: |
|
seqs, padding_mask, _ = self.variance_adaptor( |
|
seqs, padding_mask, durations, duration_factor, min_duration, film_cond_emb |
|
) |
|
|
|
seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs) |
|
|
|
return seqs, padding_mask |
|
|
|
|
|
class PretsselVocoder(Module): |
|
"""The expressivity-preserving vocoder""" |
|
|
|
encoder_frontend: PretsselEncoderFrontend |
|
encoder: FeedForwardTransformer |
|
decoder_frontend: PretsselDecoderFrontend |
|
decoder: FeedForwardTransformer |
|
final_proj: Projection |
|
|
|
def __init__( |
|
self, |
|
encoder_frontend: PretsselEncoderFrontend, |
|
encoder: FeedForwardTransformer, |
|
decoder_frontend: PretsselDecoderFrontend, |
|
decoder: FeedForwardTransformer, |
|
final_proj: Projection, |
|
pn_n_channels: int, |
|
pn_kernel_size: int, |
|
pn_layers: int, |
|
pn_dropout: float, |
|
upsample_rates: List[int], |
|
upsample_kernel_sizes: List[int], |
|
upsample_initial_channel: int, |
|
resblock_kernel_sizes: List[int], |
|
resblock_dilation_sizes: List[List[int]], |
|
mel_dim: int = 80, |
|
add_ups_out_pad: bool = True, |
|
channels: int = 1, |
|
dimension: int = 128, |
|
n_filters: int = 32, |
|
ratios: List[int] = [8, 5, 4, 2], |
|
norm: Literal[ |
|
"none", "weight_norm", "spectral_norm", "time_group_norm" |
|
] = "none", |
|
norm_params: Dict[str, Any] = {}, |
|
kernel_size: int = 7, |
|
last_kernel_size: int = 7, |
|
residual_kernel_size: int = 3, |
|
causal: bool = False, |
|
pad_mode: str = "constant", |
|
true_skip: bool = True, |
|
compress: int = 2, |
|
lstm: int = 0, |
|
disable_norm_outer_blocks: int = 0, |
|
trim_right_ratio: float = 1.0, |
|
gcmvn_mean: Optional[List[float]] = None, |
|
gcmvn_std: Optional[List[float]] = None, |
|
device: Optional[Device] = None, |
|
dtype: Optional[DataType] = None, |
|
): |
|
super().__init__() |
|
self.encoder_frontend = encoder_frontend |
|
self.encoder = encoder |
|
self.decoder_frontend = decoder_frontend |
|
self.decoder = decoder |
|
self.final_proj = final_proj |
|
mult = 1 |
|
stream_layers: List[Module] = [ |
|
StreamableConv1d( |
|
channels, |
|
mult * n_filters, |
|
kernel_size, |
|
norm="none" if disable_norm_outer_blocks >= 1 else norm, |
|
norm_kwargs=norm_params, |
|
causal=causal, |
|
pad_mode=pad_mode, |
|
activation=Tanh(), |
|
device=device, |
|
dtype=dtype, |
|
) |
|
] |
|
|
|
for i, ratio in enumerate(list(reversed(ratios))): |
|
block_norm = "none" if disable_norm_outer_blocks >= i + 2 else norm |
|
stream_layers.append( |
|
StreamableResnetBlock( |
|
mult * n_filters, |
|
kernel_sizes=[residual_kernel_size, 1], |
|
dilations=[1, 1], |
|
norm=block_norm, |
|
norm_params=norm_params, |
|
causal=causal, |
|
pad_mode=pad_mode, |
|
compress=compress, |
|
true_skip=true_skip, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
stream_layers.append(ELU(**ELU_PARAMS)) |
|
stream_layers.append( |
|
StreamableConv1d( |
|
mult * n_filters, |
|
mult * n_filters * 2, |
|
kernel_size=ratio * 2, |
|
stride=ratio, |
|
norm=block_norm, |
|
norm_kwargs=norm_params, |
|
causal=causal, |
|
pad_mode=pad_mode, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
mult *= 2 |
|
|
|
stream_layers.append(StreamableLSTM(mult * n_filters, num_layers=lstm)) |
|
stream_layers.append(ELU(**ELU_PARAMS)) |
|
n_blocks = len(ratios) + 2 |
|
stream_layers.append( |
|
StreamableConv1d( |
|
mult * n_filters, |
|
dimension, |
|
last_kernel_size, |
|
norm="none" if disable_norm_outer_blocks == n_blocks else norm, |
|
norm_kwargs=norm_params, |
|
causal=causal, |
|
pad_mode=pad_mode, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
stream_layers.append( |
|
StreamableConv1d( |
|
dimension, |
|
mult * n_filters, |
|
kernel_size, |
|
norm="none" if disable_norm_outer_blocks == n_blocks else norm, |
|
norm_kwargs=norm_params, |
|
causal=causal, |
|
pad_mode=pad_mode, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
stream_layers.append( |
|
StreamableLSTM( |
|
mult * n_filters, num_layers=lstm, device=device, dtype=dtype |
|
) |
|
) |
|
|
|
|
|
for i, ratio in enumerate(ratios): |
|
block_norm = ( |
|
"none" if disable_norm_outer_blocks >= n_blocks - (i + 1) else norm |
|
) |
|
stream_layers.append(ELU(**ELU_PARAMS)) |
|
stream_layers.append( |
|
StreamableConvTranspose1d( |
|
mult * n_filters, |
|
mult * n_filters // 2, |
|
kernel_size=ratio * 2, |
|
stride=ratio, |
|
norm=block_norm, |
|
norm_kwargs=norm_params, |
|
causal=causal, |
|
trim_right_ratio=trim_right_ratio, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
stream_layers.append( |
|
StreamableResnetBlock( |
|
mult * n_filters // 2, |
|
kernel_sizes=[residual_kernel_size, 1], |
|
dilations=[1, 1], |
|
norm=block_norm, |
|
norm_params=norm_params, |
|
activation_params=ELU_PARAMS, |
|
causal=causal, |
|
pad_mode=pad_mode, |
|
compress=compress, |
|
true_skip=true_skip, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
mult //= 2 |
|
|
|
stream_layers.append(ELU(**ELU_PARAMS)) |
|
stream_layers.append( |
|
StreamableConv1d( |
|
n_filters, |
|
channels, |
|
last_kernel_size, |
|
norm="none" if disable_norm_outer_blocks >= 1 else norm, |
|
norm_kwargs=norm_params, |
|
causal=causal, |
|
pad_mode=pad_mode, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
self.n_streams = len(stream_layers) |
|
chunk_size = self.n_streams // 4 |
|
stream_idx = 0 |
|
|
|
self.pn_layers = pn_layers |
|
self.layers = ModuleList() |
|
assert pn_kernel_size % 2 == 1 |
|
for i in range(pn_layers): |
|
cur_layers = ( |
|
[ |
|
Conv1d( |
|
mel_dim if i == 0 else pn_n_channels, |
|
pn_n_channels if i < pn_layers - 1 else mel_dim, |
|
kernel_size=pn_kernel_size, |
|
padding="same", |
|
device=device, |
|
dtype=dtype, |
|
), |
|
BatchNorm1d( |
|
pn_n_channels if i < pn_layers - 1 else mel_dim, |
|
device=device, |
|
dtype=dtype, |
|
), |
|
] |
|
+ ([Tanh()] if i < pn_layers - 1 else []) |
|
+ [Dropout(pn_dropout)] |
|
) |
|
self.layers.append(Sequential(*cur_layers)) |
|
self.reset_parameters() |
|
self.layers.extend(stream_layers[:chunk_size]) |
|
stream_idx += chunk_size |
|
self.layers.append( |
|
weight_norm( |
|
Conv1d( |
|
mel_dim if mel_dim is not None else 80, |
|
upsample_initial_channel, |
|
7, |
|
1, |
|
padding="same", |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
) |
|
self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size]) |
|
stream_idx += chunk_size |
|
|
|
self.num_kernels = len(resblock_kernel_sizes) |
|
self.num_upsamples = len(upsample_rates) |
|
ups = ModuleList() |
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): |
|
out_pad = u % 2 if add_ups_out_pad else 0 |
|
ups.append( |
|
weight_norm( |
|
ConvTranspose1d( |
|
upsample_initial_channel // (2**i), |
|
upsample_initial_channel // (2 ** (i + 1)), |
|
k, |
|
u, |
|
padding=(k - u) // 2 + out_pad, |
|
output_padding=out_pad, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
) |
|
) |
|
ups.apply(init_weights) |
|
self.layers.extend(ups) |
|
self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size]) |
|
stream_idx += chunk_size |
|
|
|
for i in range(self.num_upsamples): |
|
ch = upsample_initial_channel // (2 ** (i + 1)) |
|
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): |
|
self.layers.append( |
|
ResBlock( |
|
ch, |
|
k, |
|
d, |
|
).to(device, dtype=dtype) |
|
) |
|
self.layers.extend(stream_layers[stream_idx:]) |
|
|
|
conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) |
|
conv_post.apply(init_weights) |
|
self.layers.append(conv_post) |
|
for u, k in zip(upsample_rates, upsample_kernel_sizes): |
|
assert k == 2 * u, (k, u) |
|
|
|
mean = torch.zeros((mel_dim,), dtype=torch.float) |
|
scale = torch.zeros((mel_dim,), dtype=torch.float) |
|
self.register_buffer("mean", mean) |
|
self.register_buffer("scale", scale) |
|
|
|
self.gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) |
|
self.gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) |
|
|
|
def reset_parameters(self) -> None: |
|
for i in range(self.pn_layers): |
|
init.xavier_uniform_( |
|
self.layers[i][0].weight, |
|
init.calculate_gain("tanh" if i < self.pn_layers - 1 else "linear"), |
|
) |
|
|
|
def gcmvn_denormalize(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.gcmvn_mean is None or self.gcmvn_std is None: |
|
raise ValueError("gcmvn_mean is not set") |
|
|
|
assert ( |
|
x.ndim == 3 |
|
and x.shape[2] == self.gcmvn_mean.shape[0] |
|
and x.shape[2] == self.gcmvn_std.shape[0] |
|
) |
|
gcmvn_mean = self.gcmvn_mean.to(x) |
|
gcmvn_std = self.gcmvn_std.to(x) |
|
x = x * gcmvn_std.view(1, 1, -1).expand_as(x) |
|
return x + gcmvn_mean.view(1, 1, -1).expand_as(x) |
|
|
|
def forward( |
|
self, |
|
seqs: torch.Tensor, |
|
tgt_lang: str, |
|
prosody_input_seqs: torch.Tensor, |
|
padding_mask: Optional[PaddingMask] = None, |
|
prosody_padding_mask: Optional[PaddingMask] = None, |
|
durations: Optional[torch.Tensor] = None, |
|
duration_factor: float = 1.0, |
|
min_duration: int = 0, |
|
normalize_before: bool = True, |
|
) -> List[torch.Tensor]: |
|
|
|
if seqs.ndim < 2: |
|
seqs = seqs.unsqueeze(0) |
|
if prosody_input_seqs.ndim < 3: |
|
prosody_input_seqs = prosody_input_seqs.unsqueeze(0) |
|
seqs, cond_embs = self.encoder_frontend( |
|
seqs, |
|
padding_mask, |
|
prosody_input_seqs, |
|
prosody_padding_mask, |
|
tgt_lang, |
|
) |
|
seqs, padding_mask = self.encoder(seqs, padding_mask, cond_embs) |
|
seqs, padding_mask = self.decoder_frontend( |
|
seqs, padding_mask, durations, duration_factor, min_duration, cond_embs |
|
) |
|
seqs, padding_mask = self.decoder(seqs, padding_mask, cond_embs) |
|
seqs = self.final_proj(seqs) |
|
|
|
pn = seqs.transpose(1, 2) |
|
for i in range(self.pn_layers): |
|
pn = self.layers[i](pn) |
|
pn = pn.transpose(1, 2) |
|
|
|
x = seqs + pn |
|
x = self.gcmvn_denormalize(x) |
|
|
|
wavs = [] |
|
for idx, _x in enumerate(x): |
|
_x = _x[: durations[idx].sum()] |
|
if normalize_before: |
|
_x = (_x - self.mean) / self.scale |
|
|
|
_x = _x.transpose(1, 0).unsqueeze(0) |
|
chunk_size = self.n_streams // 4 |
|
_x = self.layers[self.pn_layers + chunk_size](_x) |
|
for i in range(self.num_upsamples): |
|
_x = F.leaky_relu(_x, LRELU_SLOPE) |
|
_x = self.layers[i + self.pn_layers + 1 + 2 * chunk_size](_x) |
|
xs = None |
|
for j in range(self.num_kernels): |
|
if xs is None: |
|
xs = self.layers[ |
|
i * self.num_kernels |
|
+ j |
|
+ self.pn_layers |
|
+ 3 * chunk_size |
|
+ self.num_upsamples |
|
+ 1 |
|
](_x) |
|
else: |
|
xs += self.layers[ |
|
i * self.num_kernels |
|
+ j |
|
+ self.pn_layers |
|
+ 3 * chunk_size |
|
+ self.num_upsamples |
|
+ 1 |
|
](_x) |
|
_x = xs / self.num_kernels |
|
_x = F.leaky_relu(_x) |
|
_x = self.layers[ |
|
self.pn_layers |
|
+ self.n_streams |
|
+ self.num_upsamples * (1 + self.num_kernels) |
|
+ 1 |
|
](_x) |
|
skip_output = _x |
|
h = skip_output |
|
|
|
for i1 in range(self.pn_layers, self.pn_layers + chunk_size): |
|
h = self.layers[i1](h) |
|
i1 += 2 |
|
for i2 in range(i1, i1 + chunk_size): |
|
h = self.layers[i2](h) |
|
i2 = i2 + self.num_upsamples + 1 |
|
|
|
for i3 in range(i2, i2 + chunk_size): |
|
h = self.layers[i3](h) |
|
i3 = i3 + (self.num_upsamples * self.num_kernels) + 1 |
|
for i4 in range(i3, i3 + chunk_size): |
|
h = self.layers[i4](h) |
|
h = h[:, :, : _x.size(-1)] |
|
|
|
wavs.append(0.8 * h + torch.tanh(skip_output).squeeze(0)) |
|
return wavs |
|
|
|
def remove_weight_norm(self) -> None: |
|
i = self.pn_layers + 1 |
|
for j in range(self.num_upsamples): |
|
remove_weight_norm(self.layers[i + j]) |
|
for k in range(self.num_upsamples * self.num_kernels): |
|
self.layers[i + j + k + 1].remove_weight_norm() |
|
remove_weight_norm(self.layers[self.pn_layers]) |
|
remove_weight_norm( |
|
self.layers[ |
|
self.pn_layers + 1 + self.num_upsamples * (1 + self.num_kernels) |
|
] |
|
) |
|
|