Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/model/trainer.py
CHANGED
@@ -51,7 +51,7 @@ class Trainer:
|
|
51 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
52 |
is_local_vocoder: bool = False, # use local path vocoder
|
53 |
local_vocoder_path: str = "", # local vocoder path
|
54 |
-
|
55 |
):
|
56 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
57 |
|
@@ -73,8 +73,8 @@ class Trainer:
|
|
73 |
else:
|
74 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
75 |
|
76 |
-
if not
|
77 |
-
|
78 |
"epochs": epochs,
|
79 |
"learning_rate": learning_rate,
|
80 |
"num_warmup_updates": num_warmup_updates,
|
@@ -85,11 +85,11 @@ class Trainer:
|
|
85 |
"max_grad_norm": max_grad_norm,
|
86 |
"noise_scheduler": noise_scheduler,
|
87 |
}
|
88 |
-
|
89 |
self.accelerator.init_trackers(
|
90 |
project_name=wandb_project,
|
91 |
init_kwargs=init_kwargs,
|
92 |
-
config=
|
93 |
)
|
94 |
|
95 |
elif self.logger == "tensorboard":
|
|
|
51 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
52 |
is_local_vocoder: bool = False, # use local path vocoder
|
53 |
local_vocoder_path: str = "", # local vocoder path
|
54 |
+
model_cfg_dict: dict = dict(), # training config
|
55 |
):
|
56 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
57 |
|
|
|
73 |
else:
|
74 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
75 |
|
76 |
+
if not model_cfg_dict:
|
77 |
+
model_cfg_dict = {
|
78 |
"epochs": epochs,
|
79 |
"learning_rate": learning_rate,
|
80 |
"num_warmup_updates": num_warmup_updates,
|
|
|
85 |
"max_grad_norm": max_grad_norm,
|
86 |
"noise_scheduler": noise_scheduler,
|
87 |
}
|
88 |
+
model_cfg_dict["gpus"] = self.accelerator.num_processes
|
89 |
self.accelerator.init_trackers(
|
90 |
project_name=wandb_project,
|
91 |
init_kwargs=init_kwargs,
|
92 |
+
config=model_cfg_dict,
|
93 |
)
|
94 |
|
95 |
elif self.logger == "tensorboard":
|