Spaces:
Running
Running
File size: 954 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from dataclasses import dataclass
import os
from typing import Any, List
import torch
from torch.utils.data import Dataset
@dataclass
class PreprocessedData:
id: Any
raw_text: Any
speaker: Any
text: Any
src_len: Any
mel: Any
pitch: Any
pitch_stat: Any
mel_len: Any
lang: Any
attn_prior: Any
wav: Any
energy: Any
@dataclass
class PreprocessedDataset(Dataset):
def __init__(self, cache_dir: str = "datasets_cache/LibriTTS_preprocessed"):
self.cache_dir = cache_dir
self.data = []
for file in os.listdir(self.cache_dir):
if file.endswith(".pt"):
self.data.extend(torch.load(os.path.join(self.cache_dir, file)))
for file in self.data_files:
self.data.extend(torch.load(os.path.join(self.cache_dir, file)))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
|