PeechTTSv22050 / training /datasets /hifi_libri_dataset.py
nickovchinnikov's picture
Init
9d61c9b
from dataclasses import asdict, dataclass
import os
from pathlib import Path
import tempfile
from typing import Dict, List, Literal, Optional, Tuple
os.environ["TQDM_DISABLE"] = "1"
import tqdm
def nop(it, *a, **k):
return it
tqdm.tqdm = nop
from lhotse import CutSet, RecordingSet, SupervisionSet
from lhotse.cut import MonoCut
from lhotse.recipes import hifitts, libritts
import numpy as np
import soundfile as sf
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from voicefixer import VoiceFixer
from models.config import PreprocessingConfigHifiGAN as PreprocessingConfig
from models.config import get_lang_map, lang2id
from training.preprocess import PreprocessLibriTTS
from training.tools import pad_1D, pad_2D, pad_3D
NUM_JOBS = (os.cpu_count() or 2) - 1
# The selected speakers from the HiFiTTS dataset
speakers_hifi_ids = [
"Cori Samuel", # 92,
"Tony Oliva", # 6671,
"John Van Stan", # 9017,
"Helen Taylor", # 9136,
# "Phil Benson", # 6097,
# "Mike Pelton", # 6670,
# "Maria Kasper", # 8051,
# "Sylviamb", # 11614,
# "Celine Major", # 11697,
# "LikeManyWaters", # 12787,
]
# The selected speakers from the LibriTTS dataset
speakers_libri_ids = list(
map(
str,
[
# train-clean-100
40,
1088,
# train-clean-360
3307,
5935,
# train-other-500
215,
6594,
3867,
5733,
5181,
],
),
)
# Map the speaker ids to string and list of selected speaker ids to set
selected_speakers_ids = {
v: k
for k, v in enumerate(
speakers_hifi_ids + speakers_libri_ids,
)
}
def prep_2_cutset(prep: Dict[str, Dict[str, RecordingSet | SupervisionSet]]) -> CutSet:
r"""Prepare the dataset for the model. This function is used to convert the prepared dataset to a CutSet.
Args:
prep (Dict[str, Dict[str, RecordingSet | SupervisionSet]]): The prepared dataset.
Returns:
CutSet: The dataset prepared for the model.
"""
recordings_hifi = RecordingSet()
supervisions_hifi = SupervisionSet()
for hifi_row in prep.values():
record = hifi_row["recordings"]
supervision = hifi_row["supervisions"]
# Separate the recordings and supervisions
if isinstance(record, RecordingSet):
recordings_hifi += record
if isinstance(supervision, SupervisionSet):
supervisions_hifi += supervision
# Add the recordings and supervisions to the CutSet
return CutSet.from_manifests(
recordings=recordings_hifi,
supervisions=supervisions_hifi,
)
DATASET_TYPES = Literal["hifitts", "libritts"]
@dataclass
class HifiLibriItem:
"""Dataset row for the HiFiTTS and LibriTTS datasets combined in this code.
Args:
id (str): The ID of the item.
wav (Tensor): The waveform of the audio.
mel (Tensor): The mel spectrogram.
pitch (Tensor): The pitch.
text (Tensor): The text.
attn_prior (Tensor): The attention prior.
energy (Tensor): The energy.
raw_text (str): The raw text.
normalized_text (str): The normalized text.
speaker (int): The speaker ID.
pitch_is_normalized (bool): Whether the pitch is normalized.
lang (int): The language ID.
dataset_type (DATASET_TYPES): The type of dataset.
"""
id: str
wav: Tensor
mel: Tensor
pitch: Tensor
text: Tensor
attn_prior: Tensor
energy: Tensor
raw_text: str
normalized_text: str
speaker: int
pitch_is_normalized: bool
lang: int
dataset_type: DATASET_TYPES
class HifiLibriDataset(Dataset):
r"""A PyTorch dataset for loading delightful TTS data."""
def __init__(
self,
lang: str = "en",
root: str = "datasets_cache",
sampling_rate: int = 44100,
hifitts_path: str = "hifitts",
hifi_cutset_file_name: str = "hifi.json.gz",
libritts_path: str = "librittsr",
libritts_cutset_file_name: str = "libri.json.gz",
libritts_subsets: List[str] | str = "all",
cache: bool = False,
cache_dir: str = "/dev/shm",
num_jobs: int = NUM_JOBS,
min_seconds: Optional[float] = None,
max_seconds: Optional[float] = None,
include_libri: bool = True,
libri_speakers: List[str] = speakers_libri_ids,
hifi_speakers: List[str] = speakers_hifi_ids,
):
r"""Initializes the dataset.
Args:
lang (str, optional): The language of the dataset. Defaults to "en".
root (str, optional): The root directory of the dataset. Defaults to "datasets_cache".
sampling_rate (int, optional): The sampling rate of the audio. Defaults to 44100.
hifitts_path (str, optional): The path to the HiFiTTS dataset. Defaults to "hifitts".
hifi_cutset_file_name (str, optional): The file name of the HiFiTTS cutset. Defaults to "hifi.json.gz".
libritts_path (str, optional): The path to the LibriTTS dataset. Defaults to "librittsr".
libritts_cutset_file_name (str, optional): The file name of the LibriTTS cutset. Defaults to "libri.json.gz".
libritts_subsets (Union[List[str], str], optional): The subsets of the LibriTTS dataset to use. Defaults to "all".
cache (bool, optional): Whether to cache the dataset. Defaults to False.
cache_dir (str, optional): The directory to cache the dataset in. Defaults to "/dev/shm".
num_jobs (int, optional): The number of jobs to use for preparing the dataset. Defaults to NUM_JOBS.
min_seconds (Optional[float], optional): The minimum duration of the audio. Defaults from the preprocess config.
max_seconds (Optional[float], optional): The maximum duration of the audio. Defaults from the preprocess config.
include_libri (bool, optional): Whether to include the LibriTTS dataset. Defaults to True.
libri_speakers (List[str], optional): The selected speakers from the LibriTTS dataset. Defaults to selected_speakers_libri_ids.
hifi_speakers (List[str], optional): The selected speakers from the HiFiTTS dataset. Defaults to selected_speakers_hi_fi_ids.
"""
lang_map = get_lang_map(lang)
processing_lang_type = lang_map.processing_lang_type
self.preprocess_config = PreprocessingConfig(
processing_lang_type,
sampling_rate=sampling_rate,
)
self.min_seconds = min_seconds or self.preprocess_config.min_seconds
self.max_seconds = max_seconds or self.preprocess_config.max_seconds
self.dur_filter = (
lambda duration: duration >= self.min_seconds
and duration <= self.max_seconds
)
self.preprocess_libtts = PreprocessLibriTTS(
self.preprocess_config,
lang,
)
self.root_dir = Path(root)
self.voicefixer = VoiceFixer()
# Map the speaker ids to string and list of selected speaker ids to set
self.selected_speakers_libri_ids_ = set(libri_speakers)
self.selected_speakers_hi_fi_ids_ = set(hifi_speakers)
self.cache = cache
self.cache_dir = Path(cache_dir) / f"cache-{hifitts_path}-{libritts_path}"
# Prepare the HiFiTTS dataset
self.hifitts_path = self.root_dir / hifitts_path
hifi_cutset_file_path = self.root_dir / hifi_cutset_file_name
# Initialize the cutset
self.cutset = CutSet()
# Check if the HiFiTTS dataset has been prepared
if hifi_cutset_file_path.exists():
self.cutset_hifi = CutSet.from_file(hifi_cutset_file_path)
else:
hifitts_root = hifitts.download_hifitts(self.hifitts_path)
prepared_hifi = hifitts.prepare_hifitts(
hifitts_root,
num_jobs=num_jobs,
)
# Add the recordings and supervisions to the CutSet
self.cutset_hifi = prep_2_cutset(prepared_hifi)
# Save the prepared HiFiTTS dataset cutset
self.cutset_hifi.to_file(hifi_cutset_file_path)
# Filter the HiFiTTS cutset to only include the selected speakers
self.cutset_hifi = self.cutset_hifi.filter(
lambda cut: isinstance(cut, MonoCut)
and str(cut.supervisions[0].speaker) in self.selected_speakers_hi_fi_ids_
and self.dur_filter(cut.duration),
).to_eager()
# Add the HiFiTTS cutset to the final cutset
self.cutset += self.cutset_hifi
if include_libri:
# Prepare the LibriTTS dataset
self.libritts_path = self.root_dir / libritts_path
libritts_cutset_file_path = self.root_dir / libritts_cutset_file_name
# Check if the LibriTTS dataset has been prepared
if libritts_cutset_file_path.exists():
self.cutset_libri = CutSet.from_file(libritts_cutset_file_path)
else:
libritts_root = libritts.download_librittsr(
self.libritts_path,
dataset_parts=libritts_subsets,
)
prepared_libri = libritts.prepare_librittsr(
libritts_root / "LibriTTS_R",
dataset_parts=libritts_subsets,
num_jobs=num_jobs,
)
# Add the recordings and supervisions to the CutSet
self.cutset_libri = prep_2_cutset(prepared_libri)
# Save the prepared cutset for LibriTTS
self.cutset_libri.to_file(libritts_cutset_file_path)
# Filter the libri cutset to only include the selected speakers
self.cutset_libri = self.cutset_libri.filter(
lambda cut: isinstance(cut, MonoCut)
and str(cut.supervisions[0].speaker)
in self.selected_speakers_libri_ids_
and self.dur_filter(cut.duration),
).to_eager()
# Add the LibriTTS cutset to the final cutset
self.cutset += self.cutset_libri
# to_eager() is used to evaluates all lazy operations on this manifest
self.cutset = self.cutset.to_eager()
def get_cache_subdir_path(self, idx: int) -> Path:
r"""Calculate the path to the cache subdirectory.
Args:
idx (int): The index of the cache subdirectory.
Returns:
Path: The path to the cache subdirectory.
"""
return self.cache_dir / str(((idx // 1000) + 1) * 1000)
def get_cache_file_path(self, idx: int) -> Path:
r"""Calculate the path to the cache file.
Args:
idx (int): The index of the cache file.
Returns:
Path: The path to the cache file.
"""
return self.get_cache_subdir_path(idx) / f"{idx}.pt"
def __len__(self) -> int:
r"""Returns the length of the dataset.
Returns:
int: The length of the dataset.
"""
return len(self.cutset)
def __getitem__(self, idx: int) -> HifiLibriItem:
r"""Returns the item at the specified index.
Args:
idx (int): The index of the item.
Returns:
HifiLibriItem: The item at the specified index.
"""
cache_file = self.get_cache_file_path(idx)
if self.cache and cache_file.exists():
cached_data: Dict = torch.load(cache_file)
# Cast the cached data to the PreprocessForAcousticResult class
result = HifiLibriItem(**cached_data)
return result
cutset = self.cutset[idx]
if isinstance(cutset, MonoCut) and cutset.recording is not None:
dataset_speaker_id = str(cutset.supervisions[0].speaker)
# Map the dataset speaker id to the speaker id in the model
speaker_id = selected_speakers_ids.get(
dataset_speaker_id,
len(selected_speakers_ids) + 1,
)
# Run voicefixer only for the libri speakers
if str(dataset_speaker_id) in self.selected_speakers_libri_ids_:
audio_path = cutset.recording.sources[0].source
# Restore LibriTTS-R audio
with tempfile.NamedTemporaryFile(
suffix=".wav",
delete=True,
) as out_file:
self.voicefixer.restore(
input=audio_path, # low quality .wav/.flac file
output=out_file.name, # save file path
cuda=False, # GPU acceleration
mode=0,
)
audio, _ = sf.read(out_file.name)
# Convert the np audio to a tensor
audio = torch.from_numpy(audio).float().unsqueeze(0)
else:
# Load the audio from the cutset
audio = torch.from_numpy(cutset.load_audio())
text: str = str(cutset.supervisions[0].text)
fileid = str(cutset.supervisions[0].recording_id)
split_fileid = fileid.split("_")
chapter_id = split_fileid[1]
utterance_id = split_fileid[-1]
libri_row = (
audio,
cutset.sampling_rate,
text,
text,
speaker_id,
chapter_id,
utterance_id,
)
data = self.preprocess_libtts.acoustic(libri_row)
if data is None:
rand_idx = int(
torch.randint(
0,
self.__len__(),
(1,),
).item(),
)
return self.__getitem__(rand_idx)
data.wav = data.wav.unsqueeze(0)
result = HifiLibriItem(
id=data.utterance_id,
wav=data.wav,
mel=data.mel,
pitch=data.pitch,
text=data.phones,
attn_prior=data.attn_prior,
energy=data.energy,
raw_text=data.raw_text,
normalized_text=data.normalized_text,
speaker=speaker_id,
pitch_is_normalized=data.pitch_is_normalized,
lang=lang2id["en"],
dataset_type="hifitts" if idx < len(self.cutset_hifi) else "libritts",
)
if self.cache:
# Create the cache subdirectory if it doesn't exist
Path.mkdir(
self.get_cache_subdir_path(idx),
parents=True,
exist_ok=True,
)
# Save the preprocessed data to the cache
torch.save(asdict(result), cache_file)
return result
else:
raise FileNotFoundError(f"Cut not found at index {idx}.")
def __iter__(self):
r"""Method makes the class iterable. It iterates over the `_walker` attribute
and for each item, it gets the corresponding item from the dataset using the
`__getitem__` method.
Yields:
The item from the dataset corresponding to the current item in `_walker`.
"""
for item in range(self.__len__()):
yield self.__getitem__(item)
def collate_fn(self, data: List[HifiLibriItem]) -> List:
r"""Collates a batch of data samples.
Args:
data (List[HifiLibriItem]): A list of data samples.
Returns:
List: A list of reprocessed data batches.
"""
data_size = len(data)
idxs = list(range(data_size))
# Initialize empty lists to store extracted values
empty_lists: List[List] = [[] for _ in range(12)]
(
ids,
speakers,
texts,
raw_texts,
mels,
pitches,
attn_priors,
langs,
src_lens,
mel_lens,
wavs,
energy,
) = empty_lists
# Extract fields from data dictionary and populate the lists
for idx in idxs:
data_entry = data[idx]
ids.append(data_entry.id)
speakers.append(data_entry.speaker)
texts.append(data_entry.text)
raw_texts.append(data_entry.raw_text)
mels.append(data_entry.mel)
pitches.append(data_entry.pitch)
attn_priors.append(data_entry.attn_prior)
langs.append(data_entry.lang)
src_lens.append(data_entry.text.shape[0])
mel_lens.append(data_entry.mel.shape[1])
wavs.append(data_entry.wav)
energy.append(data_entry.energy)
# Convert langs, src_lens, and mel_lens to numpy arrays
langs = np.array(langs)
src_lens = np.array(src_lens)
mel_lens = np.array(mel_lens)
# NOTE: Instead of the pitches for the whole dataset, used stat for the batch
# Take only min and max values for pitch
pitches_stat = list(self.normalize_pitch(pitches)[:2])
texts = pad_1D(texts)
mels = pad_2D(mels)
pitches = pad_1D(pitches)
attn_priors = pad_3D(attn_priors, len(idxs), max(src_lens), max(mel_lens))
speakers = np.repeat(
np.expand_dims(np.array(speakers), axis=1),
texts.shape[1],
axis=1,
)
langs = np.repeat(
np.expand_dims(np.array(langs), axis=1),
texts.shape[1],
axis=1,
)
wavs = pad_2D(wavs)
energy = pad_2D(energy)
return [
ids,
raw_texts,
torch.from_numpy(speakers),
texts.int(),
torch.from_numpy(src_lens),
mels,
pitches,
pitches_stat,
torch.from_numpy(mel_lens),
torch.from_numpy(langs),
attn_priors,
wavs,
energy,
]
def normalize_pitch(
self,
pitches: List[torch.Tensor],
) -> Tuple[float, float, float, float]:
r"""Normalizes the pitch values.
Args:
pitches (List[torch.Tensor]): A list of pitch values.
Returns:
Tuple: A tuple containing the normalized pitch values.
"""
pitches_t = torch.concatenate(pitches)
min_value = torch.min(pitches_t).item()
max_value = torch.max(pitches_t).item()
mean = torch.mean(pitches_t).item()
std = torch.std(pitches_t).item()
return min_value, max_value, mean, std
def train_dataloader(
batch_size: int = 6,
num_workers: int = 5,
sampling_rate: int = 22050,
shuffle: bool = False,
lang: str = "en",
root: str = "datasets_cache",
hifitts_path: str = "hifitts",
hifi_cutset_file_name: str = "hifi.json.gz",
libritts_path: str = "librittsr",
libritts_cutset_file_name: str = "libri.json.gz",
libritts_subsets: List[str] | str = "all",
cache: bool = False,
cache_dir: str = "/dev/shm",
include_libri: bool = True,
libri_speakers: List[str] = speakers_libri_ids,
hifi_speakers: List[str] = speakers_hifi_ids,
) -> DataLoader:
r"""Returns the training dataloader, that is using the HifiLibriDataset dataset.
Args:
batch_size (int): The batch size.
num_workers (int): The number of workers.
sampling_rate (int): The sampling rate of the audio. Defaults to 22050.
shuffle (bool): Whether to shuffle the dataset.
lang (str): The language of the dataset.
root (str): The root directory of the dataset.
hifitts_path (str): The path to the HiFiTTS dataset.
hifi_cutset_file_name (str): The file name of the HiFiTTS cutset.
libritts_path (str): The path to the LibriTTS dataset.
libritts_cutset_file_name (str): The file name of the LibriTTS cutset.
libritts_subsets (List[str] | str): The subsets of the LibriTTS dataset to use.
cache (bool): Whether to cache the dataset.
cache_dir (str): The directory to cache the dataset in.
include_libri (bool): Whether to include the LibriTTS dataset.
libri_speakers (List[str]): The selected speakers from the LibriTTS dataset.
hifi_speakers (List[str]): The selected speakers from the HiFiTTS dataset.
Returns:
DataLoader: The training dataloader.
"""
dataset = HifiLibriDataset(
root=root,
hifitts_path=hifitts_path,
sampling_rate=sampling_rate,
hifi_cutset_file_name=hifi_cutset_file_name,
libritts_path=libritts_path,
libritts_cutset_file_name=libritts_cutset_file_name,
libritts_subsets=libritts_subsets,
cache=cache,
cache_dir=cache_dir,
lang=lang,
include_libri=include_libri,
libri_speakers=libri_speakers,
hifi_speakers=hifi_speakers,
)
train_loader = DataLoader(
dataset,
# 4x80Gb max 10 sec audio
# batch_size=20, # self.train_config.batch_size,
# 4*80Gb max ~20.4 sec audio
batch_size=batch_size,
# TODO: find the optimal num_workers
num_workers=num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=shuffle,
collate_fn=dataset.collate_fn,
)
return train_loader