|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from argparse import Namespace |
|
from pathlib import Path |
|
from typing import List |
|
|
|
from fairseq.data import Dictionary, encoders |
|
from fairseq.data.audio.audio_utils import get_features_or_waveform |
|
from fairseq.data.audio.data_cfg import MultitaskConfig |
|
from fairseq.data.audio.speech_to_text_dataset import ( |
|
S2TDataConfig, |
|
SpeechToTextDataset, |
|
SpeechToTextDatasetCreator, |
|
TextTargetMultitaskData, |
|
) |
|
from fairseq.tasks import LegacyFairseqTask, register_task |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@register_task("speech_to_text") |
|
class SpeechToTextTask(LegacyFairseqTask): |
|
@classmethod |
|
def add_args(cls, parser): |
|
parser.add_argument("data", help="manifest root path") |
|
parser.add_argument( |
|
"--config-yaml", |
|
type=str, |
|
default="config.yaml", |
|
help="Configuration YAML filename (under manifest root)", |
|
) |
|
parser.add_argument( |
|
"--multitask-config-yaml", |
|
type=str, |
|
default=None, |
|
help="Configuration YAML filename for the multitasks (under manifest root)", |
|
) |
|
parser.add_argument( |
|
"--max-source-positions", |
|
default=6000, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the source sequence", |
|
) |
|
parser.add_argument( |
|
"--max-target-positions", |
|
default=1024, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the target sequence", |
|
) |
|
|
|
def __init__(self, args, tgt_dict): |
|
super().__init__(args) |
|
self.tgt_dict = tgt_dict |
|
self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) |
|
self.speaker_to_id = self._get_speaker_to_id() |
|
if ( |
|
self.data_cfg.prepend_tgt_lang_tag |
|
and self.data_cfg.prepend_bos_and_append_tgt_lang_tag |
|
): |
|
raise ValueError( |
|
"Please set only one of the two options to avoid adding target token multiple times" |
|
) |
|
|
|
self.multitask_tasks = {} |
|
self.tgt_dict_mt = None |
|
self.eos_token_mt = None |
|
if getattr(args, "multitask_config_yaml", None) is not None: |
|
multitask_cfg = MultitaskConfig( |
|
Path(args.data) / args.multitask_config_yaml |
|
) |
|
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index |
|
for i, (task_name, task_config) in enumerate( |
|
multitask_cfg.get_all_tasks().items() |
|
): |
|
task_obj = DummyMultiTask( |
|
task_config, |
|
task_config.tgt_dict, |
|
first_pass=i == first_pass_task_idx, |
|
) |
|
self.multitask_tasks[task_name] = task_obj |
|
if task_obj.is_first_pass_decoder: |
|
self.tgt_dict_mt = task_obj.target_dictionary |
|
if task_config.prepend_bos_and_append_tgt_lang_tag: |
|
self.eos_token_mt = task_config.eos_token |
|
assert not isinstance(self.eos_token_mt, List) |
|
|
|
if not self.eos_token_mt: |
|
raise Warning( |
|
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" |
|
) |
|
|
|
def _get_speaker_to_id(self): |
|
speaker_to_id = None |
|
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") |
|
if speaker_set_filename is not None: |
|
speaker_set_path = Path(self.args.data) / speaker_set_filename |
|
with open(speaker_set_path) as f: |
|
speaker_to_id = {r.strip(): i for i, r in enumerate(f)} |
|
return speaker_to_id |
|
|
|
@classmethod |
|
def setup_task(cls, args, **kwargs): |
|
data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) |
|
dict_path = Path(args.data) / data_cfg.vocab_filename |
|
if not dict_path.is_file(): |
|
raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") |
|
tgt_dict = Dictionary.load(dict_path.as_posix()) |
|
logger.info( |
|
f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" |
|
) |
|
|
|
if getattr(args, "train_subset", None) is not None: |
|
if not all(s.startswith("train") for s in args.train_subset.split(",")): |
|
raise ValueError('Train splits should be named like "train*".') |
|
return cls(args, tgt_dict) |
|
|
|
def build_criterion(self, args): |
|
from fairseq import criterions |
|
|
|
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: |
|
raise ValueError( |
|
'Please set "--ignore-prefix-size 1" since ' |
|
"target language ID token is prepended as BOS." |
|
) |
|
return criterions.build_criterion(args, self) |
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
|
is_train_split = split.startswith("train") |
|
pre_tokenizer = self.build_tokenizer(self.args) |
|
bpe_tokenizer = self.build_bpe(self.args) |
|
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( |
|
root=self.args.data, |
|
cfg=self.data_cfg, |
|
splits=split, |
|
tgt_dict=self.tgt_dict, |
|
pre_tokenizer=pre_tokenizer, |
|
bpe_tokenizer=bpe_tokenizer, |
|
is_train_split=is_train_split, |
|
epoch=epoch, |
|
seed=self.args.seed, |
|
speaker_to_id=self.speaker_to_id, |
|
multitask=self.multitask_tasks, |
|
) |
|
|
|
@property |
|
def target_dictionary(self): |
|
return self.tgt_dict |
|
|
|
@property |
|
def target_dictionary_mt(self): |
|
return self.tgt_dict_mt |
|
|
|
@property |
|
def source_dictionary(self): |
|
return None |
|
|
|
def max_positions(self): |
|
return self.args.max_source_positions, self.args.max_target_positions |
|
|
|
def build_model(self, args, from_checkpoint=False): |
|
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel |
|
args.input_channels = self.data_cfg.input_channels |
|
args.speaker_to_id = self.speaker_to_id |
|
return super(SpeechToTextTask, self).build_model(args, from_checkpoint) |
|
|
|
def build_generator_dual_decoder( |
|
self, |
|
models, |
|
args, |
|
extra_gen_cls_kwargs, |
|
): |
|
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( |
|
MultiDecoderSequenceGenerator, |
|
) |
|
|
|
lang_token_ids_aux = { |
|
i |
|
for s, i in self.tgt_dict_mt.indices.items() |
|
if TextTargetMultitaskData.is_lang_tag(s) |
|
} |
|
|
|
extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux) |
|
|
|
eos_id_mt = ( |
|
self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None |
|
) |
|
assert eos_id_mt != self.tgt_dict_mt.unk() |
|
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt |
|
|
|
return MultiDecoderSequenceGenerator( |
|
models, |
|
self.target_dictionary, |
|
self.target_dictionary_mt, |
|
beam_size=max(1, getattr(args, "beam", 1)), |
|
beam_size_mt=max(1, getattr(args, "beam_mt", 1)), |
|
max_len_a=getattr(args, "max_len_a", 0), |
|
max_len_b=getattr(args, "max_len_b", 200), |
|
max_len_a_mt=getattr(args, "max_len_a_mt", 0), |
|
max_len_b_mt=getattr(args, "max_len_b_mt", 0), |
|
min_len=getattr(args, "min_len", 1), |
|
normalize_scores=(not getattr(args, "unnormalized", False)), |
|
len_penalty=getattr(args, "lenpen", 1), |
|
len_penalty_mt=getattr(args, "lenpen_mt", 1), |
|
unk_penalty=getattr(args, "unkpen", 0), |
|
temperature=getattr(args, "temperature", 1.0), |
|
match_source_len=getattr(args, "match_source_len", False), |
|
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), |
|
**extra_gen_cls_kwargs, |
|
) |
|
|
|
def build_generator( |
|
self, |
|
models, |
|
args, |
|
seq_gen_cls=None, |
|
extra_gen_cls_kwargs=None, |
|
): |
|
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: |
|
raise ValueError( |
|
'Please set "--prefix-size 1" since ' |
|
"target language ID token is prepended as BOS." |
|
) |
|
lang_token_ids = { |
|
i |
|
for s, i in self.tgt_dict.indices.items() |
|
if SpeechToTextDataset.is_lang_tag(s) |
|
} |
|
|
|
if extra_gen_cls_kwargs is None: |
|
extra_gen_cls_kwargs = {} |
|
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids |
|
|
|
eos_token = ( |
|
args.eos_token |
|
if "eos_token" in args and args.eos_token is not None |
|
else self.data_cfg.config.get("eos_token", None) |
|
) |
|
|
|
if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token: |
|
raise Warning( |
|
"Please provide --eos_token to replace eos in sequence generator" |
|
) |
|
|
|
eos_id = self.tgt_dict.index(eos_token) if eos_token else None |
|
extra_gen_cls_kwargs["eos"] = eos_id |
|
|
|
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None |
|
|
|
if has_dual_decoder: |
|
return self.build_generator_dual_decoder( |
|
models, |
|
args, |
|
extra_gen_cls_kwargs=extra_gen_cls_kwargs, |
|
) |
|
else: |
|
return super().build_generator( |
|
models, |
|
args, |
|
seq_gen_cls=None, |
|
extra_gen_cls_kwargs=extra_gen_cls_kwargs, |
|
) |
|
|
|
def train_step( |
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False |
|
): |
|
for task_name, task_obj in self.multitask_tasks.items(): |
|
criterion.set_multitask_loss_weight( |
|
task_name, task_obj.args.get_loss_weight(update_num) |
|
) |
|
if task_name in model.multitask_decoders: |
|
model.multitask_decoders[task_name].train() |
|
|
|
loss, sample_size, logging_output = super().train_step( |
|
sample, model, criterion, optimizer, update_num, ignore_grad |
|
) |
|
return loss, sample_size, logging_output |
|
|
|
def valid_step(self, sample, model, criterion): |
|
for task_name, task_obj in self.multitask_tasks.items(): |
|
if task_name in model.multitask_decoders: |
|
model.multitask_decoders[task_name].eval() |
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
|
|
|
return loss, sample_size, logging_output |
|
|
|
def build_tokenizer(self, args): |
|
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") |
|
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) |
|
|
|
def build_bpe(self, args): |
|
logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") |
|
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) |
|
|
|
def get_interactive_tokens_and_lengths(self, lines, encode_fn): |
|
n_frames = [get_features_or_waveform(p).shape[0] for p in lines] |
|
return lines, n_frames |
|
|
|
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): |
|
return SpeechToTextDataset( |
|
"interactive", False, self.data_cfg, src_tokens, src_lengths |
|
) |
|
|
|
|
|
class DummyMultiTask(LegacyFairseqTask): |
|
def __init__(self, args, tgt_dict, first_pass=False): |
|
super().__init__(args) |
|
self.tgt_dict = tgt_dict |
|
self.first_pass = first_pass |
|
|
|
@property |
|
def target_dictionary(self): |
|
return self.tgt_dict |
|
|
|
@property |
|
def is_first_pass_decoder(self): |
|
return self.first_pass |
|
|
|
def inference_step( |
|
self, generator, models, sample, prefix_tokens=None, constraints=None |
|
): |
|
if self.args.decoder_type == "ctc": |
|
model = models[0] |
|
encoder_out = model(**sample) |
|
if hasattr(model, "get_logits"): |
|
emissions = model.get_logits( |
|
encoder_out |
|
) |
|
else: |
|
emissions = model.get_normalized_probs(encoder_out, log_probs=True) |
|
return generator.decode( |
|
emissions.transpose(0, 1).float().cpu().contiguous() |
|
) |
|
else: |
|
raise NotImplementedError("only ctc decoder is supported at the moment") |
|
|
|
def build_generator( |
|
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None |
|
): |
|
if self.args.decoder_type == "ctc": |
|
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder |
|
|
|
return W2lViterbiDecoder(args, self.tgt_dict) |
|
else: |
|
raise NotImplementedError("only ctc decoder is supported at the moment") |
|
|