Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- .pre-commit-config.yaml +2 -2
- Dockerfile +1 -2
- README_REPO.md +2 -2
- ckpts/README.md +2 -11
- src/f5_tts/api.py +2 -2
- src/f5_tts/eval/eval_infer_batch.py +4 -3
- src/f5_tts/eval/utils_eval.py +3 -3
- src/f5_tts/infer/infer_cli.py +7 -6
- src/f5_tts/infer/speech_edit.py +3 -2
- src/f5_tts/model/trainer.py +1 -1
- src/f5_tts/scripts/count_max_epoch.py +1 -1
- src/f5_tts/socket_server.py +2 -2
- src/f5_tts/train/datasets/prepare_csv_wavs.py +2 -2
- src/f5_tts/train/datasets/prepare_emilia.py +1 -1
- src/f5_tts/train/datasets/prepare_libritts.py +1 -1
- src/f5_tts/train/datasets/prepare_ljspeech.py +1 -1
- src/f5_tts/train/train.py +31 -31
.pre-commit-config.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
repos:
|
2 |
- repo: https://github.com/astral-sh/ruff-pre-commit
|
3 |
# Ruff version.
|
4 |
-
rev: v0.
|
5 |
hooks:
|
6 |
# Run the linter.
|
7 |
- id: ruff
|
@@ -9,6 +9,6 @@ repos:
|
|
9 |
# Run the formatter.
|
10 |
- id: ruff-format
|
11 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
12 |
-
rev:
|
13 |
hooks:
|
14 |
- id: check-yaml
|
|
|
1 |
repos:
|
2 |
- repo: https://github.com/astral-sh/ruff-pre-commit
|
3 |
# Ruff version.
|
4 |
+
rev: v0.11.2
|
5 |
hooks:
|
6 |
# Run the linter.
|
7 |
- id: ruff
|
|
|
9 |
# Run the formatter.
|
10 |
- id: ruff-format
|
11 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
12 |
+
rev: v5.0.0
|
13 |
hooks:
|
14 |
- id: check-yaml
|
Dockerfile
CHANGED
@@ -23,9 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
|
|
23 |
|
24 |
ENV SHELL=/bin/bash
|
25 |
|
26 |
-
# models are downloaded into this folder, so user should mount it
|
27 |
VOLUME /root/.cache/huggingface/hub/
|
28 |
-
|
29 |
EXPOSE 7860
|
30 |
|
31 |
WORKDIR /workspace/F5-TTS
|
|
|
23 |
|
24 |
ENV SHELL=/bin/bash
|
25 |
|
|
|
26 |
VOLUME /root/.cache/huggingface/hub/
|
27 |
+
|
28 |
EXPOSE 7860
|
29 |
|
30 |
WORKDIR /workspace/F5-TTS
|
README_REPO.md
CHANGED
@@ -203,7 +203,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
|
|
203 |
|
204 |
## Development
|
205 |
|
206 |
-
Use pre-commit to ensure code quality (will run linters and formatters automatically)
|
207 |
|
208 |
```bash
|
209 |
pip install pre-commit
|
@@ -216,7 +216,7 @@ When making a pull request, before each commit, run:
|
|
216 |
pre-commit run --all-files
|
217 |
```
|
218 |
|
219 |
-
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
220 |
|
221 |
|
222 |
## Acknowledgements
|
|
|
203 |
|
204 |
## Development
|
205 |
|
206 |
+
Use pre-commit to ensure code quality (will run linters and formatters automatically):
|
207 |
|
208 |
```bash
|
209 |
pip install pre-commit
|
|
|
216 |
pre-commit run --all-files
|
217 |
```
|
218 |
|
219 |
+
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
|
220 |
|
221 |
|
222 |
## Acknowledgements
|
ckpts/README.md
CHANGED
@@ -1,12 +1,3 @@
|
|
|
|
1 |
|
2 |
-
|
3 |
-
|
4 |
-
```
|
5 |
-
ckpts/
|
6 |
-
F5TTS_v1_Base/
|
7 |
-
model_1250000.safetensors
|
8 |
-
F5TTS_Base/
|
9 |
-
model_1200000.safetensors
|
10 |
-
E2TTS_Base/
|
11 |
-
model_1200000.safetensors
|
12 |
-
```
|
|
|
1 |
+
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
|
2 |
|
3 |
+
Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/f5_tts/api.py
CHANGED
@@ -5,6 +5,7 @@ from importlib.resources import files
|
|
5 |
import soundfile as sf
|
6 |
import tqdm
|
7 |
from cached_path import cached_path
|
|
|
8 |
from omegaconf import OmegaConf
|
9 |
|
10 |
from f5_tts.infer.utils_infer import (
|
@@ -16,7 +17,6 @@ from f5_tts.infer.utils_infer import (
|
|
16 |
remove_silence_for_generated_wav,
|
17 |
save_spectrogram,
|
18 |
)
|
19 |
-
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
20 |
from f5_tts.model.utils import seed_everything
|
21 |
|
22 |
|
@@ -33,7 +33,7 @@ class F5TTS:
|
|
33 |
hf_cache_dir=None,
|
34 |
):
|
35 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
36 |
-
model_cls =
|
37 |
model_arc = model_cfg.model.arch
|
38 |
|
39 |
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
|
|
5 |
import soundfile as sf
|
6 |
import tqdm
|
7 |
from cached_path import cached_path
|
8 |
+
from hydra.utils import get_class
|
9 |
from omegaconf import OmegaConf
|
10 |
|
11 |
from f5_tts.infer.utils_infer import (
|
|
|
17 |
remove_silence_for_generated_wav,
|
18 |
save_spectrogram,
|
19 |
)
|
|
|
20 |
from f5_tts.model.utils import seed_everything
|
21 |
|
22 |
|
|
|
33 |
hf_cache_dir=None,
|
34 |
):
|
35 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
36 |
+
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
37 |
model_arc = model_cfg.model.arch
|
38 |
|
39 |
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -10,6 +10,7 @@ from importlib.resources import files
|
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
from accelerate import Accelerator
|
|
|
13 |
from omegaconf import OmegaConf
|
14 |
from tqdm import tqdm
|
15 |
|
@@ -19,7 +20,7 @@ from f5_tts.eval.utils_eval import (
|
|
19 |
get_seedtts_testset_metainfo,
|
20 |
)
|
21 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
22 |
-
from f5_tts.model import CFM
|
23 |
from f5_tts.model.utils import get_tokenizer
|
24 |
|
25 |
accelerator = Accelerator()
|
@@ -65,7 +66,7 @@ def main():
|
|
65 |
no_ref_audio = False
|
66 |
|
67 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
68 |
-
model_cls =
|
69 |
model_arc = model_cfg.model.arch
|
70 |
|
71 |
dataset_name = model_cfg.datasets.name
|
@@ -195,7 +196,7 @@ def main():
|
|
195 |
accelerator.wait_for_everyone()
|
196 |
if accelerator.is_main_process:
|
197 |
timediff = time.time() - start
|
198 |
-
print(f"Done batch inference in {timediff / 60
|
199 |
|
200 |
|
201 |
if __name__ == "__main__":
|
|
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
from accelerate import Accelerator
|
13 |
+
from hydra.utils import get_class
|
14 |
from omegaconf import OmegaConf
|
15 |
from tqdm import tqdm
|
16 |
|
|
|
20 |
get_seedtts_testset_metainfo,
|
21 |
)
|
22 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
23 |
+
from f5_tts.model import CFM
|
24 |
from f5_tts.model.utils import get_tokenizer
|
25 |
|
26 |
accelerator = Accelerator()
|
|
|
66 |
no_ref_audio = False
|
67 |
|
68 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
69 |
+
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
70 |
model_arc = model_cfg.model.arch
|
71 |
|
72 |
dataset_name = model_cfg.datasets.name
|
|
|
196 |
accelerator.wait_for_everyone()
|
197 |
if accelerator.is_main_process:
|
198 |
timediff = time.time() - start
|
199 |
+
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
|
200 |
|
201 |
|
202 |
if __name__ == "__main__":
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -148,9 +148,9 @@ def get_inference_prompt(
|
|
148 |
|
149 |
# deal with batch
|
150 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
151 |
-
assert (
|
152 |
-
|
153 |
-
)
|
154 |
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
155 |
|
156 |
utts[bucket_i].append(utt)
|
|
|
148 |
|
149 |
# deal with batch
|
150 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
151 |
+
assert min_tokens <= total_mel_len <= max_tokens, (
|
152 |
+
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
153 |
+
)
|
154 |
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
155 |
|
156 |
utts[bucket_i].append(utt)
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -10,6 +10,7 @@ import numpy as np
|
|
10 |
import soundfile as sf
|
11 |
import tomli
|
12 |
from cached_path import cached_path
|
|
|
13 |
from omegaconf import OmegaConf
|
14 |
|
15 |
from f5_tts.infer.utils_infer import (
|
@@ -27,7 +28,6 @@ from f5_tts.infer.utils_infer import (
|
|
27 |
preprocess_ref_audio_text,
|
28 |
remove_silence_for_generated_wav,
|
29 |
)
|
30 |
-
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
31 |
|
32 |
|
33 |
parser = argparse.ArgumentParser(
|
@@ -246,13 +246,14 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
|
|
246 |
|
247 |
model_cfg = OmegaConf.load(
|
248 |
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
249 |
-
)
|
250 |
-
model_cls =
|
|
|
251 |
|
252 |
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
253 |
|
254 |
if model != "F5TTS_Base":
|
255 |
-
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
|
256 |
|
257 |
# override for previous models
|
258 |
if model == "F5TTS_Base":
|
@@ -269,7 +270,7 @@ if not ckpt_file:
|
|
269 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
270 |
|
271 |
print(f"Using {model}...")
|
272 |
-
ema_model = load_model(model_cls,
|
273 |
|
274 |
|
275 |
# inference process
|
@@ -332,7 +333,7 @@ def main():
|
|
332 |
if len(gen_text_) > 200:
|
333 |
gen_text_ = gen_text_[:200] + " ... "
|
334 |
sf.write(
|
335 |
-
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
|
336 |
audio_segment,
|
337 |
final_sample_rate,
|
338 |
)
|
|
|
10 |
import soundfile as sf
|
11 |
import tomli
|
12 |
from cached_path import cached_path
|
13 |
+
from hydra.utils import get_class
|
14 |
from omegaconf import OmegaConf
|
15 |
|
16 |
from f5_tts.infer.utils_infer import (
|
|
|
28 |
preprocess_ref_audio_text,
|
29 |
remove_silence_for_generated_wav,
|
30 |
)
|
|
|
31 |
|
32 |
|
33 |
parser = argparse.ArgumentParser(
|
|
|
246 |
|
247 |
model_cfg = OmegaConf.load(
|
248 |
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
249 |
+
)
|
250 |
+
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
251 |
+
model_arc = model_cfg.model.arch
|
252 |
|
253 |
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
254 |
|
255 |
if model != "F5TTS_Base":
|
256 |
+
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
|
257 |
|
258 |
# override for previous models
|
259 |
if model == "F5TTS_Base":
|
|
|
270 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
271 |
|
272 |
print(f"Using {model}...")
|
273 |
+
ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
274 |
|
275 |
|
276 |
# inference process
|
|
|
333 |
if len(gen_text_) > 200:
|
334 |
gen_text_ = gen_text_[:200] + " ... "
|
335 |
sf.write(
|
336 |
+
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
337 |
audio_segment,
|
338 |
final_sample_rate,
|
339 |
)
|
src/f5_tts/infer/speech_edit.py
CHANGED
@@ -7,10 +7,11 @@ from importlib.resources import files
|
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
import torchaudio
|
|
|
10 |
from omegaconf import OmegaConf
|
11 |
|
12 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
13 |
-
from f5_tts.model import CFM
|
14 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
15 |
|
16 |
device = (
|
@@ -40,7 +41,7 @@ target_rms = 0.1
|
|
40 |
|
41 |
|
42 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
43 |
-
model_cls =
|
44 |
model_arc = model_cfg.model.arch
|
45 |
|
46 |
dataset_name = model_cfg.datasets.name
|
|
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
import torchaudio
|
10 |
+
from hydra.utils import get_class
|
11 |
from omegaconf import OmegaConf
|
12 |
|
13 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
14 |
+
from f5_tts.model import CFM
|
15 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
16 |
|
17 |
device = (
|
|
|
41 |
|
42 |
|
43 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
44 |
+
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
45 |
model_arc = model_cfg.model.arch
|
46 |
|
47 |
dataset_name = model_cfg.datasets.name
|
src/f5_tts/model/trainer.py
CHANGED
@@ -350,7 +350,7 @@ class Trainer:
|
|
350 |
|
351 |
progress_bar = tqdm(
|
352 |
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
353 |
-
desc=f"Epoch {epoch+1}/{self.epochs}",
|
354 |
unit="update",
|
355 |
disable=not self.accelerator.is_local_main_process,
|
356 |
initial=progress_bar_initial,
|
|
|
350 |
|
351 |
progress_bar = tqdm(
|
352 |
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
353 |
+
desc=f"Epoch {epoch + 1}/{self.epochs}",
|
354 |
unit="update",
|
355 |
disable=not self.accelerator.is_local_main_process,
|
356 |
initial=progress_bar_initial,
|
src/f5_tts/scripts/count_max_epoch.py
CHANGED
@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
|
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
-
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
28 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
29 |
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
30 |
|
|
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
+
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
|
28 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
29 |
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
30 |
|
src/f5_tts/socket_server.py
CHANGED
@@ -13,9 +13,9 @@ from importlib.resources import files
|
|
13 |
import torch
|
14 |
import torchaudio
|
15 |
from huggingface_hub import hf_hub_download
|
|
|
16 |
from omegaconf import OmegaConf
|
17 |
|
18 |
-
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
|
19 |
from f5_tts.infer.utils_infer import (
|
20 |
chunk_text,
|
21 |
preprocess_ref_audio_text,
|
@@ -80,7 +80,7 @@ class TTSStreamingProcessor:
|
|
80 |
else "cpu"
|
81 |
)
|
82 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
83 |
-
self.model_cls =
|
84 |
self.model_arc = model_cfg.model.arch
|
85 |
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
86 |
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
|
|
|
13 |
import torch
|
14 |
import torchaudio
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
+
from hydra.utils import get_class
|
17 |
from omegaconf import OmegaConf
|
18 |
|
|
|
19 |
from f5_tts.infer.utils_infer import (
|
20 |
chunk_text,
|
21 |
preprocess_ref_audio_text,
|
|
|
80 |
else "cpu"
|
81 |
)
|
82 |
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
83 |
+
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
84 |
self.model_arc = model_cfg.model.arch
|
85 |
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
86 |
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
|
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
@@ -122,7 +122,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
|
122 |
for future in tqdm(
|
123 |
chunk_futures,
|
124 |
total=len(chunk),
|
125 |
-
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
|
126 |
):
|
127 |
try:
|
128 |
result = future.result()
|
@@ -233,7 +233,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|
233 |
dataset_name = out_dir.stem
|
234 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
235 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
236 |
-
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
237 |
|
238 |
|
239 |
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
|
|
122 |
for future in tqdm(
|
123 |
chunk_futures,
|
124 |
total=len(chunk),
|
125 |
+
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
|
126 |
):
|
127 |
try:
|
128 |
result = future.result()
|
|
|
233 |
dataset_name = out_dir.stem
|
234 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
235 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
236 |
+
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
237 |
|
238 |
|
239 |
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
src/f5_tts/train/datasets/prepare_emilia.py
CHANGED
@@ -198,7 +198,7 @@ def main():
|
|
198 |
|
199 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
200 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
201 |
-
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
202 |
if "ZH" in langs:
|
203 |
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
204 |
if "EN" in langs:
|
|
|
198 |
|
199 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
200 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
201 |
+
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
202 |
if "ZH" in langs:
|
203 |
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
204 |
if "EN" in langs:
|
src/f5_tts/train/datasets/prepare_libritts.py
CHANGED
@@ -72,7 +72,7 @@ def main():
|
|
72 |
|
73 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
74 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
75 |
-
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
76 |
|
77 |
|
78 |
if __name__ == "__main__":
|
|
|
72 |
|
73 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
74 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
75 |
+
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
76 |
|
77 |
|
78 |
if __name__ == "__main__":
|
src/f5_tts/train/datasets/prepare_ljspeech.py
CHANGED
@@ -50,7 +50,7 @@ def main():
|
|
50 |
|
51 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
52 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
53 |
-
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
54 |
|
55 |
|
56 |
if __name__ == "__main__":
|
|
|
50 |
|
51 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
52 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
53 |
+
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
54 |
|
55 |
|
56 |
if __name__ == "__main__":
|
src/f5_tts/train/train.py
CHANGED
@@ -6,7 +6,7 @@ from importlib.resources import files
|
|
6 |
import hydra
|
7 |
from omegaconf import OmegaConf
|
8 |
|
9 |
-
from f5_tts.model import CFM,
|
10 |
from f5_tts.model.dataset import load_dataset
|
11 |
from f5_tts.model.utils import get_tokenizer
|
12 |
|
@@ -14,60 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
|
|
14 |
|
15 |
|
16 |
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
17 |
-
def main(
|
18 |
-
model_cls =
|
19 |
-
model_arc =
|
20 |
-
tokenizer =
|
21 |
-
mel_spec_type =
|
22 |
|
23 |
-
exp_name = f"{
|
24 |
wandb_resume_id = None
|
25 |
|
26 |
# set text tokenizer
|
27 |
if tokenizer != "custom":
|
28 |
-
tokenizer_path =
|
29 |
else:
|
30 |
-
tokenizer_path =
|
31 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
32 |
|
33 |
# set model
|
34 |
model = CFM(
|
35 |
-
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=
|
36 |
-
mel_spec_kwargs=
|
37 |
vocab_char_map=vocab_char_map,
|
38 |
)
|
39 |
|
40 |
# init trainer
|
41 |
trainer = Trainer(
|
42 |
model,
|
43 |
-
epochs=
|
44 |
-
learning_rate=
|
45 |
-
num_warmup_updates=
|
46 |
-
save_per_updates=
|
47 |
-
keep_last_n_checkpoints=
|
48 |
-
checkpoint_path=str(files("f5_tts").joinpath(f"../../{
|
49 |
-
batch_size_per_gpu=
|
50 |
-
batch_size_type=
|
51 |
-
max_samples=
|
52 |
-
grad_accumulation_steps=
|
53 |
-
max_grad_norm=
|
54 |
-
logger=
|
55 |
wandb_project="CFM-TTS",
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
-
last_per_updates=
|
59 |
-
log_samples=
|
60 |
-
bnb_optimizer=
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
-
is_local_vocoder=
|
63 |
-
local_vocoder_path=
|
64 |
-
|
65 |
)
|
66 |
|
67 |
-
train_dataset = load_dataset(
|
68 |
trainer.train(
|
69 |
train_dataset,
|
70 |
-
num_workers=
|
71 |
resumable_with_seed=666, # seed for shuffling dataset
|
72 |
)
|
73 |
|
|
|
6 |
import hydra
|
7 |
from omegaconf import OmegaConf
|
8 |
|
9 |
+
from f5_tts.model import CFM, Trainer
|
10 |
from f5_tts.model.dataset import load_dataset
|
11 |
from f5_tts.model.utils import get_tokenizer
|
12 |
|
|
|
14 |
|
15 |
|
16 |
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
17 |
+
def main(model_cfg):
|
18 |
+
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
19 |
+
model_arc = model_cfg.model.arch
|
20 |
+
tokenizer = model_cfg.model.tokenizer
|
21 |
+
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
22 |
|
23 |
+
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
|
24 |
wandb_resume_id = None
|
25 |
|
26 |
# set text tokenizer
|
27 |
if tokenizer != "custom":
|
28 |
+
tokenizer_path = model_cfg.datasets.name
|
29 |
else:
|
30 |
+
tokenizer_path = model_cfg.model.tokenizer_path
|
31 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
32 |
|
33 |
# set model
|
34 |
model = CFM(
|
35 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
|
36 |
+
mel_spec_kwargs=model_cfg.model.mel_spec,
|
37 |
vocab_char_map=vocab_char_map,
|
38 |
)
|
39 |
|
40 |
# init trainer
|
41 |
trainer = Trainer(
|
42 |
model,
|
43 |
+
epochs=model_cfg.optim.epochs,
|
44 |
+
learning_rate=model_cfg.optim.learning_rate,
|
45 |
+
num_warmup_updates=model_cfg.optim.num_warmup_updates,
|
46 |
+
save_per_updates=model_cfg.ckpts.save_per_updates,
|
47 |
+
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
|
48 |
+
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
|
49 |
+
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
|
50 |
+
batch_size_type=model_cfg.datasets.batch_size_type,
|
51 |
+
max_samples=model_cfg.datasets.max_samples,
|
52 |
+
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
|
53 |
+
max_grad_norm=model_cfg.optim.max_grad_norm,
|
54 |
+
logger=model_cfg.ckpts.logger,
|
55 |
wandb_project="CFM-TTS",
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
+
last_per_updates=model_cfg.ckpts.last_per_updates,
|
59 |
+
log_samples=model_cfg.ckpts.log_samples,
|
60 |
+
bnb_optimizer=model_cfg.optim.bnb_optimizer,
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
+
is_local_vocoder=model_cfg.model.vocoder.is_local,
|
63 |
+
local_vocoder_path=model_cfg.model.vocoder.local_path,
|
64 |
+
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
|
65 |
)
|
66 |
|
67 |
+
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
|
68 |
trainer.train(
|
69 |
train_dataset,
|
70 |
+
num_workers=model_cfg.datasets.num_workers,
|
71 |
resumable_with_seed=666, # seed for shuffling dataset
|
72 |
)
|
73 |
|