|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os.path as op |
|
from argparse import Namespace |
|
|
|
from fairseq.data import Dictionary, encoders |
|
from fairseq.data.audio.speech_to_text_dataset import ( |
|
S2TDataConfig, |
|
SpeechToTextDataset, |
|
SpeechToTextDatasetCreator, |
|
get_features_or_waveform |
|
) |
|
from fairseq.tasks import LegacyFairseqTask, register_task |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@register_task("speech_to_text") |
|
class SpeechToTextTask(LegacyFairseqTask): |
|
@staticmethod |
|
def add_args(parser): |
|
parser.add_argument("data", help="manifest root path") |
|
parser.add_argument( |
|
"--config-yaml", |
|
type=str, |
|
default="config.yaml", |
|
help="Configuration YAML filename (under manifest root)", |
|
) |
|
parser.add_argument( |
|
"--max-source-positions", |
|
default=6000, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the source sequence", |
|
) |
|
parser.add_argument( |
|
"--max-target-positions", |
|
default=1024, |
|
type=int, |
|
metavar="N", |
|
help="max number of tokens in the target sequence", |
|
) |
|
|
|
def __init__(self, args, tgt_dict): |
|
super().__init__(args) |
|
self.tgt_dict = tgt_dict |
|
self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) |
|
|
|
@classmethod |
|
def setup_task(cls, args, **kwargs): |
|
data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) |
|
dict_path = op.join(args.data, data_cfg.vocab_filename) |
|
if not op.isfile(dict_path): |
|
raise FileNotFoundError(f"Dict not found: {dict_path}") |
|
tgt_dict = Dictionary.load(dict_path) |
|
logger.info( |
|
f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" |
|
) |
|
|
|
if getattr(args, "train_subset", None) is not None: |
|
if not all(s.startswith("train") for s in args.train_subset.split(",")): |
|
raise ValueError('Train splits should be named like "train*".') |
|
return cls(args, tgt_dict) |
|
|
|
def build_criterion(self, args): |
|
from fairseq import criterions |
|
|
|
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: |
|
raise ValueError( |
|
'Please set "--ignore-prefix-size 1" since ' |
|
"target language ID token is prepended as BOS." |
|
) |
|
return criterions.build_criterion(args, self) |
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
|
is_train_split = split.startswith("train") |
|
pre_tokenizer = self.build_tokenizer(self.args) |
|
bpe_tokenizer = self.build_bpe(self.args) |
|
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( |
|
self.args.data, |
|
self.data_cfg, |
|
split, |
|
self.tgt_dict, |
|
pre_tokenizer, |
|
bpe_tokenizer, |
|
is_train_split=is_train_split, |
|
epoch=epoch, |
|
seed=self.args.seed, |
|
) |
|
|
|
@property |
|
def target_dictionary(self): |
|
return self.tgt_dict |
|
|
|
@property |
|
def source_dictionary(self): |
|
return None |
|
|
|
def max_positions(self): |
|
return self.args.max_source_positions, self.args.max_target_positions |
|
|
|
def build_model(self, args): |
|
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel |
|
args.input_channels = self.data_cfg.input_channels |
|
return super(SpeechToTextTask, self).build_model(args) |
|
|
|
def build_generator( |
|
self, |
|
models, |
|
args, |
|
seq_gen_cls=None, |
|
extra_gen_cls_kwargs=None, |
|
): |
|
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: |
|
raise ValueError( |
|
'Please set "--prefix-size 1" since ' |
|
"target language ID token is prepended as BOS." |
|
) |
|
lang_token_ids = { |
|
i |
|
for s, i in self.tgt_dict.indices.items() |
|
if SpeechToTextDataset.is_lang_tag(s) |
|
} |
|
extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids} |
|
return super().build_generator( |
|
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs |
|
) |
|
|
|
def build_tokenizer(self, args): |
|
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") |
|
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) |
|
|
|
def build_bpe(self, args): |
|
logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") |
|
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) |
|
|
|
def get_interactive_tokens_and_lengths(self, lines, encode_fn): |
|
n_frames = [get_features_or_waveform(p).shape[0] for p in lines] |
|
return lines, n_frames |
|
|
|
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): |
|
return SpeechToTextDataset( |
|
"interactive", False, self.data_cfg, src_tokens, src_lengths |
|
) |
|
|