|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import os.path as op |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator |
|
from fairseq.tasks import register_task |
|
from fairseq.tasks.speech_to_text import SpeechToTextTask |
|
from fairseq.speech_generator import ( |
|
AutoRegressiveSpeechGenerator, |
|
NonAutoregressiveSpeechGenerator, |
|
TeacherForcingAutoRegressiveSpeechGenerator, |
|
) |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
from tensorboardX import SummaryWriter |
|
except ImportError: |
|
logger.info("Please install tensorboardX: pip install tensorboardX") |
|
SummaryWriter = None |
|
|
|
|
|
@register_task("text_to_speech") |
|
class TextToSpeechTask(SpeechToTextTask): |
|
@staticmethod |
|
def add_args(parser): |
|
parser.add_argument("data", help="manifest root path") |
|
parser.add_argument( |
|
"--config-yaml", |
|
type=str, |
|
default="config.yaml", |
|
help="Configuration YAML filename (under manifest root)", |
|
) |
|
parser.add_argument( |
|
"--max-source-positions", |
|
default=1024, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the source sequence", |
|
) |
|
parser.add_argument( |
|
"--max-target-positions", |
|
default=1200, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the target sequence", |
|
) |
|
parser.add_argument("--n-frames-per-step", type=int, default=1) |
|
parser.add_argument("--eos-prob-threshold", type=float, default=0.5) |
|
parser.add_argument("--eval-inference", action="store_true") |
|
parser.add_argument("--eval-tb-nsample", type=int, default=8) |
|
parser.add_argument("--vocoder", type=str, default="griffin_lim") |
|
parser.add_argument("--spec-bwd-max-iter", type=int, default=8) |
|
|
|
def __init__(self, args, src_dict): |
|
super().__init__(args, src_dict) |
|
self.src_dict = src_dict |
|
self.sr = self.data_cfg.config.get("features").get("sample_rate") |
|
|
|
self.tensorboard_writer = None |
|
self.tensorboard_dir = "" |
|
if args.tensorboard_logdir and SummaryWriter is not None: |
|
self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra") |
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
|
is_train_split = split.startswith("train") |
|
pre_tokenizer = self.build_tokenizer(self.args) |
|
bpe_tokenizer = self.build_bpe(self.args) |
|
self.datasets[split] = TextToSpeechDatasetCreator.from_tsv( |
|
self.args.data, |
|
self.data_cfg, |
|
split, |
|
self.src_dict, |
|
pre_tokenizer, |
|
bpe_tokenizer, |
|
is_train_split=is_train_split, |
|
epoch=epoch, |
|
seed=self.args.seed, |
|
n_frames_per_step=self.args.n_frames_per_step, |
|
speaker_to_id=self.speaker_to_id, |
|
) |
|
|
|
@property |
|
def target_dictionary(self): |
|
return None |
|
|
|
@property |
|
def source_dictionary(self): |
|
return self.src_dict |
|
|
|
def get_speaker_embeddings_path(self): |
|
speaker_emb_path = None |
|
if self.data_cfg.config.get("speaker_emb_filename") is not None: |
|
speaker_emb_path = op.join( |
|
self.args.data, self.data_cfg.config.get("speaker_emb_filename") |
|
) |
|
return speaker_emb_path |
|
|
|
@classmethod |
|
def get_speaker_embeddings(cls, args): |
|
embed_speaker = None |
|
if args.speaker_to_id is not None: |
|
if args.speaker_emb_path is None: |
|
embed_speaker = torch.nn.Embedding( |
|
len(args.speaker_to_id), args.speaker_embed_dim |
|
) |
|
else: |
|
speaker_emb_mat = np.load(args.speaker_emb_path) |
|
assert speaker_emb_mat.shape[1] == args.speaker_embed_dim |
|
embed_speaker = torch.nn.Embedding.from_pretrained( |
|
torch.from_numpy(speaker_emb_mat), |
|
freeze=True, |
|
) |
|
logger.info( |
|
f"load speaker embeddings from {args.speaker_emb_path}. " |
|
f"train embedding? {embed_speaker.weight.requires_grad}\n" |
|
f"embeddings:\n{speaker_emb_mat}" |
|
) |
|
return embed_speaker |
|
|
|
def build_model(self, cfg, from_checkpoint=False): |
|
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None) |
|
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None) |
|
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None) |
|
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None) |
|
cfg.speaker_emb_path = self.get_speaker_embeddings_path() |
|
model = super().build_model(cfg, from_checkpoint) |
|
self.generator = None |
|
if getattr(cfg, "eval_inference", False): |
|
self.generator = self.build_generator([model], cfg) |
|
return model |
|
|
|
def build_generator(self, models, cfg, vocoder=None, **unused): |
|
if vocoder is None: |
|
vocoder = self.build_default_vocoder() |
|
model = models[0] |
|
if getattr(model, "NON_AUTOREGRESSIVE", False): |
|
return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg) |
|
else: |
|
generator = AutoRegressiveSpeechGenerator |
|
if getattr(cfg, "teacher_forcing", False): |
|
generator = TeacherForcingAutoRegressiveSpeechGenerator |
|
logger.info("Teacher forcing mode for generation") |
|
return generator( |
|
model, |
|
vocoder, |
|
self.data_cfg, |
|
max_iter=self.args.max_target_positions, |
|
eos_prob_threshold=self.args.eos_prob_threshold, |
|
) |
|
|
|
def build_default_vocoder(self): |
|
from fairseq.models.text_to_speech.vocoder import get_vocoder |
|
|
|
vocoder = get_vocoder(self.args, self.data_cfg) |
|
if torch.cuda.is_available() and not self.args.cpu: |
|
vocoder = vocoder.cuda() |
|
else: |
|
vocoder = vocoder.cpu() |
|
return vocoder |
|
|
|
def valid_step(self, sample, model, criterion): |
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
|
|
|
if getattr(self.args, "eval_inference", False): |
|
hypos, inference_losses = self.valid_step_with_inference( |
|
sample, model, self.generator |
|
) |
|
for k, v in inference_losses.items(): |
|
assert k not in logging_output |
|
logging_output[k] = v |
|
|
|
picked_id = 0 |
|
if self.tensorboard_dir and (sample["id"] == picked_id).any(): |
|
self.log_tensorboard( |
|
sample, |
|
hypos[: self.args.eval_tb_nsample], |
|
model._num_updates, |
|
is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False), |
|
) |
|
return loss, sample_size, logging_output |
|
|
|
def valid_step_with_inference(self, sample, model, generator): |
|
hypos = generator.generate(model, sample, has_targ=True) |
|
|
|
losses = { |
|
"mcd_loss": 0.0, |
|
"targ_frames": 0.0, |
|
"pred_frames": 0.0, |
|
"nins": 0.0, |
|
"ndel": 0.0, |
|
} |
|
rets = batch_mel_cepstral_distortion( |
|
[hypo["targ_waveform"] for hypo in hypos], |
|
[hypo["waveform"] for hypo in hypos], |
|
self.sr, |
|
normalize_type=None, |
|
) |
|
for d, extra in rets: |
|
pathmap = extra[-1] |
|
losses["mcd_loss"] += d.item() |
|
losses["targ_frames"] += pathmap.size(0) |
|
losses["pred_frames"] += pathmap.size(1) |
|
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() |
|
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() |
|
|
|
return hypos, losses |
|
|
|
def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False): |
|
if self.tensorboard_writer is None: |
|
self.tensorboard_writer = SummaryWriter(self.tensorboard_dir) |
|
tb_writer = self.tensorboard_writer |
|
for b in range(len(hypos)): |
|
idx = sample["id"][b] |
|
text = sample["src_texts"][b] |
|
targ = hypos[b]["targ_feature"] |
|
pred = hypos[b]["feature"] |
|
attn = hypos[b]["attn"] |
|
|
|
if is_na_model: |
|
data = plot_tts_output( |
|
[targ.transpose(0, 1), pred.transpose(0, 1)], |
|
[f"target (idx={idx})", "output"], |
|
attn, |
|
"alignment", |
|
ret_np=True, |
|
suptitle=text, |
|
) |
|
else: |
|
eos_prob = hypos[b]["eos_prob"] |
|
data = plot_tts_output( |
|
[targ.transpose(0, 1), pred.transpose(0, 1), attn], |
|
[f"target (idx={idx})", "output", "alignment"], |
|
eos_prob, |
|
"eos prob", |
|
ret_np=True, |
|
suptitle=text, |
|
) |
|
|
|
tb_writer.add_image( |
|
f"inference_sample_{b}", data, num_updates, dataformats="HWC" |
|
) |
|
|
|
if hypos[b]["waveform"] is not None: |
|
targ_wave = hypos[b]["targ_waveform"].detach().cpu().float() |
|
pred_wave = hypos[b]["waveform"].detach().cpu().float() |
|
tb_writer.add_audio( |
|
f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr |
|
) |
|
tb_writer.add_audio( |
|
f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr |
|
) |
|
|
|
|
|
def save_figure_to_numpy(fig): |
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
return data |
|
|
|
|
|
DEFAULT_V_MIN = np.log(1e-5) |
|
|
|
|
|
def plot_tts_output( |
|
data_2d, |
|
title_2d, |
|
data_1d, |
|
title_1d, |
|
figsize=(24, 4), |
|
v_min=DEFAULT_V_MIN, |
|
v_max=3, |
|
ret_np=False, |
|
suptitle="", |
|
): |
|
try: |
|
import matplotlib.pyplot as plt |
|
from mpl_toolkits.axes_grid1 import make_axes_locatable |
|
except ImportError: |
|
raise ImportError("Please install Matplotlib: pip install matplotlib") |
|
|
|
data_2d = [ |
|
x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x |
|
for x in data_2d |
|
] |
|
fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize) |
|
if suptitle: |
|
fig.suptitle(suptitle[:400]) |
|
axes = [axes] if len(data_2d) == 0 else axes |
|
for ax, x, name in zip(axes, data_2d, title_2d): |
|
ax.set_title(name) |
|
divider = make_axes_locatable(ax) |
|
cax = divider.append_axes("right", size="5%", pad=0.05) |
|
im = ax.imshow( |
|
x, |
|
origin="lower", |
|
aspect="auto", |
|
vmin=max(x.min(), v_min), |
|
vmax=min(x.max(), v_max), |
|
) |
|
fig.colorbar(im, cax=cax, orientation="vertical") |
|
|
|
if isinstance(data_1d, torch.Tensor): |
|
data_1d = data_1d.detach().cpu().numpy() |
|
axes[-1].plot(data_1d) |
|
axes[-1].set_title(title_1d) |
|
plt.tight_layout() |
|
|
|
if ret_np: |
|
fig.canvas.draw() |
|
data = save_figure_to_numpy(fig) |
|
plt.close(fig) |
|
return data |
|
|
|
|
|
def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None): |
|
""" |
|
for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs |
|
|
|
offset=2 (1, 1), |
|
offset=3 (2, 1), (1, 2) |
|
offset=4 (2, 2), (1, 3) |
|
offset=5 (2, 3) |
|
|
|
constraints: |
|
i + j = offset |
|
min_j <= j < max_j |
|
min_i <= offset - j < max_i |
|
""" |
|
if max_i is None: |
|
max_i = offset + 1 |
|
if max_j is None: |
|
max_j = offset + 1 |
|
min_j = max(min_j, offset - max_i + 1, 0) |
|
max_j = min(max_j, offset - min_i + 1, offset + 1) |
|
j = torch.arange(min_j, max_j) |
|
i = offset - j |
|
return torch.stack([i, j]) |
|
|
|
|
|
def batch_dynamic_time_warping(distance, shapes=None): |
|
"""full batched DTW without any constraints |
|
|
|
distance: (batchsize, max_M, max_N) matrix |
|
shapes: (batchsize,) vector specifying (M, N) for each entry |
|
""" |
|
|
|
ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)} |
|
|
|
bsz, m, n = distance.size() |
|
cumdist = torch.zeros_like(distance) |
|
backptr = torch.zeros_like(distance).type(torch.int32) - 1 |
|
|
|
|
|
cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1) |
|
cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1) |
|
backptr[:, 0, :] = 0 |
|
backptr[:, :, 0] = 2 |
|
|
|
|
|
for offset in range(2, m + n - 1): |
|
ind = antidiag_indices(offset, 1, m, 1, n) |
|
c = torch.stack( |
|
[ |
|
cumdist[:, ind[0], ind[1] - 1], |
|
cumdist[:, ind[0] - 1, ind[1] - 1], |
|
cumdist[:, ind[0] - 1, ind[1]], |
|
], |
|
dim=2, |
|
) |
|
v, b = c.min(axis=-1) |
|
backptr[:, ind[0], ind[1]] = b.int() |
|
cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]] |
|
|
|
|
|
pathmap = torch.zeros_like(backptr) |
|
for b in range(bsz): |
|
i = m - 1 if shapes is None else (shapes[b][0] - 1).item() |
|
j = n - 1 if shapes is None else (shapes[b][1] - 1).item() |
|
dtwpath = [(i, j)] |
|
while (i != 0 or j != 0) and len(dtwpath) < 10000: |
|
assert i >= 0 and j >= 0 |
|
di, dj = ptr2dij[backptr[b, i, j].item()] |
|
i, j = i + di, j + dj |
|
dtwpath.append((i, j)) |
|
dtwpath = dtwpath[::-1] |
|
indices = torch.from_numpy(np.array(dtwpath)) |
|
pathmap[b, indices[:, 0], indices[:, 1]] = 1 |
|
|
|
return cumdist, backptr, pathmap |
|
|
|
|
|
def compute_l2_dist(x1, x2): |
|
"""compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices""" |
|
return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2) |
|
|
|
|
|
def compute_rms_dist(x1, x2): |
|
l2_dist = compute_l2_dist(x1, x2) |
|
return (l2_dist / x1.size(1)).pow(0.5) |
|
|
|
|
|
def get_divisor(pathmap, normalize_type): |
|
if normalize_type is None: |
|
return 1 |
|
elif normalize_type == "len1": |
|
return pathmap.size(0) |
|
elif normalize_type == "len2": |
|
return pathmap.size(1) |
|
elif normalize_type == "path": |
|
return pathmap.sum().item() |
|
else: |
|
raise ValueError(f"normalize_type {normalize_type} not supported") |
|
|
|
|
|
def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): |
|
d, s, x1, x2 = [], [], [], [] |
|
for cur_y1, cur_y2 in zip(y1, y2): |
|
assert cur_y1.ndim == 1 and cur_y2.ndim == 1 |
|
cur_x1 = feat_fn(cur_y1) |
|
cur_x2 = feat_fn(cur_y2) |
|
x1.append(cur_x1) |
|
x2.append(cur_x2) |
|
|
|
cur_d = dist_fn(cur_x1, cur_x2) |
|
d.append(cur_d) |
|
s.append(d[-1].size()) |
|
max_m = max(ss[0] for ss in s) |
|
max_n = max(ss[1] for ss in s) |
|
d = torch.stack( |
|
[F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d] |
|
) |
|
s = torch.LongTensor(s).to(d.device) |
|
cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s) |
|
|
|
rets = [] |
|
itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps) |
|
for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr: |
|
cumdist = cumdist[:m, :n] |
|
backptr = backptr[:m, :n] |
|
pathmap = pathmap[:m, :n] |
|
divisor = get_divisor(pathmap, normalize_type) |
|
|
|
distortion = cumdist[-1, -1] / divisor |
|
ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap) |
|
rets.append(ret) |
|
return rets |
|
|
|
|
|
def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None): |
|
""" |
|
https://arxiv.org/pdf/2011.03568.pdf |
|
|
|
The root mean squared error computed on 13-dimensional MFCC using DTW for |
|
alignment. MFCC features are computed from an 80-channel log-mel |
|
spectrogram using a 50ms Hann window and hop of 12.5ms. |
|
|
|
y1: list of waveforms |
|
y2: list of waveforms |
|
sr: sampling rate |
|
""" |
|
|
|
try: |
|
import torchaudio |
|
except ImportError: |
|
raise ImportError("Please install torchaudio: pip install torchaudio") |
|
|
|
if mfcc_fn is None or mfcc_fn.sample_rate != sr: |
|
melkwargs = { |
|
"n_fft": int(0.05 * sr), |
|
"win_length": int(0.05 * sr), |
|
"hop_length": int(0.0125 * sr), |
|
"f_min": 20, |
|
"n_mels": 80, |
|
"window_fn": torch.hann_window, |
|
} |
|
mfcc_fn = torchaudio.transforms.MFCC( |
|
sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs |
|
).to(y1[0].device) |
|
return batch_compute_distortion( |
|
y1, |
|
y2, |
|
sr, |
|
lambda y: mfcc_fn(y).transpose(-1, -2), |
|
compute_rms_dist, |
|
normalize_type, |
|
) |
|
|