# -------------------------------------------------------- # The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777) # Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Based on fairseq code bases # https://github.com/facebookresearch/fairseq # -------------------------------------------------------- import itertools import logging import os import sys from typing import Any, List, Optional, Union import numpy as np import torch import torch.nn.functional as F from fairseq.data import data_utils, Dictionary from fairseq.data.audio.hubert_dataset import HubertDataset logger = logging.getLogger(__name__) class Speech2cDataset(HubertDataset): def __init__( self, manifest_path: str, sample_rate: float, label_paths: List[str], label_rates: Union[List[float], float], # -1 for sequence labels pad_list: List[str], eos_list: List[str], label_processors: Optional[List[Any]] = None, max_keep_sample_size: Optional[int] = None, min_keep_sample_size: Optional[int] = None, max_sample_size: Optional[int] = None, shuffle: bool = True, pad_audio: bool = False, normalize: bool = False, store_labels: bool = True, random_crop: bool = False, single_target: bool = False, tgt_dict: Optional[Dictionary] = None, add_decoder: bool = False, fine_tuning: bool = False, tokenizer = None, tgt_lang_idx: int = None, mbart_style_lang_id: bool = False, retry_times: int = 5, reduce_label_for_dec: bool = True, ): super().__init__( manifest_path, sample_rate, label_paths, label_rates, pad_list, eos_list, label_processors, max_keep_sample_size, min_keep_sample_size, max_sample_size, shuffle, pad_audio, normalize, store_labels, random_crop, single_target ) self.tgt_dict = tgt_dict self.add_decoder = add_decoder self.fine_tuning = fine_tuning self.tokenizer = tokenizer self.tgt_lang_idx = tgt_lang_idx self.mbart_style_lang_id = mbart_style_lang_id self.retry_times = retry_times self.reduce_label_for_dec = reduce_label_for_dec logger.info( f"tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, " f"mbart_style_lang_id={mbart_style_lang_id}" ) self.sizes = np.array(self.sizes) def get_label(self, index, label_idx): if self.store_labels: label = self.label_list[label_idx][index] else: with open(self.label_paths[label_idx]) as f: offset_s, offset_e = self.label_offsets_list[label_idx][index] f.seek(offset_s) label = f.read(offset_e - offset_s) if self.tokenizer is not None and self.fine_tuning: label = self.tokenizer.encode(label) if self.label_processors is not None: label = self.label_processors[label_idx](label) return label def collater(self, samples): # target = max(sizes) -> random_crop not used # target = max_sample_size -> random_crop used for long samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: return {} audios = [s["source"] for s in samples] audio_sizes = [len(s) for s in audios] if self.pad_audio: audio_size = min(max(audio_sizes), self.max_sample_size) else: audio_size = min(min(audio_sizes), self.max_sample_size) collated_audios, padding_mask, audio_starts = self.collater_audio( audios, audio_size ) targets_by_label = [ [s["label_list"][i] for s in samples] for i in range(self.num_labels) ] targets_list, lengths_list, ntokens_list = self.collater_label( targets_by_label, audio_size, audio_starts ) if self.add_decoder: if self.fine_tuning: decoder_label = [ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long() for i in range(targets_list[0].size(0)) ] else: if self.tokenizer is not None: decoder_label = [ # Set 48 for translate int to char and avoid \n torch.cat( ( torch.tensor( self.tokenizer.sp.Encode( "".join( [chr(j + 48) for j in ( targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]] ).tolist()] ), out_type=int ) ), torch.tensor([self.tgt_dict.eos()]) ), dim=0 ).long() for i in range(targets_list[0].size(0)) ] else: decoder_label = [ torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long() for i in range(targets_list[0].size(0)) ] if self.mbart_style_lang_id: decoder_label = [ torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long() for i in range(targets_list[0].size(0)) ] dec_ntokens = sum(x.size(0) for x in decoder_label) decoder_target = data_utils.collate_tokens( decoder_label, self.tgt_dict.pad(), self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx, left_pad=False, move_eos_to_beginning=False, ) decoder_target_lengths = torch.tensor( [x.size(0) for x in decoder_label], dtype=torch.long ) prev_output_tokens = data_utils.collate_tokens( decoder_label, self.tgt_dict.pad(), self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx, left_pad=False, move_eos_to_beginning=True, ) if self.tgt_lang_idx is not None and not self.mbart_style_lang_id: assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0 prev_output_tokens[:, 0] = self.tgt_lang_idx net_input = { "source": collated_audios, "padding_mask": padding_mask, "prev_output_tokens": prev_output_tokens, } batch = { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": net_input, "decoder_target": decoder_target, "decoder_target_lengths": decoder_target_lengths, "dec_ntokens": dec_ntokens, "lang_idx": self.tgt_lang_idx, } else: net_input = {"source": collated_audios, "padding_mask": padding_mask} batch = { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": net_input, } if self.single_target: batch["target_lengths"] = lengths_list[0] batch["ntokens"] = ntokens_list[0] batch["target"] = targets_list[0] else: batch["target_lengths_list"] = lengths_list batch["ntokens_list"] = ntokens_list batch["target_list"] = targets_list return batch # @property # def sizes(self): # return np.array(self.sizes)