mrfakename commited on
Commit
43bc5dc
·
verified ·
1 Parent(s): 519ed19

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

.pre-commit-config.yaml CHANGED
@@ -1,7 +1,7 @@
1
  repos:
2
  - repo: https://github.com/astral-sh/ruff-pre-commit
3
  # Ruff version.
4
- rev: v0.7.0
5
  hooks:
6
  # Run the linter.
7
  - id: ruff
@@ -9,6 +9,6 @@ repos:
9
  # Run the formatter.
10
  - id: ruff-format
11
  - repo: https://github.com/pre-commit/pre-commit-hooks
12
- rev: v2.3.0
13
  hooks:
14
  - id: check-yaml
 
1
  repos:
2
  - repo: https://github.com/astral-sh/ruff-pre-commit
3
  # Ruff version.
4
+ rev: v0.11.2
5
  hooks:
6
  # Run the linter.
7
  - id: ruff
 
9
  # Run the formatter.
10
  - id: ruff-format
11
  - repo: https://github.com/pre-commit/pre-commit-hooks
12
+ rev: v5.0.0
13
  hooks:
14
  - id: check-yaml
Dockerfile CHANGED
@@ -23,9 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
23
 
24
  ENV SHELL=/bin/bash
25
 
26
- # models are downloaded into this folder, so user should mount it
27
  VOLUME /root/.cache/huggingface/hub/
28
- # port the GUI is exposed on by default, if it is run
29
  EXPOSE 7860
30
 
31
  WORKDIR /workspace/F5-TTS
 
23
 
24
  ENV SHELL=/bin/bash
25
 
 
26
  VOLUME /root/.cache/huggingface/hub/
27
+
28
  EXPOSE 7860
29
 
30
  WORKDIR /workspace/F5-TTS
README_REPO.md CHANGED
@@ -203,7 +203,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
203
 
204
  ## Development
205
 
206
- Use pre-commit to ensure code quality (will run linters and formatters automatically)
207
 
208
  ```bash
209
  pip install pre-commit
@@ -216,7 +216,7 @@ When making a pull request, before each commit, run:
216
  pre-commit run --all-files
217
  ```
218
 
219
- Note: Some model components have linting exceptions for E722 to accommodate tensor notation
220
 
221
 
222
  ## Acknowledgements
 
203
 
204
  ## Development
205
 
206
+ Use pre-commit to ensure code quality (will run linters and formatters automatically):
207
 
208
  ```bash
209
  pip install pre-commit
 
216
  pre-commit run --all-files
217
  ```
218
 
219
+ Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
220
 
221
 
222
  ## Acknowledgements
ckpts/README.md CHANGED
@@ -1,12 +1,3 @@
 
1
 
2
- Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
3
-
4
- ```
5
- ckpts/
6
- F5TTS_v1_Base/
7
- model_1250000.safetensors
8
- F5TTS_Base/
9
- model_1200000.safetensors
10
- E2TTS_Base/
11
- model_1200000.safetensors
12
- ```
 
1
+ The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
2
 
3
+ Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.
 
 
 
 
 
 
 
 
 
 
src/f5_tts/api.py CHANGED
@@ -5,6 +5,7 @@ from importlib.resources import files
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
 
8
  from omegaconf import OmegaConf
9
 
10
  from f5_tts.infer.utils_infer import (
@@ -16,7 +17,6 @@ from f5_tts.infer.utils_infer import (
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
18
  )
19
- from f5_tts.model import DiT, UNetT # noqa: F401. used for config
20
  from f5_tts.model.utils import seed_everything
21
 
22
 
@@ -33,7 +33,7 @@ class F5TTS:
33
  hf_cache_dir=None,
34
  ):
35
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
36
- model_cls = globals()[model_cfg.model.backbone]
37
  model_arc = model_cfg.model.arch
38
 
39
  self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
 
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
8
+ from hydra.utils import get_class
9
  from omegaconf import OmegaConf
10
 
11
  from f5_tts.infer.utils_infer import (
 
17
  remove_silence_for_generated_wav,
18
  save_spectrogram,
19
  )
 
20
  from f5_tts.model.utils import seed_everything
21
 
22
 
 
33
  hf_cache_dir=None,
34
  ):
35
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
36
+ model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
37
  model_arc = model_cfg.model.arch
38
 
39
  self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -10,6 +10,7 @@ from importlib.resources import files
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
 
13
  from omegaconf import OmegaConf
14
  from tqdm import tqdm
15
 
@@ -19,7 +20,7 @@ from f5_tts.eval.utils_eval import (
19
  get_seedtts_testset_metainfo,
20
  )
