mrfakename commited on
Commit
70988bd
·
verified ·
1 Parent(s): f2868a9

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. finetune-cli.py +11 -6
  2. finetune_gradio.py +73 -32
finetune-cli.py CHANGED
@@ -28,6 +28,7 @@ def parse_args():
28
  parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
  parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
  parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
 
31
 
32
  return parser.parse_args()
33
 
@@ -42,17 +43,21 @@ def main():
42
  wandb_resume_id = None
43
  model_cls = DiT
44
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
45
- ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
 
46
  elif args.exp_name == "E2TTS_Base":
47
  wandb_resume_id = None
48
  model_cls = UNetT
49
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50
- ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
 
 
 
 
 
 
 
51
 
52
- path_ckpt = os.path.join("ckpts",args.dataset_name)
53
- if os.path.isdir(path_ckpt)==False:
54
- os.makedirs(path_ckpt,exist_ok=True)
55
- shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
56
  checkpoint_path=os.path.join("ckpts",args.dataset_name)
57
 
58
  # Use the dataset_name provided in the command line
 
28
  parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
  parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
  parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
31
+ parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
32
 
33
  return parser.parse_args()
34
 
 
43
  wandb_resume_id = None
44
  model_cls = DiT
45
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
+ if args.finetune:
47
+ ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
48
  elif args.exp_name == "E2TTS_Base":
49
  wandb_resume_id = None
50
  model_cls = UNetT
51
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
+ if args.finetune:
53
+ ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
54
+
55
+ if args.finetune:
56
+ path_ckpt = os.path.join("ckpts",args.dataset_name)
57
+ if os.path.isdir(path_ckpt)==False:
58
+ os.makedirs(path_ckpt,exist_ok=True)
59
+ shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
60
 
 
 
 
 
61
  checkpoint_path=os.path.join("ckpts",args.dataset_name)
62
 
63
  # Use the dataset_name provided in the command line
finetune_gradio.py CHANGED
@@ -9,24 +9,19 @@ from glob import glob
9
  import librosa
10
  import numpy as np
11
  from scipy.io import wavfile
12
- from tqdm import tqdm
13
  import shutil
14
  import time
15
 
16
  import json
17
- from datasets import Dataset
18
  from model.utils import convert_char_to_pinyin
19
  import signal
20
  import psutil
21
  import platform
22
  import subprocess
23
  from datasets.arrow_writer import ArrowWriter
24
- from datasets import load_dataset, load_from_disk
25
 
26
  import json
27
 
28
-
29
-
30
  training_process = None
31
  system = platform.system()
32
  python_executable = sys.executable or "python"
@@ -265,8 +260,20 @@ def start_training(dataset_name="",
265
  finetune=True,
266
  ):
267
 
 
268
  global training_process
269
 
 
 
 
 
 
 
 
 
 
 
 
270
  # Check if a training process is already running
271
  if training_process is not None:
272
  return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
