Staticaliza commited on
Commit
89ea150
1 Parent(s): dab6fe2

Upload 9 files

Browse files
scripts/scripts_count_max_epoch.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
4
+
5
+ # data
6
+ total_hours = 95282
7
+ mel_hop_length = 256
8
+ mel_sampling_rate = 24000
9
+
10
+ # target
11
+ wanted_max_updates = 1000000
12
+
13
+ # train params
14
+ gpus = 8
15
+ frames_per_gpu = 38400 # 8 * 38400 = 307200
16
+ grad_accum = 1
17
+
18
+ # intermediate
19
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
20
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
21
+ updates_per_epoch = total_hours / mini_batch_hours
22
+ steps_per_epoch = updates_per_epoch * grad_accum
23
+
24
+ # result
25
+ epochs = wanted_max_updates / updates_per_epoch
26
+ print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
27
+ print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
28
+ print(f" or approx. 0/{steps_per_epoch:.0f} steps")
29
+
30
+ # others
31
+ print(f"total {total_hours:.0f} hours")
32
+ print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
scripts/scripts_count_params_gflops.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
+
6
+ import torch
7
+ import thop
8
+
9
+
10
+ ''' ~155M '''
11
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
14
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
15
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
+ # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
+
18
+ ''' ~335M '''
19
+ # FLOPs: 622.1 G, Params: 333.2 M
20
+ # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
+ # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
+
24
+
25
+ model = M2_TTS(transformer=transformer)
26
+ target_sample_rate = 24000
27
+ n_mel_channels = 100
28
+ hop_length = 256
29
+ duration = 20
30
+ frame_length = int(duration * target_sample_rate / hop_length)
31
+ text_length = 150
32
+
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
34
+ print(f"FLOPs: {flops / 1e9} G")
35
+ print(f"Params: {params / 1e6} M")
scripts/scripts_eval_infer_batch.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ import time
5
+ import random
6
+ from tqdm import tqdm
7
+ import argparse
8
+
9
+ import torch
10
+ import torchaudio
11
+ from accelerate import Accelerator
12
+ from einops import rearrange
13
+ from vocos import Vocos
14
+
15
+ from model import CFM, UNetT, DiT
16
+ from model.utils import (
17
+ load_checkpoint,
18
+ get_tokenizer,
19
+ get_seedtts_testset_metainfo,
20
+ get_librispeech_test_clean_metainfo,
21
+ get_inference_prompt,
22
+ )
23
+
24
+ accelerator = Accelerator()
25
+ device = f"cuda:{accelerator.process_index}"
26
+
27
+
28
+ # --------------------- Dataset Settings -------------------- #
29
+
30
+ target_sample_rate = 24000
31
+ n_mel_channels = 100
32
+ hop_length = 256
33
+ target_rms = 0.1
34
+
35
+ tokenizer = "pinyin"
36
+
37
+
38
+ # ---------------------- infer setting ---------------------- #
39
+
40
+ parser = argparse.ArgumentParser(description="batch inference")
41
+
42
+ parser.add_argument('-s', '--seed', default=None, type=int)
43
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
44
+ parser.add_argument('-n', '--expname', required=True)
45
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
46
+
47
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
48
+ parser.add_argument('-o', '--odemethod', default="euler")
49
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
50
+
51
+ parser.add_argument('-t', '--testset', required=True)
52
+
53
+ args = parser.parse_args()
54
+
55
+
56
+ seed = args.seed
57
+ dataset_name = args.dataset
58
+ exp_name = args.expname
59
+ ckpt_step = args.ckptstep
60
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
61
+
62
+ nfe_step = args.nfestep
63
+ ode_method = args.odemethod
64
+ sway_sampling_coef = args.swaysampling
65
+
66
+ testset = args.testset
67
+
68
+
69
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
70
+ cfg_strength = 2.
71
+ speed = 1.
72
+ use_truth_duration = False
73
+ no_ref_audio = False
74
+
75
+
76
+ if exp_name == "F5TTS_Base":
77
+ model_cls = DiT
78
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
79
+
80
+ elif exp_name == "E2TTS_Base":
81
+ model_cls = UNetT
82
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
83
+
84
+
85
+ if testset == "ls_pc_test_clean":
86
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
87
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
88
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
89
+
90
+ elif testset == "seedtts_test_zh":
91
+ metalst = "data/seedtts_testset/zh/meta.lst"
92
+ metainfo = get_seedtts_testset_metainfo(metalst)
93
+
94
+ elif testset == "seedtts_test_en":
95
+ metalst = "data/seedtts_testset/en/meta.lst"
96
+ metainfo = get_seedtts_testset_metainfo(metalst)
97
+
98
+
99
+ # path to save genereted wavs
100
+ if seed is None: seed = random.randint(-10000, 10000)
101
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
102
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
103
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
104
+ f"_cfg{cfg_strength}_speed{speed}" \
105
+ f"{'_gt-dur' if use_truth_duration else ''}" \
106
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
107
+
108
+
109
+ # -------------------------------------------------#
110
+
111
+ use_ema = True
112
+
113
+ prompts_all = get_inference_prompt(
114
+ metainfo,
115
+ speed = speed,
116
+ tokenizer = tokenizer,
117
+ target_sample_rate = target_sample_rate,
118
+ n_mel_channels = n_mel_channels,
119
+ hop_length = hop_length,
120
+ target_rms = target_rms,
121
+ use_truth_duration = use_truth_duration,
122
+ infer_batch_size = infer_batch_size,
123
+ )
124
+
125
+ # Vocoder model
126
+ local = False
127
+ if local:
128
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
131
+ vocos.load_state_dict(state_dict)
132
+ vocos.eval()
133
+ else:
134
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
135
+
136
+ # Tokenizer
137
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
138
+
139
+ # Model
140
+ model = CFM(
141
+ transformer = model_cls(
142
+ **model_cfg,
143
+ text_num_embeds = vocab_size,
144
+ mel_dim = n_mel_channels
145
+ ),
146
+ mel_spec_kwargs = dict(
147
+ target_sample_rate = target_sample_rate,
148
+ n_mel_channels = n_mel_channels,
149
+ hop_length = hop_length,
150
+ ),
151
+ odeint_kwargs = dict(
152
+ method = ode_method,
153
+ ),
154
+ vocab_char_map = vocab_char_map,
155
+ ).to(device)
156
+
157
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
158
+
159
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
160
+ os.makedirs(output_dir)
161
+
162
+ # start batch inference
163
+ accelerator.wait_for_everyone()
164
+ start = time.time()
165
+
166
+ with accelerator.split_between_processes(prompts_all) as prompts:
167
+
168
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170
+ ref_mels = ref_mels.to(device)
171
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
172
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
173
+
174
+ # Inference
175
+ with torch.inference_mode():
176
+ generated, _ = model.sample(
177
+ cond = ref_mels,
178
+ text = final_text_list,
179
+ duration = total_mel_lens,
180
+ lens = ref_mel_lens,
181
+ steps = nfe_step,
182
+ cfg_strength = cfg_strength,
183
+ sway_sampling_coef = sway_sampling_coef,
184
+ no_ref_audio = no_ref_audio,
185
+ seed = seed,
186
+ )
187
+ # Final result
188
+ for i, gen in enumerate(generated):
189
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
191
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
192
+ if ref_rms_list[i] < target_rms:
193
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
194
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
195
+
196
+ accelerator.wait_for_everyone()
197
+ if accelerator.is_main_process:
198
+ timediff = time.time() - start
199
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
scripts/scripts_eval_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
scripts/scripts_eval_librispeech_test_clean.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_librispeech_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "en"
18
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
19
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
+
22
+ gpus = [0,1,2,3,4,5,6,7]
23
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
+
25
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
26
+ ## leading to a low similarity for the ground truth in some cases.
27
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
28
+
29
+ local = False
30
+ if local: # use local custom checkpoint dir
31
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
32
+ else:
33
+ asr_ckpt_dir = "" # auto download to cache dir
34
+
35
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
36
+
37
+
38
+ # --------------------------- WER ---------------------------
39
+
40
+ if eval_task == "wer":
41
+ wers = []
42
+
43
+ with mp.Pool(processes=len(gpus)) as pool:
44
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
45
+ results = pool.map(run_asr_wer, args)
46
+ for wers_ in results:
47
+ wers.extend(wers_)
48
+
49
+ wer = round(np.mean(wers)*100, 3)
50
+ print(f"\nTotal {len(wers)} samples")
51
+ print(f"WER : {wer}%")
52
+
53
+
54
+ # --------------------------- SIM ---------------------------
55
+
56
+ if eval_task == "sim":
57
+ sim_list = []
58
+
59
+ with mp.Pool(processes=len(gpus)) as pool:
60
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
61
+ results = pool.map(run_sim, args)
62
+ for sim_ in results:
63
+ sim_list.extend(sim_)
64
+
65
+ sim = round(sum(sim_list)/len(sim_list), 3)
66
+ print(f"\nTotal {len(sim_list)} samples")
67
+ print(f"SIM : {sim}")
scripts/scripts_eval_seedtts_testset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Seed-TTS testset
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_seed_tts_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "zh" # zh | en
18
+ metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
+ # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
+ gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
+
22
+
23
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
25
+ gpus = [0,1,2,3,4,5,6,7]
26
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
+
28
+ local = False
29
+ if local: # use local custom checkpoint dir
30
+ if lang == "zh":
31
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
+ elif lang == "en":
33
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
+ else:
35
+ asr_ckpt_dir = "" # auto download to cache dir
36
+
37
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
38
+
39
+
40
+ # --------------------------- WER ---------------------------
41
+
42
+ if eval_task == "wer":
43
+ wers = []
44
+
45
+ with mp.Pool(processes=len(gpus)) as pool:
46
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
47
+ results = pool.map(run_asr_wer, args)
48
+ for wers_ in results:
49
+ wers.extend(wers_)
50
+
51
+ wer = round(np.mean(wers)*100, 3)
52
+ print(f"\nTotal {len(wers)} samples")
53
+ print(f"WER : {wer}%")
54
+
55
+
56
+ # --------------------------- SIM ---------------------------
57
+
58
+ if eval_task == "sim":
59
+ sim_list = []
60
+
61
+ with mp.Pool(processes=len(gpus)) as pool:
62
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
63
+ results = pool.map(run_sim, args)
64
+ for sim_ in results:
65
+ sim_list.extend(sim_)
66
+
67
+ sim = round(sum(sim_list)/len(sim_list), 3)
68
+ print(f"\nTotal {len(sim_list)} samples")
69
+ print(f"SIM : {sim}")
scripts/scripts_prepare_csv_wavs.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from pathlib import Path
5
+ import json
6
+ import shutil
7
+ import argparse
8
+
9
+ import csv
10
+ import torchaudio
11
+ from tqdm import tqdm
12
+ from datasets.arrow_writer import ArrowWriter
13
+
14
+ from model.utils import (
15
+ convert_char_to_pinyin,
16
+ )
17
+
18
+ PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
19
+
20
+ def is_csv_wavs_format(input_dataset_dir):
21
+ fpath = Path(input_dataset_dir)
22
+ metadata = fpath / "metadata.csv"
23
+ wavs = fpath / 'wavs'
24
+ return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
25
+
26
+
27
+ def prepare_csv_wavs_dir(input_dir):
28
+ assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
29
+ input_dir = Path(input_dir)
30
+ metadata_path = input_dir / "metadata.csv"
31
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
32
+
33
+ sub_result, durations = [], []
34
+ vocab_set = set()
35
+ polyphone = True
36
+ for audio_path, text in audio_path_text_pairs:
37
+ if not Path(audio_path).exists():
38
+ print(f"audio {audio_path} not found, skipping")
39
+ continue
40
+ audio_duration = get_audio_duration(audio_path)
41
+ # assume tokenizer = "pinyin" ("pinyin" | "char")
42
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
43
+ sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
44
+ durations.append(audio_duration)
45
+ vocab_set.update(list(text))
46
+
47
+ return sub_result, durations, vocab_set
48
+
49
+ def get_audio_duration(audio_path):
50
+ audio, sample_rate = torchaudio.load(audio_path)
51
+ num_channels = audio.shape[0]
52
+ return audio.shape[1] / (sample_rate * num_channels)
53
+
54
+ def read_audio_text_pairs(csv_file_path):
55
+ audio_text_pairs = []
56
+
57
+ parent = Path(csv_file_path).parent
58
+ with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
59
+ reader = csv.reader(csvfile, delimiter='|')
60
+ next(reader) # Skip the header row
61
+ for row in reader:
62
+ if len(row) >= 2:
63
+ audio_file = row[0].strip() # First column: audio file path
64
+ text = row[1].strip() # Second column: text
65
+ audio_file_path = parent / audio_file
66
+ audio_text_pairs.append((audio_file_path.as_posix(), text))
67
+
68
+ return audio_text_pairs
69
+
70
+
71
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
72
+ out_dir = Path(out_dir)
73
+ # save preprocessed dataset to disk
74
+ out_dir.mkdir(exist_ok=True, parents=True)
75
+ print(f"\nSaving to {out_dir} ...")
76
+
77
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
78
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
79
+ raw_arrow_path = out_dir / "raw.arrow"
80
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
81
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
82
+ writer.write(line)
83
+
84
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
85
+ dur_json_path = out_dir / "duration.json"
86
+ with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
87
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
88
+
89
+ # vocab map, i.e. tokenizer
90
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
91
+ # if tokenizer == "pinyin":
92
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
93
+ voca_out_path = out_dir / "vocab.txt"
94
+ with open(voca_out_path.as_posix(), "w") as f:
95
+ for vocab in sorted(text_vocab_set):
96
+ f.write(vocab + "\n")
97
+
98
+ if is_finetune:
99
+ file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
100
+ shutil.copy2(file_vocab_finetune, voca_out_path)
101
+ else:
102
+ with open(voca_out_path, "w") as f:
103
+ for vocab in sorted(text_vocab_set):
104
+ f.write(vocab + "\n")
105
+
106
+ dataset_name = out_dir.stem
107
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
108
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
109
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
110
+
111
+
112
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
113
+ if is_finetune:
114
+ assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
115
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
116
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
117
+
118
+
119
+ def cli():
120
+ # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
121
+ # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
122
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
123
+ parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
124
+ parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
125
+ parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
126
+
127
+ args = parser.parse_args()
128
+
129
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
130
+
131
+ if __name__ == "__main__":
132
+ cli()
scripts/scripts_prepare_emilia.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
2
+ # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
3
+
4
+ # generate audio text map for Emilia ZH & EN
5
+ # evaluate for vocab size
6
+
7
+ import sys, os
8
+ sys.path.append(os.getcwd())
9
+
10
+ from pathlib import Path
11
+ import json
12
+ from tqdm import tqdm
13
+ from concurrent.futures import ProcessPoolExecutor
14
+
15
+ from datasets import Dataset
16
+ from datasets.arrow_writer import ArrowWriter
17
+
18
+ from model.utils import (
19
+ repetition_found,
20
+ convert_char_to_pinyin,
21
+ )
22
+
23
+
24
+ out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
25
+ zh_filters = ["い", "て"]
26
+ # seems synthesized audios, or heavily code-switched
27
+ out_en = {
28
+ "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
+
30
+ "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
31
+ }
32
+ en_filters = ["ا", "い", "て"]
33
+
34
+
35
+ def deal_with_audio_dir(audio_dir):
36
+ audio_jsonl = audio_dir.with_suffix(".jsonl")
37
+ sub_result, durations = [], []
38
+ vocab_set = set()
39
+ bad_case_zh = 0
40
+ bad_case_en = 0
41
+ with open(audio_jsonl, "r") as f:
42
+ lines = f.readlines()
43
+ for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
+ obj = json.loads(line)
45
+ text = obj["text"]
46
+ if obj['language'] == "zh":
47
+ if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
+ bad_case_zh += 1
49
+ continue
50
+ else:
51
+ text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
+ if obj['language'] == "en":
53
+ if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
54
+ bad_case_en += 1
55
+ continue
56
+ if tokenizer == "pinyin":
57
+ text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
+ duration = obj["duration"]
59
+ sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
+ durations.append(duration)
61
+ vocab_set.update(list(text))
62
+ return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
63
+
64
+
65
+ def main():
66
+ assert tokenizer in ["pinyin", "char"]
67
+ result = []
68
+ duration_list = []
69
+ text_vocab_set = set()
70
+ total_bad_case_zh = 0
71
+ total_bad_case_en = 0
72
+
73
+ # process raw data
74
+ executor = ProcessPoolExecutor(max_workers=max_workers)
75
+ futures = []
76
+ for lang in langs:
77
+ dataset_path = Path(os.path.join(dataset_dir, lang))
78
+ [
79
+ futures.append(executor.submit(deal_with_audio_dir, audio_dir))
80
+ for audio_dir in dataset_path.iterdir()
81
+ if audio_dir.is_dir()
82
+ ]
83
+ for futures in tqdm(futures, total=len(futures)):
84
+ sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
85
+ result.extend(sub_result)
86
+ duration_list.extend(durations)
87
+ text_vocab_set.update(vocab_set)
88
+ total_bad_case_zh += bad_case_zh
89
+ total_bad_case_en += bad_case_en
90
+ executor.shutdown()
91
+
92
+ # save preprocessed dataset to disk
93
+ if not os.path.exists(f"data/{dataset_name}"):
94
+ os.makedirs(f"data/{dataset_name}")
95
+ print(f"\nSaving to data/{dataset_name} ...")
96
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
+ with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
+ writer.write(line)
101
+
102
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
103
+ with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
+
106
+ # vocab map, i.e. tokenizer
107
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
108
+ # if tokenizer == "pinyin":
109
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
110
+ with open(f"data/{dataset_name}/vocab.txt", "w") as f:
111
+ for vocab in sorted(text_vocab_set):
112
+ f.write(vocab + "\n")
113
+
114
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
+ if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
+ if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+
123
+ max_workers = 32
124
+
125
+ tokenizer = "pinyin" # "pinyin" | "char"
126
+ polyphone = True
127
+
128
+ langs = ["ZH", "EN"]
129
+ dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
130
+ dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
131
+ print(f"\nPrepare for {dataset_name}\n")
132
+
133
+ main()
134
+
135
+ # Emilia ZH & EN
136
+ # samples count 37837916 (after removal)
137
+ # pinyin vocab size 2543 (polyphone)
138
+ # total duration 95281.87 (hours)
139
+ # bad zh asr cnt 230435 (samples)
140
+ # bad eh asr cnt 37217 (samples)
141
+
142
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
143
+ # please be careful if using pretrained model, make sure the vocab.txt is same
scripts/scripts_prepare_wenetspeech4tts.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate audio text map for WenetSpeech4TTS
2
+ # evaluate for vocab size
3
+
4
+ import sys, os
5
+ sys.path.append(os.getcwd())
6
+
7
+ import json
8
+ from tqdm import tqdm
9
+ from concurrent.futures import ProcessPoolExecutor
10
+
11
+ import torchaudio
12
+ from datasets import Dataset
13
+
14
+ from model.utils import convert_char_to_pinyin
15
+
16
+
17
+ def deal_with_sub_path_files(dataset_path, sub_path):
18
+ print(f"Dealing with: {sub_path}")
19
+
20
+ text_dir = os.path.join(dataset_path, sub_path, "txts")
21
+ audio_dir = os.path.join(dataset_path, sub_path, "wavs")
22
+ text_files = os.listdir(text_dir)
23
+
24
+ audio_paths, texts, durations = [], [], []
25
+ for text_file in tqdm(text_files):
26
+ with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
+ first_line = file.readline().split("\t")
28
+ audio_nm = first_line[0]
29
+ audio_path = os.path.join(audio_dir, audio_nm + ".wav")
30
+ text = first_line[1].strip()
31
+
32
+ audio_paths.append(audio_path)
33
+
34
+ if tokenizer == "pinyin":
35
+ texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
+ elif tokenizer == "char":
37
+ texts.append(text)
38
+
39
+ audio, sample_rate = torchaudio.load(audio_path)
40
+ durations.append(audio.shape[-1] / sample_rate)
41
+
42
+ return audio_paths, texts, durations
43
+
44
+
45
+ def main():
46
+ assert tokenizer in ["pinyin", "char"]
47
+
48
+ audio_path_list, text_list, duration_list = [], [], []
49
+
50
+ executor = ProcessPoolExecutor(max_workers=max_workers)
51
+ futures = []
52
+ for dataset_path in dataset_paths:
53
+ sub_items = os.listdir(dataset_path)
54
+ sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
55
+ for sub_path in sub_paths:
56
+ futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
57
+ for future in tqdm(futures, total=len(futures)):
58
+ audio_paths, texts, durations = future.result()
59
+ audio_path_list.extend(audio_paths)
60
+ text_list.extend(texts)
61
+ duration_list.extend(durations)
62
+ executor.shutdown()
63
+
64
+ if not os.path.exists("data"):
65
+ os.makedirs("data")
66
+
67
+ print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
68
+ dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
+ dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
+
71
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
+ json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
73
+
74
+ print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
+ text_vocab_set = set()
76
+ for text in tqdm(text_list):
77
+ text_vocab_set.update(list(text))
78
+
79
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
80
+ if tokenizer == "pinyin":
81
+ text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
82
+
83
+ with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
84
+ for vocab in sorted(text_vocab_set):
85
+ f.write(vocab + "\n")
86
+ print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
+
89
+
90
+ if __name__ == "__main__":
91
+
92
+ max_workers = 32
93
+
94
+ tokenizer = "pinyin" # "pinyin" | "char"
95
+ polyphone = True
96
+ dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
+
98
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
+ dataset_paths = [
100
+ "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
+ "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
+ "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
+ ][-dataset_choice:]
104
+ print(f"\nChoose Dataset: {dataset_name}\n")
105
+
106
+ main()
107
+
108
+ # Results (if adding alphabets with accents and symbols):
109
+ # WenetSpeech4TTS Basic Standard Premium
110
+ # samples count 3932473 1941220 407494
111
+ # pinyin vocab size 1349 1348 1344 (no polyphone)
112
+ # - - 1459 (polyphone)
113
+ # char vocab size 5264 5219 5042
114
+
115
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
+ # please be careful if using pretrained model, make sure the vocab.txt is same