Spaces:
Running
on
Zero
Running
on
Zero
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- finetune-cli.py +11 -6
- finetune_gradio.py +73 -32
finetune-cli.py
CHANGED
@@ -28,6 +28,7 @@ def parse_args():
|
|
28 |
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
|
29 |
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
|
30 |
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
|
|
|
31 |
|
32 |
return parser.parse_args()
|
33 |
|
@@ -42,17 +43,21 @@ def main():
|
|
42 |
wandb_resume_id = None
|
43 |
model_cls = DiT
|
44 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
45 |
-
|
|
|
46 |
elif args.exp_name == "E2TTS_Base":
|
47 |
wandb_resume_id = None
|
48 |
model_cls = UNetT
|
49 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
path_ckpt = os.path.join("ckpts",args.dataset_name)
|
53 |
-
if os.path.isdir(path_ckpt)==False:
|
54 |
-
os.makedirs(path_ckpt,exist_ok=True)
|
55 |
-
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
|
56 |
checkpoint_path=os.path.join("ckpts",args.dataset_name)
|
57 |
|
58 |
# Use the dataset_name provided in the command line
|
|
|
28 |
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
|
29 |
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
|
30 |
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
|
31 |
+
parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
|
32 |
|
33 |
return parser.parse_args()
|
34 |
|
|
|
43 |
wandb_resume_id = None
|
44 |
model_cls = DiT
|
45 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
46 |
+
if args.finetune:
|
47 |
+
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
48 |
elif args.exp_name == "E2TTS_Base":
|
49 |
wandb_resume_id = None
|
50 |
model_cls = UNetT
|
51 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
52 |
+
if args.finetune:
|
53 |
+
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
54 |
+
|
55 |
+
if args.finetune:
|
56 |
+
path_ckpt = os.path.join("ckpts",args.dataset_name)
|
57 |
+
if os.path.isdir(path_ckpt)==False:
|
58 |
+
os.makedirs(path_ckpt,exist_ok=True)
|
59 |
+
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
|
60 |
|
|
|
|
|
|
|
|
|
61 |
checkpoint_path=os.path.join("ckpts",args.dataset_name)
|
62 |
|
63 |
# Use the dataset_name provided in the command line
|
finetune_gradio.py
CHANGED
@@ -9,24 +9,19 @@ from glob import glob
|
|
9 |
import librosa
|
10 |
import numpy as np
|
11 |
from scipy.io import wavfile
|
12 |
-
from tqdm import tqdm
|
13 |
import shutil
|
14 |
import time
|
15 |
|
16 |
import json
|
17 |
-
from datasets import Dataset
|
18 |
from model.utils import convert_char_to_pinyin
|
19 |
import signal
|
20 |
import psutil
|
21 |
import platform
|
22 |
import subprocess
|
23 |
from datasets.arrow_writer import ArrowWriter
|
24 |
-
from datasets import load_dataset, load_from_disk
|
25 |
|
26 |
import json
|
27 |
|
28 |
-
|
29 |
-
|
30 |
training_process = None
|
31 |
system = platform.system()
|
32 |
python_executable = sys.executable or "python"
|
@@ -265,8 +260,20 @@ def start_training(dataset_name="",
|
|
265 |
finetune=True,
|
266 |
):
|
267 |
|
|
|
268 |
global training_process
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
# Check if a training process is already running
|
271 |
if training_process is not None:
|
272 |
return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
|
@@ -274,7 +281,7 @@ def start_training(dataset_name="",
|
|
274 |
yield "start train",gr.update(interactive=False),gr.update(interactive=False)
|
275 |
|
276 |
# Command to run the training script with the specified arguments
|
277 |
-
cmd = f"
|
278 |
f"--learning_rate {learning_rate} " \
|
279 |
f"--batch_size_per_gpu {batch_size_per_gpu} " \
|
280 |
f"--batch_size_type {batch_size_type} " \
|
@@ -346,6 +353,8 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
|
|
346 |
path_project_wavs = os.path.join(path_project,"wavs")
|
347 |
file_metadata = os.path.join(path_project,"metadata.csv")
|
348 |
|
|
|
|
|
349 |
if os.path.isdir(path_project_wavs):
|
350 |
shutil.rmtree(path_project_wavs)
|
351 |
|
@@ -356,16 +365,17 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
|
|
356 |
|
357 |
if user:
|
358 |
file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
|
|
|
359 |
else:
|
360 |
file_audios = audio_files
|
361 |
-
|
362 |
-
print([file_audios])
|
363 |
|
364 |
alpha = 0.5
|
365 |
_max = 1.0
|
366 |
slicer = Slicer(24000)
|
367 |
|
368 |
num = 0
|
|
|
369 |
data=""
|
370 |
for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
|
371 |
|
@@ -381,18 +391,26 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
|
|
381 |
if(tmp_max>1):chunk/=tmp_max
|
382 |
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
383 |
wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
|
|
|
|
|
|
|
|
|
384 |
|
385 |
-
|
386 |
-
text = text.lower().strip().replace('"',"")
|
387 |
|
388 |
-
|
|
|
|
|
389 |
|
390 |
-
num+=1
|
391 |
-
|
392 |
with open(file_metadata,"w",encoding="utf-8") as f:
|
393 |
f.write(data)
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
def format_seconds_to_hms(seconds):
|
398 |
hours = int(seconds / 3600)
|
@@ -408,6 +426,8 @@ def create_metadata(name_project,progress=gr.Progress()):
|
|
408 |
file_raw = os.path.join(path_project,"raw.arrow")
|
409 |
file_duration = os.path.join(path_project,"duration.json")
|
410 |
file_vocab = os.path.join(path_project,"vocab.txt")
|
|
|
|
|
411 |
|
412 |
with open(file_metadata,"r",encoding="utf-8") as f:
|
413 |
data=f.read()
|
@@ -419,11 +439,18 @@ def create_metadata(name_project,progress=gr.Progress()):
|
|
419 |
count=data.split("\n")
|
420 |
lenght=0
|
421 |
result=[]
|
|
|
422 |
for line in progress.tqdm(data.split("\n"),total=count):
|
423 |
sp_line=line.split("|")
|
424 |
if len(sp_line)!=2:continue
|
425 |
-
name_audio,text = sp_line[:2]
|
|
|
426 |
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
|
|
|
|
|
|
|
|
|
|
|
427 |
duraction = get_audio_duration(file_audio)
|
428 |
if duraction<2 and duraction>15:continue
|
429 |
if len(text)<4:continue
|
@@ -439,6 +466,10 @@ def create_metadata(name_project,progress=gr.Progress()):
|
|
439 |
|
440 |
lenght+=duraction
|
441 |
|
|
|
|
|
|
|
|
|
442 |
min_second = round(min(duration_list),2)
|
443 |
max_second = round(max(duration_list),2)
|
444 |
|
@@ -450,9 +481,15 @@ def create_metadata(name_project,progress=gr.Progress()):
|
|
450 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
451 |
|
452 |
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
|
|
453 |
shutil.copy2(file_vocab_finetune, file_vocab)
|
454 |
-
|
455 |
-
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
def check_user(value):
|
458 |
return gr.update(visible=not value),gr.update(visible=value)
|
@@ -466,15 +503,19 @@ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_w
|
|
466 |
data = json.load(file)
|
467 |
|
468 |
duration_list = data['duration']
|
|
|
469 |
samples = len(duration_list)
|
470 |
|
471 |
-
|
472 |
-
|
|
|
|
|
|
|
473 |
|
474 |
if batch_size_type=="frame":
|
475 |
batch = int(total_memory * 0.5)
|
476 |
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
|
477 |
-
batch_size_per_gpu = int(
|
478 |
else:
|
479 |
batch_size_per_gpu = int(total_memory / 8)
|
480 |
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
@@ -509,13 +550,12 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
|
|
509 |
if ema_model_state_dict is not None:
|
510 |
new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
|
511 |
torch.save(new_checkpoint, new_checkpoint_path)
|
512 |
-
|
513 |
else:
|
514 |
-
|
515 |
|
516 |
except Exception as e:
|
517 |
-
|
518 |
-
|
519 |
|
520 |
def vocab_check(project_name):
|
521 |
name_project = project_name + "_pinyin"
|
@@ -524,12 +564,17 @@ def vocab_check(project_name):
|
|
524 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
525 |
|
526 |
file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
|
|
|
|
|
527 |
|
528 |
with open(file_vocab,"r",encoding="utf-8") as f:
|
529 |
data=f.read()
|
530 |
|
531 |
vocab = data.split("\n")
|
532 |
|
|
|
|
|
|
|
533 |
with open(file_metadata,"r",encoding="utf-8") as f:
|
534 |
data=f.read()
|
535 |
|
@@ -548,6 +593,7 @@ def vocab_check(project_name):
|
|
548 |
|
549 |
if miss_symbols==[]:info ="You can train using your language !"
|
550 |
else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
|
|
|
551 |
return info
|
552 |
|
553 |
|
@@ -652,8 +698,9 @@ with gr.Blocks() as app:
|
|
652 |
with gr.TabItem("reduse checkpoint"):
|
653 |
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
654 |
txt_path_checkpoint_small = gr.Text(label="path output :")
|
|
|
655 |
reduse_button = gr.Button("reduse")
|
656 |
-
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small])
|
657 |
|
658 |
with gr.TabItem("vocab check experiment"):
|
659 |
check_button = gr.Button("check vocab")
|
@@ -680,10 +727,4 @@ def main(port, host, share, api):
|
|
680 |
)
|
681 |
|
682 |
if __name__ == "__main__":
|
683 |
-
name="my_speak"
|
684 |
-
|
685 |
-
#create_data_project(name)
|
686 |
-
#transcribe_all(name)
|
687 |
-
#create_metadata(name)
|
688 |
-
|
689 |
main()
|
|
|
9 |
import librosa
|
10 |
import numpy as np
|
11 |
from scipy.io import wavfile
|
|
|
12 |
import shutil
|
13 |
import time
|
14 |
|
15 |
import json
|
|
|
16 |
from model.utils import convert_char_to_pinyin
|
17 |
import signal
|
18 |
import psutil
|
19 |
import platform
|
20 |
import subprocess
|
21 |
from datasets.arrow_writer import ArrowWriter
|
|
|
22 |
|
23 |
import json
|
24 |
|
|
|
|
|
25 |
training_process = None
|
26 |
system = platform.system()
|
27 |
python_executable = sys.executable or "python"
|
|
|
260 |
finetune=True,
|
261 |
):
|
262 |
|
263 |
+
|
264 |
global training_process
|
265 |
|
266 |
+
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
267 |
+
|
268 |
+
if os.path.isdir(path_project)==False:
|
269 |
+
yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
|
270 |
+
return
|
271 |
+
|
272 |
+
file_raw = os.path.join(path_project,"raw.arrow")
|
273 |
+
if os.path.isfile(file_raw)==False:
|
274 |
+
yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
|
275 |
+
return
|
276 |
+
|
277 |
# Check if a training process is already running
|
278 |
if training_process is not None:
|
279 |
return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
|
|
|
281 |
yield "start train",gr.update(interactive=False),gr.update(interactive=False)
|
282 |
|
283 |
# Command to run the training script with the specified arguments
|
284 |
+
cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \
|
285 |
f"--learning_rate {learning_rate} " \
|
286 |
f"--batch_size_per_gpu {batch_size_per_gpu} " \
|
287 |
f"--batch_size_type {batch_size_type} " \
|
|
|
353 |
path_project_wavs = os.path.join(path_project,"wavs")
|
354 |
file_metadata = os.path.join(path_project,"metadata.csv")
|
355 |
|
356 |
+
if audio_files is None:return "You need to load an audio file."
|
357 |
+
|
358 |
if os.path.isdir(path_project_wavs):
|
359 |
shutil.rmtree(path_project_wavs)
|
360 |
|
|
|
365 |
|
366 |
if user:
|
367 |
file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
|
368 |
+
if file_audios==[]:return "No audio file was found in the dataset."
|
369 |
else:
|
370 |
file_audios = audio_files
|
371 |
+
|
|
|
372 |
|
373 |
alpha = 0.5
|
374 |
_max = 1.0
|
375 |
slicer = Slicer(24000)
|
376 |
|
377 |
num = 0
|
378 |
+
error_num = 0
|
379 |
data=""
|
380 |
for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
|
381 |
|
|
|
391 |
if(tmp_max>1):chunk/=tmp_max
|
392 |
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
393 |
wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
|
394 |
+
|
395 |
+
try:
|
396 |
+
text=transcribe(file_segment,language)
|
397 |
+
text = text.lower().strip().replace('"',"")
|
398 |
|
399 |
+
data+= f"{name_segment}|{text}\n"
|
|
|
400 |
|
401 |
+
num+=1
|
402 |
+
except:
|
403 |
+
error_num +=1
|
404 |
|
|
|
|
|
405 |
with open(file_metadata,"w",encoding="utf-8") as f:
|
406 |
f.write(data)
|
407 |
+
|
408 |
+
if error_num!=[]:
|
409 |
+
error_text=f"\nerror files : {error_num}"
|
410 |
+
else:
|
411 |
+
error_text=""
|
412 |
+
|
413 |
+
return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
|
414 |
|
415 |
def format_seconds_to_hms(seconds):
|
416 |
hours = int(seconds / 3600)
|
|
|
426 |
file_raw = os.path.join(path_project,"raw.arrow")
|
427 |
file_duration = os.path.join(path_project,"duration.json")
|
428 |
file_vocab = os.path.join(path_project,"vocab.txt")
|
429 |
+
|
430 |
+
if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
|
431 |
|
432 |
with open(file_metadata,"r",encoding="utf-8") as f:
|
433 |
data=f.read()
|
|
|
439 |
count=data.split("\n")
|
440 |
lenght=0
|
441 |
result=[]
|
442 |
+
error_files=[]
|
443 |
for line in progress.tqdm(data.split("\n"),total=count):
|
444 |
sp_line=line.split("|")
|
445 |
if len(sp_line)!=2:continue
|
446 |
+
name_audio,text = sp_line[:2]
|
447 |
+
|
448 |
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
|
449 |
+
|
450 |
+
if os.path.isfile(file_audio)==False:
|
451 |
+
error_files.append(file_audio)
|
452 |
+
continue
|
453 |
+
|
454 |
duraction = get_audio_duration(file_audio)
|
455 |
if duraction<2 and duraction>15:continue
|
456 |
if len(text)<4:continue
|
|
|
466 |
|
467 |
lenght+=duraction
|
468 |
|
469 |
+
if duration_list==[]:
|
470 |
+
error_files_text="\n".join(error_files)
|
471 |
+
return f"Error: No audio files found in the specified path : \n{error_files_text}"
|
472 |
+
|
473 |
min_second = round(min(duration_list),2)
|
474 |
max_second = round(max(duration_list),2)
|
475 |
|
|
|
481 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
482 |
|
483 |
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
484 |
+
if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
|
485 |
shutil.copy2(file_vocab_finetune, file_vocab)
|
486 |
+
|
487 |
+
if error_files!=[]:
|
488 |
+
error_text="error files\n" + "\n".join(error_files)
|
489 |
+
else:
|
490 |
+
error_text=""
|
491 |
+
|
492 |
+
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
|
493 |
|
494 |
def check_user(value):
|
495 |
return gr.update(visible=not value),gr.update(visible=value)
|
|
|
503 |
data = json.load(file)
|
504 |
|
505 |
duration_list = data['duration']
|
506 |
+
|
507 |
samples = len(duration_list)
|
508 |
|
509 |
+
if torch.cuda.is_available():
|
510 |
+
gpu_properties = torch.cuda.get_device_properties(0)
|
511 |
+
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
512 |
+
elif torch.backends.mps.is_available():
|
513 |
+
total_memory = psutil.virtual_memory().available / (1024 ** 3)
|
514 |
|
515 |
if batch_size_type=="frame":
|
516 |
batch = int(total_memory * 0.5)
|
517 |
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
|
518 |
+
batch_size_per_gpu = int(38400 / batch )
|
519 |
else:
|
520 |
batch_size_per_gpu = int(total_memory / 8)
|
521 |
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
|
|
550 |
if ema_model_state_dict is not None:
|
551 |
new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
|
552 |
torch.save(new_checkpoint, new_checkpoint_path)
|
553 |
+
return f"New checkpoint saved at: {new_checkpoint_path}"
|
554 |
else:
|
555 |
+
return "No 'ema_model_state_dict' found in the checkpoint."
|
556 |
|
557 |
except Exception as e:
|
558 |
+
return f"An error occurred: {e}"
|
|
|
559 |
|
560 |
def vocab_check(project_name):
|
561 |
name_project = project_name + "_pinyin"
|
|
|
564 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
565 |
|
566 |
file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
|
567 |
+
if os.path.isfile(file_vocab)==False:
|
568 |
+
return f"the file {file_vocab} not found !"
|
569 |
|
570 |
with open(file_vocab,"r",encoding="utf-8") as f:
|
571 |
data=f.read()
|
572 |
|
573 |
vocab = data.split("\n")
|
574 |
|
575 |
+
if os.path.isfile(file_metadata)==False:
|
576 |
+
return f"the file {file_metadata} not found !"
|
577 |
+
|
578 |
with open(file_metadata,"r",encoding="utf-8") as f:
|
579 |
data=f.read()
|
580 |
|
|
|
593 |
|
594 |
if miss_symbols==[]:info ="You can train using your language !"
|
595 |
else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
|
596 |
+
|
597 |
return info
|
598 |
|
599 |
|
|
|
698 |
with gr.TabItem("reduse checkpoint"):
|
699 |
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
700 |
txt_path_checkpoint_small = gr.Text(label="path output :")
|
701 |
+
txt_info_reduse = gr.Text(label="info",value="")
|
702 |
reduse_button = gr.Button("reduse")
|
703 |
+
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
|
704 |
|
705 |
with gr.TabItem("vocab check experiment"):
|
706 |
check_button = gr.Button("check vocab")
|
|
|
727 |
)
|
728 |
|
729 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
730 |
main()
|