@@ -274,7 +281,7 @@ def start_training(dataset_name="",
274
  yield "start train",gr.update(interactive=False),gr.update(interactive=False)
275
 
276
  # Command to run the training script with the specified arguments
277
- cmd = f"{python_executable} finetune-cli.py --exp_name {exp_name} " \
278
  f"--learning_rate {learning_rate} " \
279
  f"--batch_size_per_gpu {batch_size_per_gpu} " \
280
  f"--batch_size_type {batch_size_type} " \
@@ -346,6 +353,8 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
346
  path_project_wavs = os.path.join(path_project,"wavs")
347
  file_metadata = os.path.join(path_project,"metadata.csv")
348
 
 
 
349
  if os.path.isdir(path_project_wavs):
350
  shutil.rmtree(path_project_wavs)
351
 
@@ -356,16 +365,17 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
356
 
357
  if user:
358
  file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
 
359
  else:
360
  file_audios = audio_files
361
-
362
- print([file_audios])
363
 
364
  alpha = 0.5
365
  _max = 1.0
366
  slicer = Slicer(24000)
367
 
368
  num = 0
 
369
  data=""
370
  for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
371
 
@@ -381,18 +391,26 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
381
  if(tmp_max>1):chunk/=tmp_max
382
  chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
383
  wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
 
 
 
 
384
 
385
- text=transcribe(file_segment,language)
386
- text = text.lower().strip().replace('"',"")
387
 
388
- data+= f"{name_segment}|{text}\n"
 
 
389
 
390
- num+=1
391
-
392
  with open(file_metadata,"w",encoding="utf-8") as f:
393
  f.write(data)
394
-
395
- return f"transcribe complete samples : {num} in path {path_project_wavs}"
 
 
 
 
 
396
 
397
  def format_seconds_to_hms(seconds):
398
  hours = int(seconds / 3600)
@@ -408,6 +426,8 @@ def create_metadata(name_project,progress=gr.Progress()):
408
  file_raw = os.path.join(path_project,"raw.arrow")
409
  file_duration = os.path.join(path_project,"duration.json")
410
  file_vocab = os.path.join(path_project,"vocab.txt")
 
 
411
 
412
  with open(file_metadata,"r",encoding="utf-8") as f:
413
  data=f.read()
@@ -419,11 +439,18 @@ def create_metadata(name_project,progress=gr.Progress()):
419
  count=data.split("\n")
420
  lenght=0
421
  result=[]
 
422
  for line in progress.tqdm(data.split("\n"),total=count):
423
  sp_line=line.split("|")
424
  if len(sp_line)!=2:continue
425
- name_audio,text = sp_line[:2]
 
426
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
 
 
 
 
 
427
  duraction = get_audio_duration(file_audio)
428
  if duraction<2 and duraction>15:continue
429
  if len(text)<4:continue
@@ -439,6 +466,10 @@ def create_metadata(name_project,progress=gr.Progress()):
439
 
440
  lenght+=duraction
441
 
 
 
 
 
442
  min_second = round(min(duration_list),2)
443
  max_second = round(max(duration_list),2)
444
 
@@ -450,9 +481,15 @@ def create_metadata(name_project,progress=gr.Progress()):
450
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
451
 
452
  file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
 
453
  shutil.copy2(file_vocab_finetune, file_vocab)
454
-
455
- return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n"
 
 
 
 
 
456
 
457
  def check_user(value):
458
  return gr.update(visible=not value),gr.update(visible=value)
@@ -466,15 +503,19 @@ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_w
466
  data = json.load(file)
467
 
468
  duration_list = data['duration']
 
469
  samples = len(duration_list)
470
 
471
- gpu_properties = torch.cuda.get_device_properties(0)
472
- total_memory = gpu_properties.total_memory / (1024 ** 3)
 
 
 
473
 
474
  if batch_size_type=="frame":
475
  batch = int(total_memory * 0.5)
476
  batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
477
- batch_size_per_gpu = int(36800 / batch )
478
  else:
479
  batch_size_per_gpu = int(total_memory / 8)
480
  batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
@@ -509,13 +550,12 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
509
  if ema_model_state_dict is not None:
510
  new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
511
  torch.save(new_checkpoint, new_checkpoint_path)
512
- print(f"New checkpoint saved at: {new_checkpoint_path}")
513
  else:
514
- print("No 'ema_model_state_dict' found in the checkpoint.")
515
 
516
  except Exception as e:
517
- print(f"An error occurred: {e}")
518
-
519
 
520
  def vocab_check(project_name):
521
  name_project = project_name + "_pinyin"
@@ -524,12 +564,17 @@ def vocab_check(project_name):
524
  file_metadata = os.path.join(path_project, "metadata.csv")
525
 
526
  file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
 
 
527
 
528
  with open(file_vocab,"r",encoding="utf-8") as f:
529
  data=f.read()
530
 
531
  vocab = data.split("\n")
532
 
 
 
 
533
  with open(file_metadata,"r",encoding="utf-8") as f:
534
  data=f.read()
535
 
@@ -548,6 +593,7 @@ def vocab_check(project_name):
548
 
549
  if miss_symbols==[]:info ="You can train using your language !"
550
  else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
 
551
  return info
552
 
553
 
@@ -652,8 +698,9 @@ with gr.Blocks() as app:
652
  with gr.TabItem("reduse checkpoint"):
653
  txt_path_checkpoint = gr.Text(label="path checkpoint :")
654
  txt_path_checkpoint_small = gr.Text(label="path output :")
 
655
  reduse_button = gr.Button("reduse")
656
- reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small])
657
 
