Spaces:
Running
Running
# Copyright (c) 2024 Amphion. | |
# | |
# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/alignments.py | |
# Licensed under Apache License 2.0 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Tuple | |
from numba import jit | |
from scipy.stats import betabinom | |
class AlignmentModule(nn.Module): | |
"""Alignment Learning Framework proposed for parallel TTS models in: | |
https://arxiv.org/abs/2108.10447 | |
""" | |
def __init__(self, adim, odim, cache_prior=True): | |
"""Initialize AlignmentModule. | |
Args: | |
adim (int): Dimension of attention. | |
odim (int): Dimension of feats. | |
cache_prior (bool): Whether to cache beta-binomial prior. | |
""" | |
super().__init__() | |
self.cache_prior = cache_prior | |
self._cache = {} | |
self.t_conv1 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) | |
self.t_conv2 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) | |
self.f_conv1 = nn.Conv1d(odim, adim, kernel_size=3, padding=1) | |
self.f_conv2 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) | |
self.f_conv3 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) | |
def forward(self, text, feats, text_lengths, feats_lengths, x_masks=None): | |
"""Calculate alignment loss. | |
Args: | |
text (Tensor): Batched text embedding (B, T_text, adim). | |
feats (Tensor): Batched acoustic feature (B, T_feats, odim). | |
text_lengths (Tensor): Text length tensor (B,). | |
feats_lengths (Tensor): Feature length tensor (B,). | |
x_masks (Tensor): Mask tensor (B, T_text). | |
Returns: | |
Tensor: Log probability of attention matrix (B, T_feats, T_text). | |
""" | |
text = text.transpose(1, 2) | |
text = F.relu(self.t_conv1(text)) | |
text = self.t_conv2(text) | |
text = text.transpose(1, 2) | |
feats = feats.transpose(1, 2) | |
feats = F.relu(self.f_conv1(feats)) | |
feats = F.relu(self.f_conv2(feats)) | |
feats = self.f_conv3(feats) | |
feats = feats.transpose(1, 2) | |
dist = feats.unsqueeze(2) - text.unsqueeze(1) | |
dist = torch.norm(dist, p=2, dim=3) | |
score = -dist | |
if x_masks is not None: | |
x_masks = x_masks.unsqueeze(-2) | |
score = score.masked_fill(x_masks, -np.inf) | |
log_p_attn = F.log_softmax(score, dim=-1) | |
# add beta-binomial prior | |
bb_prior = self._generate_prior( | |
text_lengths, | |
feats_lengths, | |
).to(dtype=log_p_attn.dtype, device=log_p_attn.device) | |
log_p_attn = log_p_attn + bb_prior | |
return log_p_attn | |
def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor: | |
"""Generate alignment prior formulated as beta-binomial distribution | |
Args: | |
text_lengths (Tensor): Batch of the lengths of each input (B,). | |
feats_lengths (Tensor): Batch of the lengths of each target (B,). | |
w (float): Scaling factor; lower -> wider the width. | |
Returns: | |
Tensor: Batched 2d static prior matrix (B, T_feats, T_text). | |
""" | |
B = len(text_lengths) | |
T_text = text_lengths.max() | |
T_feats = feats_lengths.max() | |
bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf) | |
for bidx in range(B): | |
T = feats_lengths[bidx].item() | |
N = text_lengths[bidx].item() | |
key = str(T) + "," + str(N) | |
if self.cache_prior and key in self._cache: | |
prob = self._cache[key] | |
else: | |
alpha = w * np.arange(1, T + 1, dtype=float) # (T,) | |
beta = w * np.array([T - t + 1 for t in alpha]) | |
k = np.arange(N) | |
batched_k = k[..., None] # (N,1) | |
prob = betabinom.logpmf(batched_k, N, alpha, beta) # (N,T) | |
# store cache | |
if self.cache_prior and key not in self._cache: | |
self._cache[key] = prob | |
prob = torch.from_numpy(prob).transpose(0, 1) # -> (T,N) | |
bb_prior[bidx, :T, :N] = prob | |
return bb_prior | |
def _monotonic_alignment_search(log_p_attn): | |
# https://arxiv.org/abs/2005.11129 | |
T_mel = log_p_attn.shape[0] | |
T_inp = log_p_attn.shape[1] | |
Q = np.full((T_inp, T_mel), fill_value=-np.inf) | |
log_prob = log_p_attn.transpose(1, 0) # -> (T_inp,T_mel) | |
# 1. Q <- init first row for all j | |
for j in range(T_mel): | |
Q[0, j] = log_prob[0, : j + 1].sum() | |
# 2. | |
for j in range(1, T_mel): | |
for i in range(1, min(j + 1, T_inp)): | |
Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j] | |
# 3. | |
A = np.full((T_mel,), fill_value=T_inp - 1) | |
for j in range(T_mel - 2, -1, -1): # T_mel-2, ..., 0 | |
# 'i' in {A[j+1]-1, A[j+1]} | |
i_a = A[j + 1] - 1 | |
i_b = A[j + 1] | |
if i_b == 0: | |
argmax_i = 0 | |
elif Q[i_a, j] >= Q[i_b, j]: | |
argmax_i = i_a | |
else: | |
argmax_i = i_b | |
A[j] = argmax_i | |
return A | |
def viterbi_decode(log_p_attn, text_lengths, feats_lengths): | |
"""Extract duration from an attention probability matrix | |
Args: | |
log_p_attn (Tensor): Batched log probability of attention | |
matrix (B, T_feats, T_text). | |
text_lengths (Tensor): Text length tensor (B,). | |
feats_legnths (Tensor): Feature length tensor (B,). | |
Returns: | |
Tensor: Batched token duration extracted from `log_p_attn` (B, T_text). | |
Tensor: Binarization loss tensor (). | |
""" | |
B = log_p_attn.size(0) | |
T_text = log_p_attn.size(2) | |
device = log_p_attn.device | |
bin_loss = 0 | |
ds = torch.zeros((B, T_text), device=device) | |
for b in range(B): | |
cur_log_p_attn = log_p_attn[b, : feats_lengths[b], : text_lengths[b]] | |
viterbi = _monotonic_alignment_search(cur_log_p_attn.detach().cpu().numpy()) | |
_ds = np.bincount(viterbi) | |
ds[b, : len(_ds)] = torch.from_numpy(_ds).to(device) | |
t_idx = torch.arange(feats_lengths[b]) | |
bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean() | |
bin_loss = bin_loss / B | |
return ds, bin_loss | |
def _average_by_duration(ds, xs, text_lengths, feats_lengths): | |
B = ds.shape[0] | |
xs_avg = np.zeros_like(ds) | |
ds = ds.astype(np.int32) | |
for b in range(B): | |
t_text = text_lengths[b] | |
t_feats = feats_lengths[b] | |
d = ds[b, :t_text] | |
d_cumsum = d.cumsum() | |
d_cumsum = [0] + list(d_cumsum) | |
x = xs[b, :t_feats] | |
for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])): | |
if len(x[start:end]) != 0: | |
xs_avg[b, n] = x[start:end].mean() | |
else: | |
xs_avg[b, n] = 0 | |
return xs_avg | |
def average_by_duration(ds, xs, text_lengths, feats_lengths): | |
"""Average frame-level features into token-level according to durations | |
Args: | |
ds (Tensor): Batched token duration (B, T_text). | |
xs (Tensor): Batched feature sequences to be averaged (B, T_feats). | |
text_lengths (Tensor): Text length tensor (B,). | |
feats_lengths (Tensor): Feature length tensor (B,). | |
Returns: | |
Tensor: Batched feature averaged according to the token duration (B, T_text). | |
""" | |
device = ds.device | |
args = [ds, xs, text_lengths, feats_lengths] | |
args = [arg.detach().cpu().numpy() for arg in args] | |
xs_avg = _average_by_duration(*args) | |
xs_avg = torch.from_numpy(xs_avg).to(device) | |
return xs_avg | |
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): | |
"""Make mask tensor containing indices of padded part. | |
Args: | |
lengths (LongTensor or List): Batch of lengths (B,). | |
xs (Tensor, optional): The reference tensor. | |
If set, masks will be the same shape as this tensor. | |
length_dim (int, optional): Dimension indicator of the above tensor. | |
See the example. | |
Returns: | |
Tensor: Mask tensor containing indices of padded part. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
Examples: | |
With only lengths. | |
>>> lengths = [5, 3, 2] | |
>>> make_pad_mask(lengths) | |
masks = [[0, 0, 0, 0 ,0], | |
[0, 0, 0, 1, 1], | |
[0, 0, 1, 1, 1]] | |
With the reference tensor. | |
>>> xs = torch.zeros((3, 2, 4)) | |
>>> make_pad_mask(lengths, xs) | |
tensor([[[0, 0, 0, 0], | |
[0, 0, 0, 0]], | |
[[0, 0, 0, 1], | |
[0, 0, 0, 1]], | |
[[0, 0, 1, 1], | |
[0, 0, 1, 1]]], dtype=torch.uint8) | |
>>> xs = torch.zeros((3, 2, 6)) | |
>>> make_pad_mask(lengths, xs) | |
tensor([[[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1]], | |
[[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1]], | |
[[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) | |
With the reference tensor and dimension indicator. | |
>>> xs = torch.zeros((3, 6, 6)) | |
>>> make_pad_mask(lengths, xs, 1) | |
tensor([[[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 1]], | |
[[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1]], | |
[[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) | |
>>> make_pad_mask(lengths, xs, 2) | |
tensor([[[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 1]], | |
[[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1]], | |
[[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) | |
""" | |
if length_dim == 0: | |
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) | |
if not isinstance(lengths, list): | |
lengths = lengths.tolist() | |
bs = int(len(lengths)) | |
if maxlen is None: | |
if xs is None: | |
maxlen = int(max(lengths)) | |
else: | |
maxlen = xs.size(length_dim) | |
else: | |
assert xs is None | |
assert maxlen >= int(max(lengths)) | |
seq_range = torch.arange(0, maxlen, dtype=torch.int64) | |
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) | |
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) | |
mask = seq_range_expand >= seq_length_expand | |
if xs is not None: | |
assert xs.size(0) == bs, (xs.size(0), bs) | |
if length_dim < 0: | |
length_dim = xs.dim() + length_dim | |
# ind = (:, None, ..., None, :, , None, ..., None) | |
ind = tuple( | |
slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) | |
) | |
mask = mask[ind].expand_as(xs).to(xs.device) | |
return mask | |
def make_non_pad_mask(lengths, xs=None, length_dim=-1): | |
"""Make mask tensor containing indices of non-padded part. | |
Args: | |
lengths (LongTensor or List): Batch of lengths (B,). | |
xs (Tensor, optional): The reference tensor. | |
If set, masks will be the same shape as this tensor. | |
length_dim (int, optional): Dimension indicator of the above tensor. | |
See the example. | |
Returns: | |
ByteTensor: mask tensor containing indices of padded part. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
Examples: | |
With only lengths. | |
>>> lengths = [5, 3, 2] | |
>>> make_non_pad_mask(lengths) | |
masks = [[1, 1, 1, 1 ,1], | |
[1, 1, 1, 0, 0], | |
[1, 1, 0, 0, 0]] | |
With the reference tensor. | |
>>> xs = torch.zeros((3, 2, 4)) | |
>>> make_non_pad_mask(lengths, xs) | |
tensor([[[1, 1, 1, 1], | |
[1, 1, 1, 1]], | |
[[1, 1, 1, 0], | |
[1, 1, 1, 0]], | |
[[1, 1, 0, 0], | |
[1, 1, 0, 0]]], dtype=torch.uint8) | |
>>> xs = torch.zeros((3, 2, 6)) | |
>>> make_non_pad_mask(lengths, xs) | |
tensor([[[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0]], | |
[[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0]], | |
[[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) | |
With the reference tensor and dimension indicator. | |
>>> xs = torch.zeros((3, 6, 6)) | |
>>> make_non_pad_mask(lengths, xs, 1) | |
tensor([[[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0]], | |
[[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0]], | |
[[1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) | |
>>> make_non_pad_mask(lengths, xs, 2) | |
tensor([[[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1, 0]], | |
[[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0]], | |
[[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) | |
""" | |
return ~make_pad_mask(lengths, xs, length_dim) | |
def get_random_segments( | |
x: torch.Tensor, | |
x_lengths: torch.Tensor, | |
segment_size: int, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Get random segments. | |
Args: | |
x (Tensor): Input tensor (B, C, T). | |
x_lengths (Tensor): Length tensor (B,). | |
segment_size (int): Segment size. | |
Returns: | |
Tensor: Segmented tensor (B, C, segment_size). | |
Tensor: Start index tensor (B,). | |
""" | |
b, c, t = x.size() | |
max_start_idx = x_lengths - segment_size | |
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( | |
dtype=torch.long, | |
) | |
segments = get_segments(x, start_idxs, segment_size) | |
return segments, start_idxs | |
def get_segments( | |
x: torch.Tensor, | |
start_idxs: torch.Tensor, | |
segment_size: int, | |
) -> torch.Tensor: | |
"""Get segments. | |
Args: | |
x (Tensor): Input tensor (B, C, T). | |
start_idxs (Tensor): Start index tensor (B,). | |
segment_size (int): Segment size. | |
Returns: | |
Tensor: Segmented tensor (B, C, segment_size). | |
""" | |
b, c, t = x.size() | |
segments = x.new_zeros(b, c, segment_size) | |
for i, start_idx in enumerate(start_idxs): | |
segments[i] = x[i, :, start_idx : start_idx + segment_size] | |
return segments | |