|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import math |
|
from argparse import Namespace |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from fairseq import utils |
|
from fairseq.data import Dictionary |
|
from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig |
|
from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator |
|
from fairseq.data.audio.speech_to_text_dataset import ( |
|
SpeechToTextDataset, |
|
TextTargetMultitaskData, |
|
) |
|
from fairseq.tasks import LegacyFairseqTask, register_task |
|
from fairseq.tasks.speech_to_text import DummyMultiTask |
|
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class StackUnitSequenceGenerator(nn.Module): |
|
def __init__(self, tgt_dict, vocab_size): |
|
super().__init__() |
|
self.pad = tgt_dict.pad() |
|
self.eos = tgt_dict.eos() |
|
self.unk = tgt_dict.unk() |
|
self.offset = len(tgt_dict) - vocab_size |
|
self.vocab_size = vocab_size |
|
|
|
def pack_units(self, input: torch.Tensor, n_frames_per_step) -> torch.Tensor: |
|
if n_frames_per_step <= 1: |
|
return input |
|
|
|
bsz, _, n = input.shape |
|
assert n == n_frames_per_step |
|
|
|
scale = [ |
|
pow(self.vocab_size, n_frames_per_step - 1 - i) |
|
for i in range(n_frames_per_step) |
|
] |
|
scale = torch.LongTensor(scale).squeeze(0).to(input.device) |
|
mask = input >= self.offset |
|
res = ((input - self.offset) * scale * mask).sum(dim=2) + self.offset |
|
return res |
|
|
|
@torch.no_grad() |
|
def generate(self, models, sample, **kwargs): |
|
|
|
model = models[0] |
|
model.eval() |
|
|
|
max_len = model.max_decoder_positions() |
|
|
|
|
|
src_tokens = sample["net_input"]["src_tokens"] |
|
src_lengths = sample["net_input"]["src_lengths"] |
|
bsz, src_len, _ = src_tokens.size() |
|
n_frames_per_step = model.decoder.n_frames_per_step |
|
|
|
|
|
encoder_out = model.forward_encoder( |
|
src_tokens, src_lengths, speaker=sample["speaker"] |
|
) |
|
incremental_state = {} |
|
pred_out, attn, scores = [], [], [] |
|
finished = src_tokens.new_zeros((bsz,)).bool() |
|
|
|
prev_output_tokens = src_lengths.new_zeros((bsz, 1)).long().fill_(self.eos) |
|
for _ in range(max_len): |
|
cur_out, cur_extra = model.forward_decoder( |
|
prev_output_tokens, |
|
encoder_out=encoder_out, |
|
incremental_state=incremental_state, |
|
) |
|
|
|
lprobs = model.get_normalized_probs([cur_out], log_probs=True) |
|
|
|
lprobs[:, :, self.pad] = -math.inf |
|
lprobs[:, :, self.unk] = -math.inf |
|
|
|
cur_pred_lprob, cur_pred_out = torch.max(lprobs, dim=2) |
|
scores.append(cur_pred_lprob) |
|
pred_out.append(cur_pred_out) |
|
|
|
prev_output_tokens = torch.cat( |
|
( |
|
prev_output_tokens, |
|
self.pack_units( |
|
cur_pred_out.view(bsz, 1, n_frames_per_step), n_frames_per_step |
|
), |
|
), |
|
dim=1, |
|
) |
|
|
|
attn.append(cur_extra["attn"][0]) |
|
|
|
cur_finished = torch.any(cur_pred_out.squeeze(1) == self.eos, dim=1) |
|
finished = finished | cur_finished |
|
if finished.sum().item() == bsz: |
|
break |
|
|
|
pred_out = torch.cat(pred_out, dim=1).view(bsz, -1) |
|
attn = torch.cat(attn, dim=2) |
|
alignment = attn.max(dim=1)[1] |
|
attn = attn.repeat_interleave(n_frames_per_step, dim=2) |
|
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) |
|
scores = torch.cat(scores, dim=1) |
|
eos_idx = (pred_out == self.eos).nonzero(as_tuple=True) |
|
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(max_len) |
|
for b, l in zip(eos_idx[0], eos_idx[1]): |
|
out_lens[b] = min(l, out_lens[b]) |
|
|
|
hypos = [ |
|
[ |
|
{ |
|
"tokens": pred_out[b, :out_len], |
|
"attn": attn[b, :, :out_len], |
|
"alignment": alignment[b, :out_len], |
|
"positional_scores": scores[b, :out_len], |
|
"score": utils.item(scores[b, :out_len].sum().data), |
|
} |
|
] |
|
for b, out_len in zip(range(bsz), out_lens) |
|
] |
|
|
|
return hypos |
|
|
|
|
|
@register_task("speech_to_speech") |
|
class SpeechToSpeechTask(LegacyFairseqTask): |
|
@classmethod |
|
def add_args(cls, parser): |
|
parser.add_argument("data", help="manifest root path") |
|
parser.add_argument( |
|
"--config-yaml", |
|
type=str, |
|
default="config.yaml", |
|
help="Configuration YAML filename (under manifest root)", |
|
) |
|
parser.add_argument( |
|
"--multitask-config-yaml", |
|
type=str, |
|
default=None, |
|
help="Configuration YAML filename for the multitasks (under manifest root)", |
|
) |
|
parser.add_argument( |
|
"--max-source-positions", |
|
default=6000, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the source sequence", |
|
) |
|
parser.add_argument( |
|
"--max-target-positions", |
|
default=1024, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the target sequence", |
|
) |
|
parser.add_argument( |
|
"--target-is-code", |
|
action="store_true", |
|
help="set if target is discrete unit instead of spectrogram", |
|
) |
|
parser.add_argument( |
|
"--target-code-size", type=int, default=None, help="# discrete units" |
|
) |
|
parser.add_argument( |
|
"--n-frames-per-step", |
|
type=int, |
|
default=1, |
|
help="# stacked frames, use 0 for reduced discrete unit sequence", |
|
) |
|
parser.add_argument("--eval-inference", action="store_true") |
|
parser.add_argument( |
|
"--eval-args", |
|
type=str, |
|
default="{}", |
|
help='generation args for speech-to-unit model , e.g., \'{"beam": 5, "max_len_a": 1}\', as JSON string', |
|
) |
|
parser.add_argument("--eos-prob-threshold", type=float, default=0.5) |
|
parser.add_argument( |
|
"--mcd-normalize-type", |
|
type=str, |
|
default="targ", |
|
choices=["targ", "pred", "path"], |
|
) |
|
parser.add_argument( |
|
"--vocoder", |
|
type=str, |
|
default="griffin_lim", |
|
choices=["griffin_lim", "hifigan", "code_hifigan"], |
|
) |
|
parser.add_argument("--spec-bwd-max-iter", type=int, default=8) |
|
parser.add_argument( |
|
"--infer-target-lang", |
|
type=str, |
|
default="", |
|
help="target language for inference", |
|
) |
|
|
|
def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): |
|
super().__init__(args) |
|
self.tgt_dict = tgt_dict |
|
self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) |
|
|
|
self.multitask_tasks = {} |
|
self.tgt_dict_mt = None |
|
self.eos_token_mt = None |
|
if getattr(args, "multitask_config_yaml", None) is not None: |
|
multitask_cfg = MultitaskConfig( |
|
Path(args.data) / args.multitask_config_yaml |
|
) |
|
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index |
|
for i, (task_name, task_config) in enumerate( |
|
multitask_cfg.get_all_tasks().items() |
|
): |
|
task_obj = DummyMultiTask( |
|
task_config, |
|
task_config.tgt_dict, |
|
first_pass=i == first_pass_task_idx, |
|
) |
|
self.multitask_tasks[task_name] = task_obj |
|
if task_obj.is_first_pass_decoder: |
|
self.tgt_dict_mt = task_obj.target_dictionary |
|
if task_config.prepend_bos_and_append_tgt_lang_tag: |
|
self.eos_token_mt = task_config.eos_token |
|
assert not isinstance(self.eos_token_mt, List) |
|
|
|
if not self.eos_token_mt: |
|
raise Warning( |
|
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" |
|
) |
|
|
|
self._infer_tgt_lang_id = infer_tgt_lang_id |
|
|
|
@classmethod |
|
def setup_task(cls, args, **kwargs): |
|
data_cfg = data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) |
|
tgt_dict = None |
|
infer_tgt_lang_id = None |
|
if args.target_is_code: |
|
if data_cfg.prepend_tgt_lang_tag_as_bos: |
|
|
|
dict_path = Path(args.data) / data_cfg.vocab_filename |
|
if not dict_path.is_file(): |
|
raise FileNotFoundError( |
|
f"Dict has to be provided when setting prepend_tgt_lang_tag_as_bos: true, but dict not found: {dict_path}" |
|
) |
|
tgt_dict = Dictionary.load(dict_path.as_posix()) |
|
|
|
|
|
if args.infer_target_lang != "": |
|
tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format( |
|
args.infer_target_lang |
|
) |
|
infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag) |
|
assert infer_tgt_lang_id != tgt_dict.unk() |
|
else: |
|
assert args.target_code_size is not None |
|
|
|
tgt_dict = Dictionary() |
|
for i in range(args.target_code_size): |
|
tgt_dict.add_symbol(str(i)) |
|
logger.info(f"dictionary size: " f"{len(tgt_dict):,}") |
|
|
|
if getattr(args, "train_subset", None) is not None: |
|
if not all(s.startswith("train") for s in args.train_subset.split(",")): |
|
raise ValueError('Train splits should be named like "train*".') |
|
|
|
assert args.n_frames_per_step >= 1 |
|
assert ( |
|
not args.eval_inference |
|
or (args.target_is_code and args.vocoder == "code_hifigan") |
|
or (not args.target_is_code and args.vocoder != "code_hifigan") |
|
) |
|
|
|
return cls(args, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) |
|
|
|
def build_criterion(self, args): |
|
from fairseq import criterions |
|
|
|
if len(self.multitask_tasks) > 0: |
|
if self.args.target_is_code and not args._name.startswith("speech_to_unit"): |
|
raise ValueError( |
|
"set --criterion speech_to_unit for speech-to-unit loss with multitask" |
|
) |
|
elif not self.args.target_is_code and not args._name.startswith( |
|
"speech_to_spectrogram" |
|
): |
|
raise ValueError( |
|
"set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" |
|
) |
|
|
|
return criterions.build_criterion(args, self) |
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
|
self.datasets[split] = SpeechToSpeechDatasetCreator.from_tsv( |
|
root=self.args.data, |
|
data_cfg=self.data_cfg, |
|
splits=split, |
|
is_train_split=split.startswith("train"), |
|
epoch=epoch, |
|
seed=self.args.seed, |
|
target_is_code=self.args.target_is_code, |
|
tgt_dict=self.target_dictionary, |
|
n_frames_per_step=self.args.n_frames_per_step, |
|
multitask=self.multitask_tasks, |
|
) |
|
|
|
@property |
|
def target_dictionary(self): |
|
return self.tgt_dict |
|
|
|
@property |
|
def target_dictionary_mt(self): |
|
return self.tgt_dict_mt |
|
|
|
@property |
|
def source_dictionary(self): |
|
return None |
|
|
|
def max_positions(self): |
|
return self.args.max_source_positions, self.args.max_target_positions |
|
|
|
def build_model(self, args, from_checkpoint=False): |
|
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel |
|
args.input_channels = self.data_cfg.input_transformed_channels |
|
args.target_speaker_embed = self.data_cfg.target_speaker_embed is not None |
|
args.n_frames_per_step = self.args.n_frames_per_step |
|
|
|
model = super().build_model(args, from_checkpoint) |
|
|
|
if len(self.multitask_tasks) > 0: |
|
from fairseq.models.speech_to_speech.s2s_transformer import ( |
|
S2STransformerMultitaskModelBase, |
|
) |
|
|
|
assert isinstance(model, S2STransformerMultitaskModelBase) |
|
|
|
if self.args.eval_inference: |
|
self.eval_gen_args = json.loads(self.args.eval_args) |
|
self.generator = self.build_generator( |
|
[model], Namespace(**self.eval_gen_args) |
|
) |
|
|
|
return model |
|
|
|
def build_generator_dual_decoder( |
|
self, |
|
models, |
|
args, |
|
extra_gen_cls_kwargs=None, |
|
): |
|
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( |
|
MultiDecoderSequenceGenerator, |
|
) |
|
|
|
return MultiDecoderSequenceGenerator( |
|
models, |
|
self.target_dictionary, |
|
self.target_dictionary_mt, |
|
beam_size=max(1, getattr(args, "beam", 1)), |
|
beam_size_mt=max(1, getattr(args, "beam_mt", 1)), |
|
max_len_a=getattr(args, "max_len_a", 0), |
|
max_len_b=getattr(args, "max_len_b", 200), |
|
max_len_a_mt=getattr(args, "max_len_a_mt", 0), |
|
max_len_b_mt=getattr(args, "max_len_b_mt", 200), |
|
min_len=getattr(args, "min_len", 1), |
|
normalize_scores=(not getattr(args, "unnormalized", False)), |
|
len_penalty=getattr(args, "lenpen", 1), |
|
unk_penalty=getattr(args, "unkpen", 0), |
|
temperature=getattr(args, "temperature", 1.0), |
|
match_source_len=getattr(args, "match_source_len", False), |
|
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), |
|
**extra_gen_cls_kwargs, |
|
) |
|
|
|
def build_generator( |
|
self, |
|
models, |
|
args, |
|
seq_gen_cls=None, |
|
extra_gen_cls_kwargs=None, |
|
): |
|
|
|
if not self.args.target_is_code or self.args.eval_inference: |
|
from fairseq.models.text_to_speech.vocoder import get_vocoder |
|
|
|
self.vocoder = get_vocoder(self.args, self.data_cfg) |
|
self.vocoder = ( |
|
self.vocoder.cuda() |
|
if torch.cuda.is_available() and not self.args.cpu |
|
else self.vocoder.cpu() |
|
) |
|
|
|
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None |
|
|
|
if self.args.target_is_code: |
|
if self.args.n_frames_per_step == 1: |
|
if has_dual_decoder: |
|
seq_generator = self.build_generator_dual_decoder( |
|
models, |
|
args, |
|
extra_gen_cls_kwargs=extra_gen_cls_kwargs, |
|
) |
|
else: |
|
seq_generator = super().build_generator( |
|
models, |
|
args, |
|
seq_gen_cls=None, |
|
extra_gen_cls_kwargs=extra_gen_cls_kwargs, |
|
) |
|
else: |
|
assert ( |
|
getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1 |
|
), "only support viterbi search for stacked units" |
|
seq_generator = StackUnitSequenceGenerator( |
|
self.tgt_dict, |
|
self.args.target_code_size, |
|
) |
|
else: |
|
if has_dual_decoder: |
|
if getattr(args, "teacher_forcing", False): |
|
raise NotImplementedError |
|
else: |
|
from fairseq.speech_generator import MultiDecoderSpeechGenerator |
|
|
|
generator = MultiDecoderSpeechGenerator |
|
|
|
lang_token_ids_aux = { |
|
i |
|
for s, i in self.tgt_dict_mt.indices.items() |
|
if TextTargetMultitaskData.is_lang_tag(s) |
|
} |
|
|
|
if extra_gen_cls_kwargs is None: |
|
extra_gen_cls_kwargs = {} |
|
extra_gen_cls_kwargs[ |
|
"symbols_to_strip_from_output" |
|
] = lang_token_ids_aux |
|
|
|
eos_id_mt = ( |
|
self.tgt_dict_mt.index(self.eos_token_mt) |
|
if self.eos_token_mt |
|
else None |
|
) |
|
assert eos_id_mt != self.tgt_dict_mt.unk() |
|
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt |
|
|
|
seq_generator = generator( |
|
models, |
|
args, |
|
self.vocoder, |
|
self.data_cfg, |
|
self.target_dictionary_mt, |
|
max_iter=self.args.max_target_positions, |
|
eos_prob_threshold=self.args.eos_prob_threshold, |
|
**extra_gen_cls_kwargs, |
|
) |
|
else: |
|
if getattr(args, "teacher_forcing", False): |
|
from fairseq.speech_generator import ( |
|
TeacherForcingAutoRegressiveSpeechGenerator, |
|
) |
|
|
|
generator = TeacherForcingAutoRegressiveSpeechGenerator |
|
logger.info("Teacher forcing mode for generation") |
|
else: |
|
from fairseq.speech_generator import AutoRegressiveSpeechGenerator |
|
|
|
generator = AutoRegressiveSpeechGenerator |
|
|
|
seq_generator = generator( |
|
models[0], |
|
self.vocoder, |
|
self.data_cfg, |
|
max_iter=self.args.max_target_positions, |
|
eos_prob_threshold=self.args.eos_prob_threshold, |
|
) |
|
|
|
return seq_generator |
|
|
|
def train_step( |
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False |
|
): |
|
for task_name, task_obj in self.multitask_tasks.items(): |
|
criterion.set_multitask_loss_weight( |
|
task_name, task_obj.args.get_loss_weight(update_num) |
|
) |
|
if task_name in model.multitask_decoders: |
|
model.multitask_decoders[task_name].train() |
|
|
|
loss, sample_size, logging_output = super().train_step( |
|
sample, model, criterion, optimizer, update_num, ignore_grad |
|
) |
|
return loss, sample_size, logging_output |
|
|
|
def valid_step(self, sample, model, criterion): |
|
for task_name in self.multitask_tasks.keys(): |
|
if task_name in model.multitask_decoders: |
|
model.multitask_decoders[task_name].eval() |
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
|
|
|
if self.args.eval_inference: |
|
hypos, inference_losses = self.valid_step_with_inference( |
|
sample, model, self.generator |
|
) |
|
for k, v in inference_losses.items(): |
|
assert k not in logging_output |
|
logging_output[k] = v |
|
|
|
return loss, sample_size, logging_output |
|
|
|
def valid_step_with_inference(self, sample, model, generator): |
|
if self.args.target_is_code: |
|
hypos = generator.generate([model], sample) |
|
tgt_lens = ( |
|
sample["target_lengths"] - 1 |
|
) * self.args.n_frames_per_step |
|
for b, (f, l) in enumerate(zip(sample["target"], tgt_lens)): |
|
hypos[b][0]["targ_waveform"] = self.vocoder( |
|
{"code": f[:l] - 4}, |
|
dur_prediction=self.eval_gen_args.get("dur_prediction", False), |
|
) |
|
if len(hypos[b][0]["tokens"]) > 0: |
|
hypos[b][0]["waveform"] = self.vocoder( |
|
{"code": hypos[b][0]["tokens"] - 4}, |
|
dur_prediction=self.eval_gen_args.get("dur_prediction", False), |
|
) |
|
else: |
|
hypos[b][0]["waveform"] = torch.flip( |
|
hypos[b][0]["targ_waveform"], dims=[0] |
|
) |
|
else: |
|
hypos = [ |
|
[hypo] for hypo in generator.generate(model, sample, has_targ=True) |
|
] |
|
|
|
losses = { |
|
"mcd_loss": 0.0, |
|
"targ_frames": 0.0, |
|
"pred_frames": 0.0, |
|
"path_frames": 0.0, |
|
"nins": 0.0, |
|
"ndel": 0.0, |
|
} |
|
rets = batch_mel_cepstral_distortion( |
|
[hypo[0]["targ_waveform"] for hypo in hypos], |
|
[hypo[0]["waveform"] for hypo in hypos], |
|
self.data_cfg.output_sample_rate, |
|
normalize_type=None, |
|
) |
|
for d, extra in rets: |
|
pathmap = extra[-1] |
|
losses["mcd_loss"] += d.item() |
|
losses["targ_frames"] += pathmap.size(0) |
|
losses["pred_frames"] += pathmap.size(1) |
|
losses["path_frames"] += pathmap.sum().item() |
|
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() |
|
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() |
|
losses["norm_frames"] = losses[ |
|
f"{getattr(self.args, 'mcd_normalize_type', 'targ')}_frames" |
|
] |
|
|
|
return hypos, losses |
|
|
|
def inference_step( |
|
self, generator, models, sample, prefix_tokens=None, constraints=None |
|
): |
|
with torch.no_grad(): |
|
if self._infer_tgt_lang_id is not None: |
|
return generator.generate( |
|
models, |
|
sample, |
|
prefix_tokens=prefix_tokens, |
|
constraints=constraints, |
|
bos_token=self._infer_tgt_lang_id, |
|
) |
|
else: |
|
return super().inference_step( |
|
generator, |
|
models, |
|
sample, |
|
prefix_tokens=prefix_tokens, |
|
constraints=constraints, |
|
) |
|
|