658
  with gr.TabItem("vocab check experiment"):
659
  check_button = gr.Button("check vocab")
@@ -680,10 +727,4 @@ def main(port, host, share, api):
680
  )
681
 
682
  if __name__ == "__main__":
683
- name="my_speak"
684
-
685
- #create_data_project(name)
686
- #transcribe_all(name)
687
- #create_metadata(name)
688
-
689
  main()
 
9
  import librosa
10
  import numpy as np
11
  from scipy.io import wavfile
 
12
  import shutil
13
  import time
14
 
15
  import json
 
16
  from model.utils import convert_char_to_pinyin
17
  import signal
18
  import psutil
19
  import platform
20
  import subprocess
21
  from datasets.arrow_writer import ArrowWriter
 
22
 
23
  import json
24
 
 
 
25
  training_process = None
26
  system = platform.system()
27
  python_executable = sys.executable or "python"
 
260
  finetune=True,
261
  ):
262
 
263
+
264
  global training_process
265
 
266
+ path_project = os.path.join(path_data, dataset_name + "_pinyin")
267
+
268
+ if os.path.isdir(path_project)==False:
269
+ yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
270
+ return
271
+
272
+ file_raw = os.path.join(path_project,"raw.arrow")
273
+ if os.path.isfile(file_raw)==False:
274
+ yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
275
+ return
276
+
277
  # Check if a training process is already running
278
  if training_process is not None:
279
  return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
 
281
  yield "start train",gr.update(interactive=False),gr.update(interactive=False)
282
 
283
  # Command to run the training script with the specified arguments
284
+ cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \
285
  f"--learning_rate {learning_rate} " \
286
  f"--batch_size_per_gpu {batch_size_per_gpu} " \
287
  f"--batch_size_type {batch_size_type} " \
 
353
  path_project_wavs = os.path.join(path_project,"wavs")
354
  file_metadata = os.path.join(path_project,"metadata.csv")
355
 
356
+ if audio_files is None:return "You need to load an audio file."
357
+
358
  if os.path.isdir(path_project_wavs):
359
  shutil.rmtree(path_project_wavs)
360
 
 
365
 
366
  if user:
367
  file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
368
+ if file_audios==[]:return "No audio file was found in the dataset."
369
  else:
370
  file_audios = audio_files
371
+
 
372
 
373
  alpha = 0.5
374
  _max = 1.0
375
  slicer = Slicer(24000)
376
 
377
  num = 0
378
+ error_num = 0
379
  data=""
380
  for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
381
 
 
391
  if(tmp_max>1):chunk/=tmp_max
392
  chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
393
  wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
394
+
395
+ try:
396
+ text=transcribe(file_segment,language)
397
+ text = text.lower().strip().replace('"',"")
398
 
399
+ data+= f"{name_segment}|{text}\n"
 
400
 
401
+ num+=1
402
+ except:
403
+ error_num +=1
404
 
 
 
405
  with open(file_metadata,"w",encoding="utf-8") as f:
406
  f.write(data)
407
+
408
+ if error_num!=[]:
409
+ error_text=f"\nerror files : {error_num}"
410
+ else:
411
+ error_text=""
412
+
413
+ return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
414
 
415
  def format_seconds_to_hms(seconds):
416
  hours = int(seconds / 3600)
 
426
  file_raw = os.path.join(path_project,"raw.arrow")
427
  file_duration = os.path.join(path_project,"duration.json")
428
  file_vocab = os.path.join(path_project,"vocab.txt")
