Spaces:
Running
Running
Staticaliza
commited on
Commit
•
89ea150
1
Parent(s):
dab6fe2
Upload 9 files
Browse files- scripts/scripts_count_max_epoch.py +32 -0
- scripts/scripts_count_params_gflops.py +35 -0
- scripts/scripts_eval_infer_batch.py +199 -0
- scripts/scripts_eval_infer_batch.sh +13 -0
- scripts/scripts_eval_librispeech_test_clean.py +67 -0
- scripts/scripts_eval_seedtts_testset.py +69 -0
- scripts/scripts_prepare_csv_wavs.py +132 -0
- scripts/scripts_prepare_emilia.py +143 -0
- scripts/scripts_prepare_wenetspeech4tts.py +116 -0
scripts/scripts_count_max_epoch.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''ADAPTIVE BATCH SIZE'''
|
2 |
+
print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
|
3 |
+
print(' -> least padding, gather wavs with accumulated frames in a batch\n')
|
4 |
+
|
5 |
+
# data
|
6 |
+
total_hours = 95282
|
7 |
+
mel_hop_length = 256
|
8 |
+
mel_sampling_rate = 24000
|
9 |
+
|
10 |
+
# target
|
11 |
+
wanted_max_updates = 1000000
|
12 |
+
|
13 |
+
# train params
|
14 |
+
gpus = 8
|
15 |
+
frames_per_gpu = 38400 # 8 * 38400 = 307200
|
16 |
+
grad_accum = 1
|
17 |
+
|
18 |
+
# intermediate
|
19 |
+
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
20 |
+
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
21 |
+
updates_per_epoch = total_hours / mini_batch_hours
|
22 |
+
steps_per_epoch = updates_per_epoch * grad_accum
|
23 |
+
|
24 |
+
# result
|
25 |
+
epochs = wanted_max_updates / updates_per_epoch
|
26 |
+
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
27 |
+
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
28 |
+
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
29 |
+
|
30 |
+
# others
|
31 |
+
print(f"total {total_hours:.0f} hours")
|
32 |
+
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
|
scripts/scripts_count_params_gflops.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
sys.path.append(os.getcwd())
|
3 |
+
|
4 |
+
from model import M2_TTS, UNetT, DiT, MMDiT
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import thop
|
8 |
+
|
9 |
+
|
10 |
+
''' ~155M '''
|
11 |
+
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
12 |
+
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
13 |
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
14 |
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
15 |
+
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
16 |
+
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
17 |
+
|
18 |
+
''' ~335M '''
|
19 |
+
# FLOPs: 622.1 G, Params: 333.2 M
|
20 |
+
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
21 |
+
# FLOPs: 363.4 G, Params: 335.8 M
|
22 |
+
transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
23 |
+
|
24 |
+
|
25 |
+
model = M2_TTS(transformer=transformer)
|
26 |
+
target_sample_rate = 24000
|
27 |
+
n_mel_channels = 100
|
28 |
+
hop_length = 256
|
29 |
+
duration = 20
|
30 |
+
frame_length = int(duration * target_sample_rate / hop_length)
|
31 |
+
text_length = 150
|
32 |
+
|
33 |
+
flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
|
34 |
+
print(f"FLOPs: {flops / 1e9} G")
|
35 |
+
print(f"Params: {params / 1e6} M")
|
scripts/scripts_eval_infer_batch.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
sys.path.append(os.getcwd())
|
3 |
+
|
4 |
+
import time
|
5 |
+
import random
|
6 |
+
from tqdm import tqdm
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
from accelerate import Accelerator
|
12 |
+
from einops import rearrange
|
13 |
+
from vocos import Vocos
|
14 |
+
|
15 |
+
from model import CFM, UNetT, DiT
|
16 |
+
from model.utils import (
|
17 |
+
load_checkpoint,
|
18 |
+
get_tokenizer,
|
19 |
+
get_seedtts_testset_metainfo,
|
20 |
+
get_librispeech_test_clean_metainfo,
|
21 |
+
get_inference_prompt,
|
22 |
+
)
|
23 |
+
|
24 |
+
accelerator = Accelerator()
|
25 |
+
device = f"cuda:{accelerator.process_index}"
|
26 |
+
|
27 |
+
|
28 |
+
# --------------------- Dataset Settings -------------------- #
|
29 |
+
|
30 |
+
target_sample_rate = 24000
|
31 |
+
n_mel_channels = 100
|
32 |
+
hop_length = 256
|
33 |
+
target_rms = 0.1
|
34 |
+
|
35 |
+
tokenizer = "pinyin"
|
36 |
+
|
37 |
+
|
38 |
+
# ---------------------- infer setting ---------------------- #
|
39 |
+
|
40 |
+
parser = argparse.ArgumentParser(description="batch inference")
|
41 |
+
|
42 |
+
parser.add_argument('-s', '--seed', default=None, type=int)
|
43 |
+
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
|
44 |
+
parser.add_argument('-n', '--expname', required=True)
|
45 |
+
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
|
46 |
+
|
47 |
+
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
|
48 |
+
parser.add_argument('-o', '--odemethod', default="euler")
|
49 |
+
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
|
50 |
+
|
51 |
+
parser.add_argument('-t', '--testset', required=True)
|
52 |
+
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
|
56 |
+
seed = args.seed
|
57 |
+
dataset_name = args.dataset
|
58 |
+
exp_name = args.expname
|
59 |
+
ckpt_step = args.ckptstep
|
60 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
61 |
+
|
62 |
+
nfe_step = args.nfestep
|
63 |
+
ode_method = args.odemethod
|
64 |
+
sway_sampling_coef = args.swaysampling
|
65 |
+
|
66 |
+
testset = args.testset
|
67 |
+
|
68 |
+
|
69 |
+
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
70 |
+
cfg_strength = 2.
|
71 |
+
speed = 1.
|
72 |
+
use_truth_duration = False
|
73 |
+
no_ref_audio = False
|
74 |
+
|
75 |
+
|
76 |
+
if exp_name == "F5TTS_Base":
|
77 |
+
model_cls = DiT
|
78 |
+
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
79 |
+
|
80 |
+
elif exp_name == "E2TTS_Base":
|
81 |
+
model_cls = UNetT
|
82 |
+
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
83 |
+
|
84 |
+
|
85 |
+
if testset == "ls_pc_test_clean":
|
86 |
+
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
87 |
+
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
88 |
+
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
89 |
+
|
90 |
+
elif testset == "seedtts_test_zh":
|
91 |
+
metalst = "data/seedtts_testset/zh/meta.lst"
|
92 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
93 |
+
|
94 |
+
elif testset == "seedtts_test_en":
|
95 |
+
metalst = "data/seedtts_testset/en/meta.lst"
|
96 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
97 |
+
|
98 |
+
|
99 |
+
# path to save genereted wavs
|
100 |
+
if seed is None: seed = random.randint(-10000, 10000)
|
101 |
+
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
102 |
+
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
|
103 |
+
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
|
104 |
+
f"_cfg{cfg_strength}_speed{speed}" \
|
105 |
+
f"{'_gt-dur' if use_truth_duration else ''}" \
|
106 |
+
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
107 |
+
|
108 |
+
|
109 |
+
# -------------------------------------------------#
|
110 |
+
|
111 |
+
use_ema = True
|
112 |
+
|
113 |
+
prompts_all = get_inference_prompt(
|
114 |
+
metainfo,
|
115 |
+
speed = speed,
|
116 |
+
tokenizer = tokenizer,
|
117 |
+
target_sample_rate = target_sample_rate,
|
118 |
+
n_mel_channels = n_mel_channels,
|
119 |
+
hop_length = hop_length,
|
120 |
+
target_rms = target_rms,
|
121 |
+
use_truth_duration = use_truth_duration,
|
122 |
+
infer_batch_size = infer_batch_size,
|
123 |
+
)
|
124 |
+
|
125 |
+
# Vocoder model
|
126 |
+
local = False
|
127 |
+
if local:
|
128 |
+
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
129 |
+
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
130 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
131 |
+
vocos.load_state_dict(state_dict)
|
132 |
+
vocos.eval()
|
133 |
+
else:
|
134 |
+
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
135 |
+
|
136 |
+
# Tokenizer
|
137 |
+
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
138 |
+
|
139 |
+
# Model
|
140 |
+
model = CFM(
|
141 |
+
transformer = model_cls(
|
142 |
+
**model_cfg,
|
143 |
+
text_num_embeds = vocab_size,
|
144 |
+
mel_dim = n_mel_channels
|
145 |
+
),
|
146 |
+
mel_spec_kwargs = dict(
|
147 |
+
target_sample_rate = target_sample_rate,
|
148 |
+
n_mel_channels = n_mel_channels,
|
149 |
+
hop_length = hop_length,
|
150 |
+
),
|
151 |
+
odeint_kwargs = dict(
|
152 |
+
method = ode_method,
|
153 |
+
),
|
154 |
+
vocab_char_map = vocab_char_map,
|
155 |
+
).to(device)
|
156 |
+
|
157 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
158 |
+
|
159 |
+
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
160 |
+
os.makedirs(output_dir)
|
161 |
+
|
162 |
+
# start batch inference
|
163 |
+
accelerator.wait_for_everyone()
|
164 |
+
start = time.time()
|
165 |
+
|
166 |
+
with accelerator.split_between_processes(prompts_all) as prompts:
|
167 |
+
|
168 |
+
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
169 |
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
170 |
+
ref_mels = ref_mels.to(device)
|
171 |
+
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
|
172 |
+
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
|
173 |
+
|
174 |
+
# Inference
|
175 |
+
with torch.inference_mode():
|
176 |
+
generated, _ = model.sample(
|
177 |
+
cond = ref_mels,
|
178 |
+
text = final_text_list,
|
179 |
+
duration = total_mel_lens,
|
180 |
+
lens = ref_mel_lens,
|
181 |
+
steps = nfe_step,
|
182 |
+
cfg_strength = cfg_strength,
|
183 |
+
sway_sampling_coef = sway_sampling_coef,
|
184 |
+
no_ref_audio = no_ref_audio,
|
185 |
+
seed = seed,
|
186 |
+
)
|
187 |
+
# Final result
|
188 |
+
for i, gen in enumerate(generated):
|
189 |
+
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
|
190 |
+
gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
|
191 |
+
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
192 |
+
if ref_rms_list[i] < target_rms:
|
193 |
+
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
194 |
+
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
195 |
+
|
196 |
+
accelerator.wait_for_everyone()
|
197 |
+
if accelerator.is_main_process:
|
198 |
+
timediff = time.time() - start
|
199 |
+
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
scripts/scripts_eval_infer_batch.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# e.g. F5-TTS, 16 NFE
|
4 |
+
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
5 |
+
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
6 |
+
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
7 |
+
|
8 |
+
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
+
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
+
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
+
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
12 |
+
|
13 |
+
# etc.
|
scripts/scripts_eval_librispeech_test_clean.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
+
|
3 |
+
import sys, os
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import multiprocessing as mp
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from model.utils import (
|
10 |
+
get_librispeech_test,
|
11 |
+
run_asr_wer,
|
12 |
+
run_sim,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
eval_task = "wer" # sim | wer
|
17 |
+
lang = "en"
|
18 |
+
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
19 |
+
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
20 |
+
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
21 |
+
|
22 |
+
gpus = [0,1,2,3,4,5,6,7]
|
23 |
+
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
24 |
+
|
25 |
+
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
26 |
+
## leading to a low similarity for the ground truth in some cases.
|
27 |
+
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
|
28 |
+
|
29 |
+
local = False
|
30 |
+
if local: # use local custom checkpoint dir
|
31 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
32 |
+
else:
|
33 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
34 |
+
|
35 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
36 |
+
|
37 |
+
|
38 |
+
# --------------------------- WER ---------------------------
|
39 |
+
|
40 |
+
if eval_task == "wer":
|
41 |
+
wers = []
|
42 |
+
|
43 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
44 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
45 |
+
results = pool.map(run_asr_wer, args)
|
46 |
+
for wers_ in results:
|
47 |
+
wers.extend(wers_)
|
48 |
+
|
49 |
+
wer = round(np.mean(wers)*100, 3)
|
50 |
+
print(f"\nTotal {len(wers)} samples")
|
51 |
+
print(f"WER : {wer}%")
|
52 |
+
|
53 |
+
|
54 |
+
# --------------------------- SIM ---------------------------
|
55 |
+
|
56 |
+
if eval_task == "sim":
|
57 |
+
sim_list = []
|
58 |
+
|
59 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
60 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
61 |
+
results = pool.map(run_sim, args)
|
62 |
+
for sim_ in results:
|
63 |
+
sim_list.extend(sim_)
|
64 |
+
|
65 |
+
sim = round(sum(sim_list)/len(sim_list), 3)
|
66 |
+
print(f"\nTotal {len(sim_list)} samples")
|
67 |
+
print(f"SIM : {sim}")
|
scripts/scripts_eval_seedtts_testset.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate with Seed-TTS testset
|
2 |
+
|
3 |
+
import sys, os
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import multiprocessing as mp
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from model.utils import (
|
10 |
+
get_seed_tts_test,
|
11 |
+
run_asr_wer,
|
12 |
+
run_sim,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
eval_task = "wer" # sim | wer
|
17 |
+
lang = "zh" # zh | en
|
18 |
+
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
19 |
+
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
|
20 |
+
gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
|
21 |
+
|
22 |
+
|
23 |
+
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
24 |
+
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
25 |
+
gpus = [0,1,2,3,4,5,6,7]
|
26 |
+
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
27 |
+
|
28 |
+
local = False
|
29 |
+
if local: # use local custom checkpoint dir
|
30 |
+
if lang == "zh":
|
31 |
+
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
32 |
+
elif lang == "en":
|
33 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
34 |
+
else:
|
35 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
36 |
+
|
37 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
38 |
+
|
39 |
+
|
40 |
+
# --------------------------- WER ---------------------------
|
41 |
+
|
42 |
+
if eval_task == "wer":
|
43 |
+
wers = []
|
44 |
+
|
45 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
46 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
47 |
+
results = pool.map(run_asr_wer, args)
|
48 |
+
for wers_ in results:
|
49 |
+
wers.extend(wers_)
|
50 |
+
|
51 |
+
wer = round(np.mean(wers)*100, 3)
|
52 |
+
print(f"\nTotal {len(wers)} samples")
|
53 |
+
print(f"WER : {wer}%")
|
54 |
+
|
55 |
+
|
56 |
+
# --------------------------- SIM ---------------------------
|
57 |
+
|
58 |
+
if eval_task == "sim":
|
59 |
+
sim_list = []
|
60 |
+
|
61 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
+
results = pool.map(run_sim, args)
|
64 |
+
for sim_ in results:
|
65 |
+
sim_list.extend(sim_)
|
66 |
+
|
67 |
+
sim = round(sum(sim_list)/len(sim_list), 3)
|
68 |
+
print(f"\nTotal {len(sim_list)} samples")
|
69 |
+
print(f"SIM : {sim}")
|
scripts/scripts_prepare_csv_wavs.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
sys.path.append(os.getcwd())
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
import json
|
6 |
+
import shutil
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import csv
|
10 |
+
import torchaudio
|
11 |
+
from tqdm import tqdm
|
12 |
+
from datasets.arrow_writer import ArrowWriter
|
13 |
+
|
14 |
+
from model.utils import (
|
15 |
+
convert_char_to_pinyin,
|
16 |
+
)
|
17 |
+
|
18 |
+
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
19 |
+
|
20 |
+
def is_csv_wavs_format(input_dataset_dir):
|
21 |
+
fpath = Path(input_dataset_dir)
|
22 |
+
metadata = fpath / "metadata.csv"
|
23 |
+
wavs = fpath / 'wavs'
|
24 |
+
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
25 |
+
|
26 |
+
|
27 |
+
def prepare_csv_wavs_dir(input_dir):
|
28 |
+
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
29 |
+
input_dir = Path(input_dir)
|
30 |
+
metadata_path = input_dir / "metadata.csv"
|
31 |
+
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
32 |
+
|
33 |
+
sub_result, durations = [], []
|
34 |
+
vocab_set = set()
|
35 |
+
polyphone = True
|
36 |
+
for audio_path, text in audio_path_text_pairs:
|
37 |
+
if not Path(audio_path).exists():
|
38 |
+
print(f"audio {audio_path} not found, skipping")
|
39 |
+
continue
|
40 |
+
audio_duration = get_audio_duration(audio_path)
|
41 |
+
# assume tokenizer = "pinyin" ("pinyin" | "char")
|
42 |
+
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
43 |
+
sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
|
44 |
+
durations.append(audio_duration)
|
45 |
+
vocab_set.update(list(text))
|
46 |
+
|
47 |
+
return sub_result, durations, vocab_set
|
48 |
+
|
49 |
+
def get_audio_duration(audio_path):
|
50 |
+
audio, sample_rate = torchaudio.load(audio_path)
|
51 |
+
num_channels = audio.shape[0]
|
52 |
+
return audio.shape[1] / (sample_rate * num_channels)
|
53 |
+
|
54 |
+
def read_audio_text_pairs(csv_file_path):
|
55 |
+
audio_text_pairs = []
|
56 |
+
|
57 |
+
parent = Path(csv_file_path).parent
|
58 |
+
with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
|
59 |
+
reader = csv.reader(csvfile, delimiter='|')
|
60 |
+
next(reader) # Skip the header row
|
61 |
+
for row in reader:
|
62 |
+
if len(row) >= 2:
|
63 |
+
audio_file = row[0].strip() # First column: audio file path
|
64 |
+
text = row[1].strip() # Second column: text
|
65 |
+
audio_file_path = parent / audio_file
|
66 |
+
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
67 |
+
|
68 |
+
return audio_text_pairs
|
69 |
+
|
70 |
+
|
71 |
+
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
72 |
+
out_dir = Path(out_dir)
|
73 |
+
# save preprocessed dataset to disk
|
74 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
75 |
+
print(f"\nSaving to {out_dir} ...")
|
76 |
+
|
77 |
+
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
78 |
+
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
79 |
+
raw_arrow_path = out_dir / "raw.arrow"
|
80 |
+
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
81 |
+
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
82 |
+
writer.write(line)
|
83 |
+
|
84 |
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
85 |
+
dur_json_path = out_dir / "duration.json"
|
86 |
+
with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
|
87 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
88 |
+
|
89 |
+
# vocab map, i.e. tokenizer
|
90 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
91 |
+
# if tokenizer == "pinyin":
|
92 |
+
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
93 |
+
voca_out_path = out_dir / "vocab.txt"
|
94 |
+
with open(voca_out_path.as_posix(), "w") as f:
|
95 |
+
for vocab in sorted(text_vocab_set):
|
96 |
+
f.write(vocab + "\n")
|
97 |
+
|
98 |
+
if is_finetune:
|
99 |
+
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
100 |
+
shutil.copy2(file_vocab_finetune, voca_out_path)
|
101 |
+
else:
|
102 |
+
with open(voca_out_path, "w") as f:
|
103 |
+
for vocab in sorted(text_vocab_set):
|
104 |
+
f.write(vocab + "\n")
|
105 |
+
|
106 |
+
dataset_name = out_dir.stem
|
107 |
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
108 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
109 |
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
110 |
+
|
111 |
+
|
112 |
+
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
|
113 |
+
if is_finetune:
|
114 |
+
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
115 |
+
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
|
116 |
+
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
117 |
+
|
118 |
+
|
119 |
+
def cli():
|
120 |
+
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
|
121 |
+
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
|
122 |
+
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
123 |
+
parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
|
124 |
+
parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
|
125 |
+
parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
|
126 |
+
|
127 |
+
args = parser.parse_args()
|
128 |
+
|
129 |
+
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
cli()
|
scripts/scripts_prepare_emilia.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
|
2 |
+
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
|
3 |
+
|
4 |
+
# generate audio text map for Emilia ZH & EN
|
5 |
+
# evaluate for vocab size
|
6 |
+
|
7 |
+
import sys, os
|
8 |
+
sys.path.append(os.getcwd())
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
import json
|
12 |
+
from tqdm import tqdm
|
13 |
+
from concurrent.futures import ProcessPoolExecutor
|
14 |
+
|
15 |
+
from datasets import Dataset
|
16 |
+
from datasets.arrow_writer import ArrowWriter
|
17 |
+
|
18 |
+
from model.utils import (
|
19 |
+
repetition_found,
|
20 |
+
convert_char_to_pinyin,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
|
25 |
+
zh_filters = ["い", "て"]
|
26 |
+
# seems synthesized audios, or heavily code-switched
|
27 |
+
out_en = {
|
28 |
+
"EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
|
29 |
+
|
30 |
+
"EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
|
31 |
+
}
|
32 |
+
en_filters = ["ا", "い", "て"]
|
33 |
+
|
34 |
+
|
35 |
+
def deal_with_audio_dir(audio_dir):
|
36 |
+
audio_jsonl = audio_dir.with_suffix(".jsonl")
|
37 |
+
sub_result, durations = [], []
|
38 |
+
vocab_set = set()
|
39 |
+
bad_case_zh = 0
|
40 |
+
bad_case_en = 0
|
41 |
+
with open(audio_jsonl, "r") as f:
|
42 |
+
lines = f.readlines()
|
43 |
+
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
44 |
+
obj = json.loads(line)
|
45 |
+
text = obj["text"]
|
46 |
+
if obj['language'] == "zh":
|
47 |
+
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
48 |
+
bad_case_zh += 1
|
49 |
+
continue
|
50 |
+
else:
|
51 |
+
text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
|
52 |
+
if obj['language'] == "en":
|
53 |
+
if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
|
54 |
+
bad_case_en += 1
|
55 |
+
continue
|
56 |
+
if tokenizer == "pinyin":
|
57 |
+
text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
|
58 |
+
duration = obj["duration"]
|
59 |
+
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
60 |
+
durations.append(duration)
|
61 |
+
vocab_set.update(list(text))
|
62 |
+
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
|
63 |
+
|
64 |
+
|
65 |
+
def main():
|
66 |
+
assert tokenizer in ["pinyin", "char"]
|
67 |
+
result = []
|
68 |
+
duration_list = []
|
69 |
+
text_vocab_set = set()
|
70 |
+
total_bad_case_zh = 0
|
71 |
+
total_bad_case_en = 0
|
72 |
+
|
73 |
+
# process raw data
|
74 |
+
executor = ProcessPoolExecutor(max_workers=max_workers)
|
75 |
+
futures = []
|
76 |
+
for lang in langs:
|
77 |
+
dataset_path = Path(os.path.join(dataset_dir, lang))
|
78 |
+
[
|
79 |
+
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
|
80 |
+
for audio_dir in dataset_path.iterdir()
|
81 |
+
if audio_dir.is_dir()
|
82 |
+
]
|
83 |
+
for futures in tqdm(futures, total=len(futures)):
|
84 |
+
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
|
85 |
+
result.extend(sub_result)
|
86 |
+
duration_list.extend(durations)
|
87 |
+
text_vocab_set.update(vocab_set)
|
88 |
+
total_bad_case_zh += bad_case_zh
|
89 |
+
total_bad_case_en += bad_case_en
|
90 |
+
executor.shutdown()
|
91 |
+
|
92 |
+
# save preprocessed dataset to disk
|
93 |
+
if not os.path.exists(f"data/{dataset_name}"):
|
94 |
+
os.makedirs(f"data/{dataset_name}")
|
95 |
+
print(f"\nSaving to data/{dataset_name} ...")
|
96 |
+
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
97 |
+
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
98 |
+
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
|
99 |
+
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
100 |
+
writer.write(line)
|
101 |
+
|
102 |
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
103 |
+
with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
|
104 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
105 |
+
|
106 |
+
# vocab map, i.e. tokenizer
|
107 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
108 |
+
# if tokenizer == "pinyin":
|
109 |
+
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
110 |
+
with open(f"data/{dataset_name}/vocab.txt", "w") as f:
|
111 |
+
for vocab in sorted(text_vocab_set):
|
112 |
+
f.write(vocab + "\n")
|
113 |
+
|
114 |
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
115 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
116 |
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
117 |
+
if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
|
118 |
+
if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
|
123 |
+
max_workers = 32
|
124 |
+
|
125 |
+
tokenizer = "pinyin" # "pinyin" | "char"
|
126 |
+
polyphone = True
|
127 |
+
|
128 |
+
langs = ["ZH", "EN"]
|
129 |
+
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
|
130 |
+
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
|
131 |
+
print(f"\nPrepare for {dataset_name}\n")
|
132 |
+
|
133 |
+
main()
|
134 |
+
|
135 |
+
# Emilia ZH & EN
|
136 |
+
# samples count 37837916 (after removal)
|
137 |
+
# pinyin vocab size 2543 (polyphone)
|
138 |
+
# total duration 95281.87 (hours)
|
139 |
+
# bad zh asr cnt 230435 (samples)
|
140 |
+
# bad eh asr cnt 37217 (samples)
|
141 |
+
|
142 |
+
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
143 |
+
# please be careful if using pretrained model, make sure the vocab.txt is same
|
scripts/scripts_prepare_wenetspeech4tts.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# generate audio text map for WenetSpeech4TTS
|
2 |
+
# evaluate for vocab size
|
3 |
+
|
4 |
+
import sys, os
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
+
import json
|
8 |
+
from tqdm import tqdm
|
9 |
+
from concurrent.futures import ProcessPoolExecutor
|
10 |
+
|
11 |
+
import torchaudio
|
12 |
+
from datasets import Dataset
|
13 |
+
|
14 |
+
from model.utils import convert_char_to_pinyin
|
15 |
+
|
16 |
+
|
17 |
+
def deal_with_sub_path_files(dataset_path, sub_path):
|
18 |
+
print(f"Dealing with: {sub_path}")
|
19 |
+
|
20 |
+
text_dir = os.path.join(dataset_path, sub_path, "txts")
|
21 |
+
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
|
22 |
+
text_files = os.listdir(text_dir)
|
23 |
+
|
24 |
+
audio_paths, texts, durations = [], [], []
|
25 |
+
for text_file in tqdm(text_files):
|
26 |
+
with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
|
27 |
+
first_line = file.readline().split("\t")
|
28 |
+
audio_nm = first_line[0]
|
29 |
+
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
30 |
+
text = first_line[1].strip()
|
31 |
+
|
32 |
+
audio_paths.append(audio_path)
|
33 |
+
|
34 |
+
if tokenizer == "pinyin":
|
35 |
+
texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
|
36 |
+
elif tokenizer == "char":
|
37 |
+
texts.append(text)
|
38 |
+
|
39 |
+
audio, sample_rate = torchaudio.load(audio_path)
|
40 |
+
durations.append(audio.shape[-1] / sample_rate)
|
41 |
+
|
42 |
+
return audio_paths, texts, durations
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
assert tokenizer in ["pinyin", "char"]
|
47 |
+
|
48 |
+
audio_path_list, text_list, duration_list = [], [], []
|
49 |
+
|
50 |
+
executor = ProcessPoolExecutor(max_workers=max_workers)
|
51 |
+
futures = []
|
52 |
+
for dataset_path in dataset_paths:
|
53 |
+
sub_items = os.listdir(dataset_path)
|
54 |
+
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
|
55 |
+
for sub_path in sub_paths:
|
56 |
+
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
|
57 |
+
for future in tqdm(futures, total=len(futures)):
|
58 |
+
audio_paths, texts, durations = future.result()
|
59 |
+
audio_path_list.extend(audio_paths)
|
60 |
+
text_list.extend(texts)
|
61 |
+
duration_list.extend(durations)
|
62 |
+
executor.shutdown()
|
63 |
+
|
64 |
+
if not os.path.exists("data"):
|
65 |
+
os.makedirs("data")
|
66 |
+
|
67 |
+
print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
|
68 |
+
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
69 |
+
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
|
70 |
+
|
71 |
+
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
|
72 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
73 |
+
|
74 |
+
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
75 |
+
text_vocab_set = set()
|
76 |
+
for text in tqdm(text_list):
|
77 |
+
text_vocab_set.update(list(text))
|
78 |
+
|
79 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
80 |
+
if tokenizer == "pinyin":
|
81 |
+
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
82 |
+
|
83 |
+
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
|
84 |
+
for vocab in sorted(text_vocab_set):
|
85 |
+
f.write(vocab + "\n")
|
86 |
+
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
87 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
|
92 |
+
max_workers = 32
|
93 |
+
|
94 |
+
tokenizer = "pinyin" # "pinyin" | "char"
|
95 |
+
polyphone = True
|
96 |
+
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
97 |
+
|
98 |
+
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
|
99 |
+
dataset_paths = [
|
100 |
+
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
101 |
+
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
102 |
+
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
103 |
+
][-dataset_choice:]
|
104 |
+
print(f"\nChoose Dataset: {dataset_name}\n")
|
105 |
+
|
106 |
+
main()
|
107 |
+
|
108 |
+
# Results (if adding alphabets with accents and symbols):
|
109 |
+
# WenetSpeech4TTS Basic Standard Premium
|
110 |
+
# samples count 3932473 1941220 407494
|
111 |
+
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
112 |
+
# - - 1459 (polyphone)
|
113 |
+
# char vocab size 5264 5219 5042
|
114 |
+
|
115 |
+
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
116 |
+
# please be careful if using pretrained model, make sure the vocab.txt is same
|