Spaces:
Sleeping
Sleeping
from dataclasses import asdict, dataclass | |
import os | |
from pathlib import Path | |
import tempfile | |
from typing import Dict, List, Literal, Optional, Tuple | |
os.environ["TQDM_DISABLE"] = "1" | |
import tqdm | |
def nop(it, *a, **k): | |
return it | |
tqdm.tqdm = nop | |
from lhotse import CutSet, RecordingSet, SupervisionSet | |
from lhotse.cut import MonoCut | |
from lhotse.recipes import hifitts, libritts | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from torch import Tensor | |
from torch.utils.data import DataLoader, Dataset | |
from voicefixer import VoiceFixer | |
from models.config import PreprocessingConfigHifiGAN as PreprocessingConfig | |
from models.config import get_lang_map, lang2id | |
from training.preprocess import PreprocessLibriTTS | |
from training.tools import pad_1D, pad_2D, pad_3D | |
NUM_JOBS = (os.cpu_count() or 2) - 1 | |
# The selected speakers from the HiFiTTS dataset | |
speakers_hifi_ids = [ | |
"Cori Samuel", # 92, | |
"Tony Oliva", # 6671, | |
"John Van Stan", # 9017, | |
"Helen Taylor", # 9136, | |
# "Phil Benson", # 6097, | |
# "Mike Pelton", # 6670, | |
# "Maria Kasper", # 8051, | |
# "Sylviamb", # 11614, | |
# "Celine Major", # 11697, | |
# "LikeManyWaters", # 12787, | |
] | |
# The selected speakers from the LibriTTS dataset | |
speakers_libri_ids = list( | |
map( | |
str, | |
[ | |
# train-clean-100 | |
40, | |
1088, | |
# train-clean-360 | |
3307, | |
5935, | |
# train-other-500 | |
215, | |
6594, | |
3867, | |
5733, | |
5181, | |
], | |
), | |
) | |
# Map the speaker ids to string and list of selected speaker ids to set | |
selected_speakers_ids = { | |
v: k | |
for k, v in enumerate( | |
speakers_hifi_ids + speakers_libri_ids, | |
) | |
} | |
def prep_2_cutset(prep: Dict[str, Dict[str, RecordingSet | SupervisionSet]]) -> CutSet: | |
r"""Prepare the dataset for the model. This function is used to convert the prepared dataset to a CutSet. | |
Args: | |
prep (Dict[str, Dict[str, RecordingSet | SupervisionSet]]): The prepared dataset. | |
Returns: | |
CutSet: The dataset prepared for the model. | |
""" | |
recordings_hifi = RecordingSet() | |
supervisions_hifi = SupervisionSet() | |
for hifi_row in prep.values(): | |
record = hifi_row["recordings"] | |
supervision = hifi_row["supervisions"] | |
# Separate the recordings and supervisions | |
if isinstance(record, RecordingSet): | |
recordings_hifi += record | |
if isinstance(supervision, SupervisionSet): | |
supervisions_hifi += supervision | |
# Add the recordings and supervisions to the CutSet | |
return CutSet.from_manifests( | |
recordings=recordings_hifi, | |
supervisions=supervisions_hifi, | |
) | |
DATASET_TYPES = Literal["hifitts", "libritts"] | |
class HifiLibriItem: | |
"""Dataset row for the HiFiTTS and LibriTTS datasets combined in this code. | |
Args: | |
id (str): The ID of the item. | |
wav (Tensor): The waveform of the audio. | |
mel (Tensor): The mel spectrogram. | |
pitch (Tensor): The pitch. | |
text (Tensor): The text. | |
attn_prior (Tensor): The attention prior. | |
energy (Tensor): The energy. | |
raw_text (str): The raw text. | |
normalized_text (str): The normalized text. | |
speaker (int): The speaker ID. | |
pitch_is_normalized (bool): Whether the pitch is normalized. | |
lang (int): The language ID. | |
dataset_type (DATASET_TYPES): The type of dataset. | |
""" | |
id: str | |
wav: Tensor | |
mel: Tensor | |
pitch: Tensor | |
text: Tensor | |
attn_prior: Tensor | |
energy: Tensor | |
raw_text: str | |
normalized_text: str | |
speaker: int | |
pitch_is_normalized: bool | |
lang: int | |
dataset_type: DATASET_TYPES | |
class HifiLibriDataset(Dataset): | |
r"""A PyTorch dataset for loading delightful TTS data.""" | |
def __init__( | |
self, | |
lang: str = "en", | |
root: str = "datasets_cache", | |
sampling_rate: int = 44100, | |
hifitts_path: str = "hifitts", | |
hifi_cutset_file_name: str = "hifi.json.gz", | |
libritts_path: str = "librittsr", | |
libritts_cutset_file_name: str = "libri.json.gz", | |
libritts_subsets: List[str] | str = "all", | |
cache: bool = False, | |
cache_dir: str = "/dev/shm", | |
num_jobs: int = NUM_JOBS, | |
min_seconds: Optional[float] = None, | |
max_seconds: Optional[float] = None, | |
include_libri: bool = True, | |
libri_speakers: List[str] = speakers_libri_ids, | |
hifi_speakers: List[str] = speakers_hifi_ids, | |
): | |
r"""Initializes the dataset. | |
Args: | |
lang (str, optional): The language of the dataset. Defaults to "en". | |
root (str, optional): The root directory of the dataset. Defaults to "datasets_cache". | |
sampling_rate (int, optional): The sampling rate of the audio. Defaults to 44100. | |
hifitts_path (str, optional): The path to the HiFiTTS dataset. Defaults to "hifitts". | |
hifi_cutset_file_name (str, optional): The file name of the HiFiTTS cutset. Defaults to "hifi.json.gz". | |
libritts_path (str, optional): The path to the LibriTTS dataset. Defaults to "librittsr". | |
libritts_cutset_file_name (str, optional): The file name of the LibriTTS cutset. Defaults to "libri.json.gz". | |
libritts_subsets (Union[List[str], str], optional): The subsets of the LibriTTS dataset to use. Defaults to "all". | |
cache (bool, optional): Whether to cache the dataset. Defaults to False. | |
cache_dir (str, optional): The directory to cache the dataset in. Defaults to "/dev/shm". | |
num_jobs (int, optional): The number of jobs to use for preparing the dataset. Defaults to NUM_JOBS. | |
min_seconds (Optional[float], optional): The minimum duration of the audio. Defaults from the preprocess config. | |
max_seconds (Optional[float], optional): The maximum duration of the audio. Defaults from the preprocess config. | |
include_libri (bool, optional): Whether to include the LibriTTS dataset. Defaults to True. | |
libri_speakers (List[str], optional): The selected speakers from the LibriTTS dataset. Defaults to selected_speakers_libri_ids. | |
hifi_speakers (List[str], optional): The selected speakers from the HiFiTTS dataset. Defaults to selected_speakers_hi_fi_ids. | |
""" | |
lang_map = get_lang_map(lang) | |
processing_lang_type = lang_map.processing_lang_type | |
self.preprocess_config = PreprocessingConfig( | |
processing_lang_type, | |
sampling_rate=sampling_rate, | |
) | |
self.min_seconds = min_seconds or self.preprocess_config.min_seconds | |
self.max_seconds = max_seconds or self.preprocess_config.max_seconds | |
self.dur_filter = ( | |
lambda duration: duration >= self.min_seconds | |
and duration <= self.max_seconds | |
) | |
self.preprocess_libtts = PreprocessLibriTTS( | |
self.preprocess_config, | |
lang, | |
) | |
self.root_dir = Path(root) | |
self.voicefixer = VoiceFixer() | |
# Map the speaker ids to string and list of selected speaker ids to set | |
self.selected_speakers_libri_ids_ = set(libri_speakers) | |
self.selected_speakers_hi_fi_ids_ = set(hifi_speakers) | |
self.cache = cache | |
self.cache_dir = Path(cache_dir) / f"cache-{hifitts_path}-{libritts_path}" | |
# Prepare the HiFiTTS dataset | |
self.hifitts_path = self.root_dir / hifitts_path | |
hifi_cutset_file_path = self.root_dir / hifi_cutset_file_name | |
# Initialize the cutset | |
self.cutset = CutSet() | |
# Check if the HiFiTTS dataset has been prepared | |
if hifi_cutset_file_path.exists(): | |
self.cutset_hifi = CutSet.from_file(hifi_cutset_file_path) | |
else: | |
hifitts_root = hifitts.download_hifitts(self.hifitts_path) | |
prepared_hifi = hifitts.prepare_hifitts( | |
hifitts_root, | |
num_jobs=num_jobs, | |
) | |
# Add the recordings and supervisions to the CutSet | |
self.cutset_hifi = prep_2_cutset(prepared_hifi) | |
# Save the prepared HiFiTTS dataset cutset | |
self.cutset_hifi.to_file(hifi_cutset_file_path) | |
# Filter the HiFiTTS cutset to only include the selected speakers | |
self.cutset_hifi = self.cutset_hifi.filter( | |
lambda cut: isinstance(cut, MonoCut) | |
and str(cut.supervisions[0].speaker) in self.selected_speakers_hi_fi_ids_ | |
and self.dur_filter(cut.duration), | |
).to_eager() | |
# Add the HiFiTTS cutset to the final cutset | |
self.cutset += self.cutset_hifi | |
if include_libri: | |
# Prepare the LibriTTS dataset | |
self.libritts_path = self.root_dir / libritts_path | |
libritts_cutset_file_path = self.root_dir / libritts_cutset_file_name | |
# Check if the LibriTTS dataset has been prepared | |
if libritts_cutset_file_path.exists(): | |
self.cutset_libri = CutSet.from_file(libritts_cutset_file_path) | |
else: | |
libritts_root = libritts.download_librittsr( | |
self.libritts_path, | |
dataset_parts=libritts_subsets, | |
) | |
prepared_libri = libritts.prepare_librittsr( | |
libritts_root / "LibriTTS_R", | |
dataset_parts=libritts_subsets, | |
num_jobs=num_jobs, | |
) | |
# Add the recordings and supervisions to the CutSet | |
self.cutset_libri = prep_2_cutset(prepared_libri) | |
# Save the prepared cutset for LibriTTS | |
self.cutset_libri.to_file(libritts_cutset_file_path) | |
# Filter the libri cutset to only include the selected speakers | |
self.cutset_libri = self.cutset_libri.filter( | |
lambda cut: isinstance(cut, MonoCut) | |
and str(cut.supervisions[0].speaker) | |
in self.selected_speakers_libri_ids_ | |
and self.dur_filter(cut.duration), | |
).to_eager() | |
# Add the LibriTTS cutset to the final cutset | |
self.cutset += self.cutset_libri | |
# to_eager() is used to evaluates all lazy operations on this manifest | |
self.cutset = self.cutset.to_eager() | |
def get_cache_subdir_path(self, idx: int) -> Path: | |
r"""Calculate the path to the cache subdirectory. | |
Args: | |
idx (int): The index of the cache subdirectory. | |
Returns: | |
Path: The path to the cache subdirectory. | |
""" | |
return self.cache_dir / str(((idx // 1000) + 1) * 1000) | |
def get_cache_file_path(self, idx: int) -> Path: | |
r"""Calculate the path to the cache file. | |
Args: | |
idx (int): The index of the cache file. | |
Returns: | |
Path: The path to the cache file. | |
""" | |
return self.get_cache_subdir_path(idx) / f"{idx}.pt" | |
def __len__(self) -> int: | |
r"""Returns the length of the dataset. | |
Returns: | |
int: The length of the dataset. | |
""" | |
return len(self.cutset) | |
def __getitem__(self, idx: int) -> HifiLibriItem: | |
r"""Returns the item at the specified index. | |
Args: | |
idx (int): The index of the item. | |
Returns: | |
HifiLibriItem: The item at the specified index. | |
""" | |
cache_file = self.get_cache_file_path(idx) | |
if self.cache and cache_file.exists(): | |
cached_data: Dict = torch.load(cache_file) | |
# Cast the cached data to the PreprocessForAcousticResult class | |
result = HifiLibriItem(**cached_data) | |
return result | |
cutset = self.cutset[idx] | |
if isinstance(cutset, MonoCut) and cutset.recording is not None: | |
dataset_speaker_id = str(cutset.supervisions[0].speaker) | |
# Map the dataset speaker id to the speaker id in the model | |
speaker_id = selected_speakers_ids.get( | |
dataset_speaker_id, | |
len(selected_speakers_ids) + 1, | |
) | |
# Run voicefixer only for the libri speakers | |
if str(dataset_speaker_id) in self.selected_speakers_libri_ids_: | |
audio_path = cutset.recording.sources[0].source | |
# Restore LibriTTS-R audio | |
with tempfile.NamedTemporaryFile( | |
suffix=".wav", | |
delete=True, | |
) as out_file: | |
self.voicefixer.restore( | |
input=audio_path, # low quality .wav/.flac file | |
output=out_file.name, # save file path | |
cuda=False, # GPU acceleration | |
mode=0, | |
) | |
audio, _ = sf.read(out_file.name) | |
# Convert the np audio to a tensor | |
audio = torch.from_numpy(audio).float().unsqueeze(0) | |
else: | |
# Load the audio from the cutset | |
audio = torch.from_numpy(cutset.load_audio()) | |
text: str = str(cutset.supervisions[0].text) | |
fileid = str(cutset.supervisions[0].recording_id) | |
split_fileid = fileid.split("_") | |
chapter_id = split_fileid[1] | |
utterance_id = split_fileid[-1] | |
libri_row = ( | |
audio, | |
cutset.sampling_rate, | |
text, | |
text, | |
speaker_id, | |
chapter_id, | |
utterance_id, | |
) | |
data = self.preprocess_libtts.acoustic(libri_row) | |
if data is None: | |
rand_idx = int( | |
torch.randint( | |
0, | |
self.__len__(), | |
(1,), | |
).item(), | |
) | |
return self.__getitem__(rand_idx) | |
data.wav = data.wav.unsqueeze(0) | |
result = HifiLibriItem( | |
id=data.utterance_id, | |
wav=data.wav, | |
mel=data.mel, | |
pitch=data.pitch, | |
text=data.phones, | |
attn_prior=data.attn_prior, | |
energy=data.energy, | |
raw_text=data.raw_text, | |
normalized_text=data.normalized_text, | |
speaker=speaker_id, | |
pitch_is_normalized=data.pitch_is_normalized, | |
lang=lang2id["en"], | |
dataset_type="hifitts" if idx < len(self.cutset_hifi) else "libritts", | |
) | |
if self.cache: | |
# Create the cache subdirectory if it doesn't exist | |
Path.mkdir( | |
self.get_cache_subdir_path(idx), | |
parents=True, | |
exist_ok=True, | |
) | |
# Save the preprocessed data to the cache | |
torch.save(asdict(result), cache_file) | |
return result | |
else: | |
raise FileNotFoundError(f"Cut not found at index {idx}.") | |
def __iter__(self): | |
r"""Method makes the class iterable. It iterates over the `_walker` attribute | |
and for each item, it gets the corresponding item from the dataset using the | |
`__getitem__` method. | |
Yields: | |
The item from the dataset corresponding to the current item in `_walker`. | |
""" | |
for item in range(self.__len__()): | |
yield self.__getitem__(item) | |
def collate_fn(self, data: List[HifiLibriItem]) -> List: | |
r"""Collates a batch of data samples. | |
Args: | |
data (List[HifiLibriItem]): A list of data samples. | |
Returns: | |
List: A list of reprocessed data batches. | |
""" | |
data_size = len(data) | |
idxs = list(range(data_size)) | |
# Initialize empty lists to store extracted values | |
empty_lists: List[List] = [[] for _ in range(12)] | |
( | |
ids, | |
speakers, | |
texts, | |
raw_texts, | |
mels, | |
pitches, | |
attn_priors, | |
langs, | |
src_lens, | |
mel_lens, | |
wavs, | |
energy, | |
) = empty_lists | |
# Extract fields from data dictionary and populate the lists | |
for idx in idxs: | |
data_entry = data[idx] | |
ids.append(data_entry.id) | |
speakers.append(data_entry.speaker) | |
texts.append(data_entry.text) | |
raw_texts.append(data_entry.raw_text) | |
mels.append(data_entry.mel) | |
pitches.append(data_entry.pitch) | |
attn_priors.append(data_entry.attn_prior) | |
langs.append(data_entry.lang) | |
src_lens.append(data_entry.text.shape[0]) | |
mel_lens.append(data_entry.mel.shape[1]) | |
wavs.append(data_entry.wav) | |
energy.append(data_entry.energy) | |
# Convert langs, src_lens, and mel_lens to numpy arrays | |
langs = np.array(langs) | |
src_lens = np.array(src_lens) | |
mel_lens = np.array(mel_lens) | |
# NOTE: Instead of the pitches for the whole dataset, used stat for the batch | |
# Take only min and max values for pitch | |
pitches_stat = list(self.normalize_pitch(pitches)[:2]) | |
texts = pad_1D(texts) | |
mels = pad_2D(mels) | |
pitches = pad_1D(pitches) | |
attn_priors = pad_3D(attn_priors, len(idxs), max(src_lens), max(mel_lens)) | |
speakers = np.repeat( | |
np.expand_dims(np.array(speakers), axis=1), | |
texts.shape[1], | |
axis=1, | |
) | |
langs = np.repeat( | |
np.expand_dims(np.array(langs), axis=1), | |
texts.shape[1], | |
axis=1, | |
) | |
wavs = pad_2D(wavs) | |
energy = pad_2D(energy) | |
return [ | |
ids, | |
raw_texts, | |
torch.from_numpy(speakers), | |
texts.int(), | |
torch.from_numpy(src_lens), | |
mels, | |
pitches, | |
pitches_stat, | |
torch.from_numpy(mel_lens), | |
torch.from_numpy(langs), | |
attn_priors, | |
wavs, | |
energy, | |
] | |
def normalize_pitch( | |
self, | |
pitches: List[torch.Tensor], | |
) -> Tuple[float, float, float, float]: | |
r"""Normalizes the pitch values. | |
Args: | |
pitches (List[torch.Tensor]): A list of pitch values. | |
Returns: | |
Tuple: A tuple containing the normalized pitch values. | |
""" | |
pitches_t = torch.concatenate(pitches) | |
min_value = torch.min(pitches_t).item() | |
max_value = torch.max(pitches_t).item() | |
mean = torch.mean(pitches_t).item() | |
std = torch.std(pitches_t).item() | |
return min_value, max_value, mean, std | |
def train_dataloader( | |
batch_size: int = 6, | |
num_workers: int = 5, | |
sampling_rate: int = 22050, | |
shuffle: bool = False, | |
lang: str = "en", | |
root: str = "datasets_cache", | |
hifitts_path: str = "hifitts", | |
hifi_cutset_file_name: str = "hifi.json.gz", | |
libritts_path: str = "librittsr", | |
libritts_cutset_file_name: str = "libri.json.gz", | |
libritts_subsets: List[str] | str = "all", | |
cache: bool = False, | |
cache_dir: str = "/dev/shm", | |
include_libri: bool = True, | |
libri_speakers: List[str] = speakers_libri_ids, | |
hifi_speakers: List[str] = speakers_hifi_ids, | |
) -> DataLoader: | |
r"""Returns the training dataloader, that is using the HifiLibriDataset dataset. | |
Args: | |
batch_size (int): The batch size. | |
num_workers (int): The number of workers. | |
sampling_rate (int): The sampling rate of the audio. Defaults to 22050. | |
shuffle (bool): Whether to shuffle the dataset. | |
lang (str): The language of the dataset. | |
root (str): The root directory of the dataset. | |
hifitts_path (str): The path to the HiFiTTS dataset. | |
hifi_cutset_file_name (str): The file name of the HiFiTTS cutset. | |
libritts_path (str): The path to the LibriTTS dataset. | |
libritts_cutset_file_name (str): The file name of the LibriTTS cutset. | |
libritts_subsets (List[str] | str): The subsets of the LibriTTS dataset to use. | |
cache (bool): Whether to cache the dataset. | |
cache_dir (str): The directory to cache the dataset in. | |
include_libri (bool): Whether to include the LibriTTS dataset. | |
libri_speakers (List[str]): The selected speakers from the LibriTTS dataset. | |
hifi_speakers (List[str]): The selected speakers from the HiFiTTS dataset. | |
Returns: | |
DataLoader: The training dataloader. | |
""" | |
dataset = HifiLibriDataset( | |
root=root, | |
hifitts_path=hifitts_path, | |
sampling_rate=sampling_rate, | |
hifi_cutset_file_name=hifi_cutset_file_name, | |
libritts_path=libritts_path, | |
libritts_cutset_file_name=libritts_cutset_file_name, | |
libritts_subsets=libritts_subsets, | |
cache=cache, | |
cache_dir=cache_dir, | |
lang=lang, | |
include_libri=include_libri, | |
libri_speakers=libri_speakers, | |
hifi_speakers=hifi_speakers, | |
) | |
train_loader = DataLoader( | |
dataset, | |
# 4x80Gb max 10 sec audio | |
# batch_size=20, # self.train_config.batch_size, | |
# 4*80Gb max ~20.4 sec audio | |
batch_size=batch_size, | |
# TODO: find the optimal num_workers | |
num_workers=num_workers, | |
persistent_workers=True, | |
pin_memory=True, | |
shuffle=shuffle, | |
collate_fn=dataset.collate_fn, | |
) | |
return train_loader | |