429
+
430
+ if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
431
 
432
  with open(file_metadata,"r",encoding="utf-8") as f:
433
  data=f.read()
 
439
  count=data.split("\n")
440
  lenght=0
441
  result=[]
442
+ error_files=[]
443
  for line in progress.tqdm(data.split("\n"),total=count):
444
  sp_line=line.split("|")
445
  if len(sp_line)!=2:continue
446
+ name_audio,text = sp_line[:2]
447
+
448
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
449
+
450
+ if os.path.isfile(file_audio)==False:
451
+ error_files.append(file_audio)
452
+ continue
453
+
454
  duraction = get_audio_duration(file_audio)
455
  if duraction<2 and duraction>15:continue
456
  if len(text)<4:continue
 
466
 
467
  lenght+=duraction
468
 
469
+ if duration_list==[]:
470
+ error_files_text="\n".join(error_files)
471
+ return f"Error: No audio files found in the specified path : \n{error_files_text}"
472
+
473
  min_second = round(min(duration_list),2)
474
  max_second = round(max(duration_list),2)
475
 
 
481
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
482
 
483
  file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
484
+ if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
485
  shutil.copy2(file_vocab_finetune, file_vocab)
486
+
487
+ if error_files!=[]:
488
+ error_text="error files\n" + "\n".join(error_files)
489
+ else:
490
+ error_text=""
491
+
492
+ return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
493
 
494
  def check_user(value):
495
  return gr.update(visible=not value),gr.update(visible=value)
 
503
  data = json.load(file)
504
 
505
  duration_list = data['duration']
506
+
507
  samples = len(duration_list)
508
 
509
+ if torch.cuda.is_available():
510
+ gpu_properties = torch.cuda.get_device_properties(0)
511
+ total_memory = gpu_properties.total_memory / (1024 ** 3)
512
+ elif torch.backends.mps.is_available():
513
+ total_memory = psutil.virtual_memory().available / (1024 ** 3)
514
 
515
  if batch_size_type=="frame":
516
  batch = int(total_memory * 0.5)
517
  batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
518
+ batch_size_per_gpu = int(38400 / batch )
519
  else:
520
  batch_size_per_gpu = int(total_memory / 8)
521
  batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
 
550
  if ema_model_state_dict is not None:
551
  new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
552
  torch.save(new_checkpoint, new_checkpoint_path)
553
+ return f"New checkpoint saved at: {new_checkpoint_path}"
554
  else:
555
+ return "No 'ema_model_state_dict' found in the checkpoint."
556
 
557
  except Exception as e:
558
+ return f"An error occurred: {e}"
 
559
 
560
  def vocab_check(project_name):
561
  name_project = project_name + "_pinyin"
 
564
  file_metadata = os.path.join(path_project, "metadata.csv")
565
 
566
  file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
567
+ if os.path.isfile(file_vocab)==False:
568
+ return f"the file {file_vocab} not found !"
569
 
570
  with open(file_vocab,"r",encoding="utf-8") as f:
571
  data=f.read()
572
 
573
  vocab = data.split("\n")
574
 
575
+ if os.path.isfile(file_metadata)==False:
576
+ return f"the file {file_metadata} not found !"
577
+
578
  with open(file_metadata,"r",encoding="utf-8") as f:
579
  data=f.read()
580
 
 
593
 
594
  if miss_symbols==[]:info ="You can train using your language !"
595
  else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
596
+
597
  return info
598
 
599
 
 
698
  with gr.TabItem("reduse checkpoint"):
699
  txt_path_checkpoint = gr.Text(label="path checkpoint :")
700
  txt_path_checkpoint_small = gr.Text(label="path output :")
701
+ txt_info_reduse = gr.Text(label="info",value="")
702
  reduse_button = gr.Button("reduse")
703
+ reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
704
 
705
  with gr.TabItem("vocab check experiment"):
706
  check_button = gr.Button("check vocab")
 
727
  )
728
 
729
  if __name__ == "__main__":
 
 
 
 
 
 
730
  main()