# ---------------------------------------------------------------------------- # SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 # # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # ---------------------------------------------------------------------------- from pathlib import Path from typing import List, Dict, Optional, Any from dataclasses import dataclass import numpy as np import torch from fairseq.data.audio.speech_to_text_dataset import ( SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig, _collate_frames, get_features_or_waveform, ) from fairseq.data import Dictionary, data_utils as fairseq_data_utils @dataclass class TextToUnitDatasetItem(object): index: int source: torch.Tensor target: Optional[torch.Tensor] = None speaker_id: Optional[int] = None speaker_emb: Optional[torch.Tensor] = None duration: Optional[torch.Tensor] = None pitch: Optional[torch.Tensor] = None energy: Optional[torch.Tensor] = None class Text2UnitDataset(SpeechToTextDataset): def __init__( self, split: str, is_train_split: bool, cfg: S2TDataConfig, unit_labels: List[str], n_frames: List[int], src_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None, speakers: Optional[List[str]] = None, src_langs: Optional[List[str]] = None, tgt_langs: Optional[List[str]] = None, ids: Optional[List[str]] = None, tgt_dict: Optional[Dictionary] = None, pre_tokenizer=None, bpe_tokenizer=None, n_frames_per_step=1, speaker_to_id=None, durations: Optional[List[List[int]]] = None, pitches: Optional[List[str]] = None, energies: Optional[List[str]] = None, ): super(Text2UnitDataset, self).__init__( split, is_train_split, cfg, unit_labels, n_frames, src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, speaker_to_id=speaker_to_id, ) self.durations = durations self.pitches = pitches self.energies = energies self.unit_labels = unit_labels self.feature_root = Path(cfg.audio_root) self.spk_emb_type = cfg.config.get("speaker_embedding_type", None) self.random_spk = cfg.config.get("random_speaker", False) if self.spk_emb_type is not None: self.spk_emb_choices = [i for i in (self.feature_root / self.spk_emb_type).glob("*.npy")] self.spk_emb_num = len(self.spk_emb_choices) def __getitem__(self, index: int) -> TextToUnitDatasetItem: # s2t_item = super().__getitem__(index) source = torch.LongTensor(self.unit_labels[index]) target = None if self.tgt_texts is not None: tokenized = self.get_tokenized_tgt_text(index) target = self.tgt_dict.encode_line( tokenized, add_if_not_exist=False, append_eos=self.append_eos ).long() if self.cfg.prepend_tgt_lang_tag: lang_tag_idx = self.get_lang_tag_idx( self.tgt_langs[index], self.tgt_dict ) target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) speaker_id = None if self.speaker_to_id is not None: speaker_id = self.speaker_to_id[self.speakers[index]] speaker_emb = None if self.spk_emb_type is not None: if self.random_spk: spk_emb_path = self.spk_emb_choices[np.random.choice(self.spk_emb_num)] else: spk_emb_path = self.feature_root / self.spk_emb_type / f"{self.ids[index]}.npy" speaker_emb = get_features_or_waveform(spk_emb_path) speaker_emb = torch.from_numpy(speaker_emb).float() duration, pitch, energy = None, None, None if self.durations is not None: duration = torch.tensor( self.durations[index] + [0], dtype=torch.long # pad 0 for EOS ) if self.pitches is not None: pitch = get_features_or_waveform(self.pitches[index]) pitch = torch.from_numpy( np.concatenate((pitch, [0])) # pad 0 for EOS ).float() if self.energies is not None: energy = get_features_or_waveform(self.energies[index]) energy = torch.from_numpy( np.concatenate((energy, [0])) # pad 0 for EOS ).float() return TextToUnitDatasetItem( index=index, source=source, target=target, speaker_id=speaker_id, speaker_emb=speaker_emb, duration=duration, pitch=pitch, energy=energy, ) def collater(self, samples: List[TextToUnitDatasetItem]) -> Dict[str, Any]: if len(samples) == 0: return {} src_lengths, order = torch.tensor( [s.target.shape[0] for s in samples], dtype=torch.long ).sort(descending=True) id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select( 0, order ) traget = fairseq_data_utils.collate_tokens( [s.source for s in samples], self.tgt_dict.pad(), ).index_select(0, order) target_lengths = torch.tensor( [s.source.shape[0] for s in samples], dtype=torch.long ).index_select(0, order) src_tokens = fairseq_data_utils.collate_tokens( [s.target for s in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ).index_select(0, order) speaker = None if self.speaker_to_id is not None: speaker = ( torch.tensor([s.speaker_id for s in samples], dtype=torch.long) .index_select(0, order) .view(-1, 1) ) if self.spk_emb_type is not None: speaker = torch.stack([s.speaker_emb for s in samples], dim=0).index_select(0, order) bsz, _ = traget.size() prev_output_tokens = torch.cat( (traget.new_zeros((bsz, self.tgt_dict.bos())), traget[:, :-1]), dim=1 ) durations, pitches, energies = None, None, None if self.durations is not None: durations = fairseq_data_utils.collate_tokens( [s.duration for s in samples], 0 ).index_select(0, order) assert src_tokens.shape[1] == durations.shape[1] if self.pitches is not None: pitches = _collate_frames([s.pitch for s in samples], True) pitches = pitches.index_select(0, order) assert src_tokens.shape[1] == pitches.shape[1] if self.energies is not None: energies = _collate_frames([s.energy for s in samples], True) energies = energies.index_select(0, order) assert src_tokens.shape[1] == energies.shape[1] src_texts = [self.tgt_dict.string(samples[i].target) for i in order] return { "id": id_, "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, "prev_output_tokens": prev_output_tokens, }, "speaker": speaker, "target": traget, "durations": durations, "pitches": pitches, "energies": energies, "target_lengths": target_lengths, "ntokens": sum(target_lengths).item(), "nsentences": len(samples), "src_texts": src_texts, } class Text2UnitDatasetCreator(SpeechToTextDatasetCreator): KEY_DURATION = "duration" KEY_PITCH = "pitch" KEY_ENERGY = "energy" KEY_UNIT = "unit" @classmethod def _from_list( cls, split_name: str, is_train_split, samples: List[Dict], cfg: S2TDataConfig, tgt_dict, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, ) -> Text2UnitDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] # audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] unit_labels = [s[cls.KEY_UNIT] for s in samples] unit_labels = [ None if dd is None else [int(d) for d in dd.split(" ")] for dd in unit_labels ] n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] durations = [s.get(cls.KEY_DURATION, None) for s in samples] durations = [ None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations ] durations = None if any(dd is None for dd in durations) else durations pitches = [s.get(cls.KEY_PITCH, None) for s in samples] pitches = [ None if pp is None else (audio_root / pp).as_posix() for pp in pitches ] pitches = None if any(pp is None for pp in pitches) else pitches energies = [s.get(cls.KEY_ENERGY, None) for s in samples] energies = [ None if ee is None else (audio_root / ee).as_posix() for ee in energies ] energies = None if any(ee is None for ee in energies) else energies return Text2UnitDataset( split_name, is_train_split, cfg, unit_labels, n_frames, src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, durations, pitches, energies, )