mrfakename commited on
Commit
ebe57a6
·
verified ·
1 Parent(s): 6e26246

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/model/trainer.py +5 -5
src/f5_tts/model/trainer.py CHANGED
@@ -51,7 +51,7 @@ class Trainer:
51
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
52
  is_local_vocoder: bool = False, # use local path vocoder
53
  local_vocoder_path: str = "", # local vocoder path
54
- cfg_dict: dict = dict(), # training config
55
  ):
56
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
57
 
@@ -73,8 +73,8 @@ class Trainer:
73
  else:
74
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
75
 
76
- if not cfg_dict:
77
- cfg_dict = {
78
  "epochs": epochs,
79
  "learning_rate": learning_rate,
80
  "num_warmup_updates": num_warmup_updates,
@@ -85,11 +85,11 @@ class Trainer:
85
  "max_grad_norm": max_grad_norm,
86
  "noise_scheduler": noise_scheduler,
87
  }
88
- cfg_dict["gpus"] = self.accelerator.num_processes
89
  self.accelerator.init_trackers(
90
  project_name=wandb_project,
91
  init_kwargs=init_kwargs,
92
- config=cfg_dict,
93
  )
94
 
95
  elif self.logger == "tensorboard":
 
51
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
52
  is_local_vocoder: bool = False, # use local path vocoder
53
  local_vocoder_path: str = "", # local vocoder path
54
+ model_cfg_dict: dict = dict(), # training config
55
  ):
56
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
57
 
 
73
  else:
74
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
75
 
76
+ if not model_cfg_dict:
77
+ model_cfg_dict = {
78
  "epochs": epochs,
79
  "learning_rate": learning_rate,
80
  "num_warmup_updates": num_warmup_updates,
 
85
  "max_grad_norm": max_grad_norm,
86
  "noise_scheduler": noise_scheduler,
87
  }
88
+ model_cfg_dict["gpus"] = self.accelerator.num_processes
89
  self.accelerator.init_trackers(
90
  project_name=wandb_project,
91
  init_kwargs=init_kwargs,
92
+ config=model_cfg_dict,
93
  )
94
 
95
  elif self.logger == "tensorboard":