Spaces:
Sleeping
Sleeping
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/dataset.py +17 -3
model/dataset.py
CHANGED
@@ -8,8 +8,10 @@ from torch.utils.data import Dataset, Sampler
|
|
8 |
import torchaudio
|
9 |
from datasets import load_from_disk
|
10 |
from datasets import Dataset as Dataset_
|
|
|
11 |
|
12 |
from model.modules import MelSpec
|
|
|
13 |
|
14 |
|
15 |
class HFDataset(Dataset):
|
@@ -77,15 +79,22 @@ class CustomDataset(Dataset):
|
|
77 |
hop_length=256,
|
78 |
n_mel_channels=100,
|
79 |
preprocessed_mel=False,
|
|
|
80 |
):
|
81 |
self.data = custom_dataset
|
82 |
self.durations = durations
|
83 |
self.target_sample_rate = target_sample_rate
|
84 |
self.hop_length = hop_length
|
85 |
self.preprocessed_mel = preprocessed_mel
|
|
|
86 |
if not preprocessed_mel:
|
87 |
-
self.mel_spectrogram =
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
def get_frame_len(self, index):
|
@@ -201,6 +210,7 @@ def load_dataset(
|
|
201 |
tokenizer: str = "pinyin",
|
202 |
dataset_type: str = "CustomDataset",
|
203 |
audio_type: str = "raw",
|
|
|
204 |
mel_spec_kwargs: dict = dict(),
|
205 |
) -> CustomDataset | HFDataset:
|
206 |
"""
|
@@ -224,7 +234,11 @@ def load_dataset(
|
|
224 |
data_dict = json.load(f)
|
225 |
durations = data_dict["duration"]
|
226 |
train_dataset = CustomDataset(
|
227 |
-
train_dataset,
|
|
|
|
|
|
|
|
|
228 |
)
|
229 |
|
230 |
elif dataset_type == "CustomDatasetPath":
|
|
|
8 |
import torchaudio
|
9 |
from datasets import load_from_disk
|
10 |
from datasets import Dataset as Dataset_
|
11 |
+
from torch import nn
|
12 |
|
13 |
from model.modules import MelSpec
|
14 |
+
from model.utils import default
|
15 |
|
16 |
|
17 |
class HFDataset(Dataset):
|
|
|
79 |
hop_length=256,
|
80 |
n_mel_channels=100,
|
81 |
preprocessed_mel=False,
|
82 |
+
mel_spec_module: nn.Module | None = None,
|
83 |
):
|
84 |
self.data = custom_dataset
|
85 |
self.durations = durations
|
86 |
self.target_sample_rate = target_sample_rate
|
87 |
self.hop_length = hop_length
|
88 |
self.preprocessed_mel = preprocessed_mel
|
89 |
+
|
90 |
if not preprocessed_mel:
|
91 |
+
self.mel_spectrogram = default(
|
92 |
+
mel_spec_module,
|
93 |
+
MelSpec(
|
94 |
+
target_sample_rate=target_sample_rate,
|
95 |
+
hop_length=hop_length,
|
96 |
+
n_mel_channels=n_mel_channels,
|
97 |
+
),
|
98 |
)
|
99 |
|
100 |
def get_frame_len(self, index):
|
|
|
210 |
tokenizer: str = "pinyin",
|
211 |
dataset_type: str = "CustomDataset",
|
212 |
audio_type: str = "raw",
|
213 |
+
mel_spec_module: nn.Module | None = None,
|
214 |
mel_spec_kwargs: dict = dict(),
|
215 |
) -> CustomDataset | HFDataset:
|
216 |
"""
|
|
|
234 |
data_dict = json.load(f)
|
235 |
durations = data_dict["duration"]
|
236 |
train_dataset = CustomDataset(
|
237 |
+
train_dataset,
|
238 |
+
durations=durations,
|
239 |
+
preprocessed_mel=preprocessed_mel,
|
240 |
+
mel_spec_module=mel_spec_module,
|
241 |
+
**mel_spec_kwargs,
|
242 |
)
|
243 |
|
244 |
elif dataset_type == "CustomDatasetPath":
|