Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Based on fairseq code bases | |
# https://github.com/pytorch/fairseq | |
# -------------------------------------------------------- | |
import logging | |
from typing import Any, List, Optional, Union | |
import torch | |
from fairseq.data import data_utils, Dictionary | |
from fairseq.data.audio.hubert_dataset import HubertDataset | |
logger = logging.getLogger(__name__) | |
class Speech2cDataset(HubertDataset): | |
def __init__( | |
self, | |
manifest_path: str, | |
sample_rate: float, | |
label_paths: List[str], | |
label_rates: Union[List[float], float], # -1 for sequence labels | |
pad_list: List[str], | |
eos_list: List[str], | |
label_processors: Optional[List[Any]] = None, | |
max_keep_sample_size: Optional[int] = None, | |
min_keep_sample_size: Optional[int] = None, | |
max_sample_size: Optional[int] = None, | |
shuffle: bool = True, | |
pad_audio: bool = False, | |
normalize: bool = False, | |
store_labels: bool = True, | |
random_crop: bool = False, | |
single_target: bool = False, | |
tgt_dict: Optional[Dictionary] = None, | |
add_decoder: bool = False, | |
fine_tuning: bool = False, | |
): | |
super().__init__( | |
manifest_path, | |
sample_rate, | |
label_paths, | |
label_rates, | |
pad_list, | |
eos_list, | |
label_processors, | |
max_keep_sample_size, | |
min_keep_sample_size, | |
max_sample_size, | |
shuffle, | |
pad_audio, | |
normalize, | |
store_labels, | |
random_crop, | |
single_target | |
) | |
self.tgt_dict = tgt_dict | |
self.add_decoder = add_decoder | |
self.fine_tuning = fine_tuning | |
def collater(self, samples): | |
# target = max(sizes) -> random_crop not used | |
# target = max_sample_size -> random_crop used for long | |
samples = [s for s in samples if s["source"] is not None] | |
if len(samples) == 0: | |
return {} | |
audios = [s["source"] for s in samples] | |
audio_sizes = [len(s) for s in audios] | |
if self.pad_audio: | |
audio_size = min(max(audio_sizes), self.max_sample_size) | |
else: | |
audio_size = min(min(audio_sizes), self.max_sample_size) | |
collated_audios, padding_mask, audio_starts = self.collater_audio( | |
audios, audio_size | |
) | |
targets_by_label = [ | |
[s["label_list"][i] for s in samples] for i in range(self.num_labels) | |
] | |
targets_list, lengths_list, ntokens_list = self.collater_label( | |
targets_by_label, audio_size, audio_starts | |
) | |
if self.add_decoder: | |
if self.fine_tuning: | |
decoder_label = [ | |
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long() | |
for i in range(targets_list[0].size(0)) | |
] | |
else: | |
decoder_label = [ | |
torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive(), torch.tensor([self.tgt_dict.eos()])), 0).long() | |
for i in range(targets_list[0].size(0)) | |
] | |
dec_ntokens = sum(x.size(0) for x in decoder_label) | |
decoder_target = data_utils.collate_tokens( | |
decoder_label, | |
self.tgt_dict.pad(), | |
self.tgt_dict.eos(), | |
left_pad=False, | |
move_eos_to_beginning=False, | |
) | |
decoder_target_lengths = torch.tensor( | |
[x.size(0) for x in decoder_label], dtype=torch.long | |
) | |
prev_output_tokens = data_utils.collate_tokens( | |
decoder_label, | |
self.tgt_dict.pad(), | |
self.tgt_dict.eos(), | |
left_pad=False, | |
move_eos_to_beginning=True, | |
) | |
net_input = { | |
"source": collated_audios, | |
"padding_mask": padding_mask, | |
"prev_output_tokens": prev_output_tokens, | |
} | |
batch = { | |
"id": torch.LongTensor([s["id"] for s in samples]), | |
"net_input": net_input, | |
"decoder_target": decoder_target, | |
"decoder_target_lengths": decoder_target_lengths, | |
"dec_ntokens": dec_ntokens, | |
} | |
else: | |
net_input = {"source": collated_audios, "padding_mask": padding_mask} | |
batch = { | |
"id": torch.LongTensor([s["id"] for s in samples]), | |
"net_input": net_input, | |
} | |
if self.single_target: | |
batch["target_lengths"] = lengths_list[0] | |
batch["ntokens"] = ntokens_list[0] | |
batch["target"] = targets_list[0] | |
else: | |
batch["target_lengths_list"] = lengths_list | |
batch["ntokens_list"] = ntokens_list | |
batch["target_list"] = targets_list | |
return batch | |