21
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
22
- from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
23
  from f5_tts.model.utils import get_tokenizer
24
 
25
  accelerator = Accelerator()
@@ -65,7 +66,7 @@ def main():
65
  no_ref_audio = False
66
 
67
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
68
- model_cls = globals()[model_cfg.model.backbone]
69
  model_arc = model_cfg.model.arch
70
 
71
  dataset_name = model_cfg.datasets.name
@@ -195,7 +196,7 @@ def main():
195
  accelerator.wait_for_everyone()
196
  if accelerator.is_main_process:
197
  timediff = time.time() - start
198
- print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
199
 
200
 
201
  if __name__ == "__main__":
 
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
13
+ from hydra.utils import get_class
14
  from omegaconf import OmegaConf
15
  from tqdm import tqdm
16
 
 
20
  get_seedtts_testset_metainfo,
21
  )
22
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
23
+ from f5_tts.model import CFM
24
  from f5_tts.model.utils import get_tokenizer
25
 
26
  accelerator = Accelerator()
 
66
  no_ref_audio = False
67
 
68
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
69
+ model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
70
  model_arc = model_cfg.model.arch
71
 
72
  dataset_name = model_cfg.datasets.name
 
196
  accelerator.wait_for_everyone()
197
  if accelerator.is_main_process:
198
  timediff = time.time() - start
199
+ print(f"Done batch inference in {timediff / 60:.2f} minutes.")
200
 
201
 
202
  if __name__ == "__main__":
