|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from pathlib import Path |
|
import soundfile as sf |
|
import sys |
|
import torch |
|
import torchaudio |
|
|
|
from fairseq import checkpoint_utils, options, tasks, utils |
|
from fairseq.logging import progress_bar |
|
from fairseq.tasks.text_to_speech import plot_tts_output |
|
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset |
|
|
|
|
|
logging.basicConfig() |
|
logging.root.setLevel(logging.INFO) |
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def make_parser(): |
|
parser = options.get_speech_generation_parser() |
|
parser.add_argument("--dump-features", action="store_true") |
|
parser.add_argument("--dump-waveforms", action="store_true") |
|
parser.add_argument("--dump-attentions", action="store_true") |
|
parser.add_argument("--dump-eos-probs", action="store_true") |
|
parser.add_argument("--dump-plots", action="store_true") |
|
parser.add_argument("--dump-target", action="store_true") |
|
parser.add_argument("--output-sample-rate", default=22050, type=int) |
|
parser.add_argument("--teacher-forcing", action="store_true") |
|
parser.add_argument( |
|
"--audio-format", type=str, default="wav", choices=["wav", "flac"] |
|
) |
|
return parser |
|
|
|
|
|
def postprocess_results( |
|
dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target |
|
): |
|
def to_np(x): |
|
return None if x is None else x.detach().cpu().numpy() |
|
|
|
sample_ids = [dataset.ids[i] for i in sample["id"].tolist()] |
|
texts = sample["src_texts"] |
|
attns = [to_np(hypo["attn"]) for hypo in hypos] |
|
eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos] |
|
feat_preds = [to_np(hypo["feature"]) for hypo in hypos] |
|
wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos] |
|
if dump_target: |
|
feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos] |
|
wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos] |
|
else: |
|
feat_targs = [None for _ in hypos] |
|
wave_targs = [None for _ in hypos] |
|
|
|
return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds, |
|
feat_targs, wave_targs) |
|
|
|
|
|
def dump_result( |
|
is_na_model, |
|
args, |
|
vocoder, |
|
sample_id, |
|
text, |
|
attn, |
|
eos_prob, |
|
feat_pred, |
|
wave_pred, |
|
feat_targ, |
|
wave_targ, |
|
): |
|
sample_rate = args.output_sample_rate |
|
out_root = Path(args.results_path) |
|
if args.dump_features: |
|
feat_dir = out_root / "feat" |
|
feat_dir.mkdir(exist_ok=True, parents=True) |
|
np.save(feat_dir / f"{sample_id}.npy", feat_pred) |
|
if args.dump_target: |
|
feat_tgt_dir = out_root / "feat_tgt" |
|
feat_tgt_dir.mkdir(exist_ok=True, parents=True) |
|
np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ) |
|
if args.dump_attentions: |
|
attn_dir = out_root / "attn" |
|
attn_dir.mkdir(exist_ok=True, parents=True) |
|
np.save(attn_dir / f"{sample_id}.npy", attn.numpy()) |
|
if args.dump_eos_probs and not is_na_model: |
|
eos_dir = out_root / "eos" |
|
eos_dir.mkdir(exist_ok=True, parents=True) |
|
np.save(eos_dir / f"{sample_id}.npy", eos_prob) |
|
|
|
if args.dump_plots: |
|
images = [feat_pred.T] if is_na_model else [feat_pred.T, attn] |
|
names = ["output"] if is_na_model else ["output", "alignment"] |
|
if feat_targ is not None: |
|
images = [feat_targ.T] + images |
|
names = [f"target (idx={sample_id})"] + names |
|
if is_na_model: |
|
plot_tts_output(images, names, attn, "alignment", suptitle=text) |
|
else: |
|
plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text) |
|
plot_dir = out_root / "plot" |
|
plot_dir.mkdir(exist_ok=True, parents=True) |
|
plt.savefig(plot_dir / f"{sample_id}.png") |
|
plt.close() |
|
|
|
if args.dump_waveforms: |
|
ext = args.audio_format |
|
if wave_pred is not None: |
|
wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}" |
|
wav_dir.mkdir(exist_ok=True, parents=True) |
|
sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate) |
|
if args.dump_target and wave_targ is not None: |
|
wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt" |
|
wav_tgt_dir.mkdir(exist_ok=True, parents=True) |
|
sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate) |
|
|
|
|
|
def main(args): |
|
assert(args.dump_features or args.dump_waveforms or args.dump_attentions |
|
or args.dump_eos_probs or args.dump_plots) |
|
if args.max_tokens is None and args.batch_size is None: |
|
args.max_tokens = 8000 |
|
logger.info(args) |
|
|
|
use_cuda = torch.cuda.is_available() and not args.cpu |
|
task = tasks.setup_task(args) |
|
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( |
|
[args.path], |
|
task=task, |
|
) |
|
model = models[0].cuda() if use_cuda else models[0] |
|
|
|
task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step |
|
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) |
|
|
|
data_cfg = task.data_cfg |
|
sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050) |
|
resample_fn = { |
|
False: lambda x: x, |
|
True: lambda x: torchaudio.sox_effects.apply_effects_tensor( |
|
x.detach().cpu().unsqueeze(0), sample_rate, |
|
[['rate', str(args.output_sample_rate)]] |
|
)[0].squeeze(0) |
|
}.get(args.output_sample_rate != sample_rate) |
|
if args.output_sample_rate != sample_rate: |
|
logger.info(f"resampling to {args.output_sample_rate}Hz") |
|
|
|
generator = task.build_generator([model], args) |
|
itr = task.get_batch_iterator( |
|
dataset=task.dataset(args.gen_subset), |
|
max_tokens=args.max_tokens, |
|
max_sentences=args.batch_size, |
|
max_positions=(sys.maxsize, sys.maxsize), |
|
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, |
|
required_batch_size_multiple=args.required_batch_size_multiple, |
|
num_shards=args.num_shards, |
|
shard_id=args.shard_id, |
|
num_workers=args.num_workers, |
|
data_buffer_size=args.data_buffer_size, |
|
).next_epoch_itr(shuffle=False) |
|
|
|
Path(args.results_path).mkdir(exist_ok=True, parents=True) |
|
is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False) |
|
dataset = task.dataset(args.gen_subset) |
|
vocoder = task.args.vocoder |
|
with progress_bar.build_progress_bar(args, itr) as t: |
|
for sample in t: |
|
sample = utils.move_to_cuda(sample) if use_cuda else sample |
|
hypos = generator.generate(model, sample, has_targ=args.dump_target) |
|
for result in postprocess_results( |
|
dataset, sample, hypos, resample_fn, args.dump_target |
|
): |
|
dump_result(is_na_model, args, vocoder, *result) |
|
|
|
|
|
def cli_main(): |
|
parser = make_parser() |
|
args = options.parse_args_and_arch(parser) |
|
main(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|