src/f5_tts/eval/utils_eval.py CHANGED
@@ -148,9 +148,9 @@ def get_inference_prompt(
148
 
149
  # deal with batch
150
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
151
- assert (
152
- min_tokens <= total_mel_len <= max_tokens
153
- ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
154
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
155
 
156
  utts[bucket_i].append(utt)
 
148
 
149
  # deal with batch
150
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
151
+ assert min_tokens <= total_mel_len <= max_tokens, (
152
+ f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
153
+ )
154
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
155
 
156
  utts[bucket_i].append(utt)
src/f5_tts/infer/infer_cli.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  import soundfile as sf
11
  import tomli
12
  from cached_path import cached_path
 
13
  from omegaconf import OmegaConf
14
 
15
  from f5_tts.infer.utils_infer import (
@@ -27,7 +28,6 @@ from f5_tts.infer.utils_infer import (
27
  preprocess_ref_audio_text,
28
  remove_silence_for_generated_wav,
29
  )
30
- from f5_tts.model import DiT, UNetT # noqa: F401. used for config
31
 
32
 
33
  parser = argparse.ArgumentParser(
@@ -246,13 +246,14 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
246
 
247
  model_cfg = OmegaConf.load(
248
  args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
249
- ).model
250
- model_cls = globals()[model_cfg.backbone]
 
251
 
252
  repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
253
 
254
  if model != "F5TTS_Base":
255
- assert vocoder_name == model_cfg.mel_spec.mel_spec_type
256
 
257
  # override for previous models
258
  if model == "F5TTS_Base":
@@ -269,7 +270,7 @@ if not ckpt_file:
269
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
270
 
271
  print(f"Using {model}...")
272
- ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
273
 
274
 
275
  # inference process
@@ -332,7 +333,7 @@ def main():
332
  if len(gen_text_) > 200:
333
  gen_text_ = gen_text_[:200] + " ... "
334
  sf.write(
335
- os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
336
  audio_segment,
337
  final_sample_rate,
338
  )
 
10
  import soundfile as sf
11
  import tomli
12
  from cached_path import cached_path
13
+ from hydra.utils import get_class
14
  from omegaconf import OmegaConf
15
 
16
  from f5_tts.infer.utils_infer import (
 
28
  preprocess_ref_audio_text,
29
  remove_silence_for_generated_wav,
30
  )
 
31
 
32
 
33
  parser = argparse.ArgumentParser(
 
246
 
247
  model_cfg = OmegaConf.load(
248
  args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
249
+ )
250
+ model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
251
+ model_arc = model_cfg.model.arch
252
 
253
  repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
254
 
255
  if model != "F5TTS_Base":
256
+ assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
257
 
258
  # override for previous models
259
  if model == "F5TTS_Base":
 
270
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
271
 
272
  print(f"Using {model}...")
273
+ ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
274
 
275
 
276
  # inference process
 
333
  if len(gen_text_) > 200:
334
  gen_text_ = gen_text_[:200] + " ... "
335
  sf.write(
336
+ os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
337
  audio_segment,
338
  final_sample_rate,
339
  )
src/f5_tts/infer/speech_edit.py CHANGED
@@ -7,10 +7,11 @@ from importlib.resources import files
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
 
10
  from omegaconf import OmegaConf
11
 
12
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
13
- from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
14
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
15
 
16
  device = (
@@ -40,7 +41,7 @@ target_rms = 0.1
40
 
41
 
42
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
43
- model_cls = globals()[model_cfg.model.backbone]
44
  model_arc = model_cfg.model.arch
45
 
46
  dataset_name = model_cfg.datasets.name
 
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
10
+ from hydra.utils import get_class
11
  from omegaconf import OmegaConf
12
 
13
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
14
+ from f5_tts.model import CFM
15
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
16
 
17
  device = (
 
41
 
42
 
43
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
44
+ model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
45
  model_arc = model_cfg.model.arch
46
 
47
  dataset_name = model_cfg.datasets.name
src/f5_tts/model/trainer.py CHANGED
@@ -350,7 +350,7 @@ class Trainer:
350
 
351
  progress_bar = tqdm(
352
  range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
353
- desc=f"Epoch {epoch+1}/{self.epochs}",
354
  unit="update",
355
  disable=not self.accelerator.is_local_main_process,
356
  initial=progress_bar_initial,
 
350
 
351
  progress_bar = tqdm(
352
  range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
353
+ desc=f"Epoch {epoch + 1}/{self.epochs}",
354
  unit="update",
355
  disable=not self.accelerator.is_local_main_process,
356
  initial=progress_bar_initial,
src/f5_tts/scripts/count_max_epoch.py CHANGED
@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
- print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
28
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
  # print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
 
 
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
+ print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
28
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
  # print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
 
src/f5_tts/socket_server.py CHANGED
@@ -13,9 +13,9 @@ from importlib.resources import files
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
 
16
  from omegaconf import OmegaConf
17
 
18
- from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
19
  from f5_tts.infer.utils_infer import (
20
  chunk_text,
21
  preprocess_ref_audio_text,
@@ -80,7 +80,7 @@ class TTSStreamingProcessor:
80
  else "cpu"
81
  )
82
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
83
- self.model_cls = globals()[model_cfg.model.backbone]
84
  self.model_arc = model_cfg.model.arch
85
  self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
86
  self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
 
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
16
+ from hydra.utils import get_class
17
  from omegaconf import OmegaConf
18
 
 
19
  from f5_tts.infer.utils_infer import (
20
  chunk_text,
21
  preprocess_ref_audio_text,
 
80
  else "cpu"
81
  )
82
  model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
83
+ self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
84
  self.model_arc = model_cfg.model.arch
85
  self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
86
  self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -122,7 +122,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
122
  for future in tqdm(
123
  chunk_futures,
124
  total=len(chunk),
125
- desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
126
  ):
127
  try:
128
  result = future.result()
@@ -233,7 +233,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
233
  dataset_name = out_dir.stem
234
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
235
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
236
- print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
237
 
238
 
239
  def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
 
122
  for future in tqdm(
123
  chunk_futures,
124
  total=len(chunk),
125
+ desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
126
  ):
127
  try:
128
  result = future.result()
 
233
  dataset_name = out_dir.stem
234
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
235
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
236
+ print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
237
 
238
 
239
  def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
src/f5_tts/train/datasets/prepare_emilia.py CHANGED
@@ -198,7 +198,7 @@ def main():
198
 
199
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
200
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
201
- print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
202
  if "ZH" in langs:
203
  print(f"Bad zh transcription case: {total_bad_case_zh}")
204
  if "EN" in langs:
 
198
 
199
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
200
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
201
+ print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
202
  if "ZH" in langs:
203
  print(f"Bad zh transcription case: {total_bad_case_zh}")
204
  if "EN" in langs:
src/f5_tts/train/datasets/prepare_libritts.py CHANGED
@@ -72,7 +72,7 @@ def main():
72
 
73
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
74
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
75
- print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
76
 
77
 
78
  if __name__ == "__main__":
 
72
 
73
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
74
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
75
+ print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
76
 
77
 
78
  if __name__ == "__main__":
src/f5_tts/train/datasets/prepare_ljspeech.py CHANGED
@@ -50,7 +50,7 @@ def main():
50
 
51
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
52
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
53
- print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
54
 
55
 
56
  if __name__ == "__main__":
 
50
 
51
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
52
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
53
+ print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
54
 
55
 
56
  if __name__ == "__main__":
src/f5_tts/train/train.py CHANGED
@@ -6,7 +6,7 @@ from importlib.resources import files
6
  import hydra
7
  from omegaconf import OmegaConf
8
 
9
- from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
@@ -14,60 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
14
 
15
 
16
  @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
17
- def main(cfg):
18
- model_cls = globals()[cfg.model.backbone]
19
- model_arc = cfg.model.arch
20
- tokenizer = cfg.model.tokenizer
21
- mel_spec_type = cfg.model.mel_spec.mel_spec_type
22
 
23
- exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
24
  wandb_resume_id = None
25
 
26
  # set text tokenizer
27
  if tokenizer != "custom":
28
- tokenizer_path = cfg.datasets.name
29
  else:
30
- tokenizer_path = cfg.model.tokenizer_path
31
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
32
 
33
  # set model
34
  model = CFM(
35
- transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
36
- mel_spec_kwargs=cfg.model.mel_spec,
37
  vocab_char_map=vocab_char_map,
38
  )
39
 
40
  # init trainer
41
  trainer = Trainer(
42
  model,
43
- epochs=cfg.optim.epochs,
44
- learning_rate=cfg.optim.learning_rate,
45
- num_warmup_updates=cfg.optim.num_warmup_updates,
46
- save_per_updates=cfg.ckpts.save_per_updates,
47
- keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
48
- checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
49
- batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
50
- batch_size_type=cfg.datasets.batch_size_type,
51
- max_samples=cfg.datasets.max_samples,
52
- grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
53
- max_grad_norm=cfg.optim.max_grad_norm,
54
- logger=cfg.ckpts.logger,
55
  wandb_project="CFM-TTS",
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
- last_per_updates=cfg.ckpts.last_per_updates,
59
- log_samples=cfg.ckpts.log_samples,
60
- bnb_optimizer=cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,
62
- is_local_vocoder=cfg.model.vocoder.is_local,
63
- local_vocoder_path=cfg.model.vocoder.local_path,
64
- cfg_dict=OmegaConf.to_container(cfg, resolve=True),
65
  )
66
 
67
- train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
68
  trainer.train(
69
  train_dataset,
70
- num_workers=cfg.datasets.num_workers,
71
  resumable_with_seed=666, # seed for shuffling dataset
72
  )
73
 
 
6
  import hydra
7
  from omegaconf import OmegaConf
8
 
9
+ from f5_tts.model import CFM, Trainer
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
 
14
 
15
 
16
  @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
17
+ def main(model_cfg):
18
+ model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
19
+ model_arc = model_cfg.model.arch
20
+ tokenizer = model_cfg.model.tokenizer
21
+ mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
22
 
23
+ exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
24
  wandb_resume_id = None
25
 
26
  # set text tokenizer
27
  if tokenizer != "custom":
28
+ tokenizer_path = model_cfg.datasets.name
29
  else:
30
+ tokenizer_path = model_cfg.model.tokenizer_path
31
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
32
 
33
  # set model
34
  model = CFM(
35
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
36
+ mel_spec_kwargs=model_cfg.model.mel_spec,
37
  vocab_char_map=vocab_char_map,
38
  )
39
 
40
  # init trainer
41
  trainer = Trainer(
42
  model,
43
+ epochs=model_cfg.optim.epochs,
44
+ learning_rate=model_cfg.optim.learning_rate,
45
+ num_warmup_updates=model_cfg.optim.num_warmup_updates,
46
+ save_per_updates=model_cfg.ckpts.save_per_updates,
47
+ keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
48
+ checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
49
+ batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
50
+ batch_size_type=model_cfg.datasets.batch_size_type,
51
+ max_samples=model_cfg.datasets.max_samples,
52
+ grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
53
+ max_grad_norm=model_cfg.optim.max_grad_norm,
54
+ logger=model_cfg.ckpts.logger,
55
  wandb_project="CFM-TTS",
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
+ last_per_updates=model_cfg.ckpts.last_per_updates,
59
+ log_samples=model_cfg.ckpts.log_samples,
60
+ bnb_optimizer=model_cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,
62
+ is_local_vocoder=model_cfg.model.vocoder.is_local,
63
+ local_vocoder_path=model_cfg.model.vocoder.local_path,
64
+ model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
65
  )
66
 
67
+ train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
68
  trainer.train(
69
  train_dataset,
70
+ num_workers=model_cfg.datasets.num_workers,
71
  resumable_with_seed=666, # seed for shuffling dataset
72
  )
73