Gregniuki commited on
Commit
2c9ddc1
1 Parent(s): 175e3f9

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model/model_utils.py +580 -0
model/model_utils.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import math
6
+ import random
7
+ import string
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pylab as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ import torchaudio
19
+
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
+ import jieba
24
+ from pypinyin import lazy_pinyin, Style
25
+
26
+ from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
+ from model.modules import MelSpec
28
+
29
+
30
+ # seed everything
31
+
32
+ def seed_everything(seed = 0):
33
+ random.seed(seed)
34
+ os.environ['PYTHONHASHSEED'] = str(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+ # helpers
42
+
43
+ def exists(v):
44
+ return v is not None
45
+
46
+ def default(v, d):
47
+ return v if exists(v) else d
48
+
49
+ # tensor helpers
50
+
51
+ def lens_to_mask(
52
+ t: int['b'],
53
+ length: int | None = None
54
+ ) -> bool['b n']:
55
+
56
+ if not exists(length):
57
+ length = t.amax()
58
+
59
+ seq = torch.arange(length, device = t.device)
60
+ return einx.less('n, b -> b n', seq, t)
61
+
62
+ def mask_from_start_end_indices(
63
+ seq_len: int['b'],
64
+ start: int['b'],
65
+ end: int['b']
66
+ ):
67
+ max_seq_len = seq_len.max().item()
68
+ seq = torch.arange(max_seq_len, device = start.device).long()
69
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
70
+
71
+ def mask_from_frac_lengths(
72
+ seq_len: int['b'],
73
+ frac_lengths: float['b']
74
+ ):
75
+ lengths = (frac_lengths * seq_len).long()
76
+ max_start = seq_len - lengths
77
+
78
+ rand = torch.rand_like(frac_lengths)
79
+ start = (max_start * rand).long().clamp(min = 0)
80
+ end = start + lengths
81
+
82
+ return mask_from_start_end_indices(seq_len, start, end)
83
+
84
+ def maybe_masked_mean(
85
+ t: float['b n d'],
86
+ mask: bool['b n'] = None
87
+ ) -> float['b d']:
88
+
89
+ if not exists(mask):
90
+ return t.mean(dim = 1)
91
+
92
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
+ num = reduce(t, 'b n d -> b d', 'sum')
94
+ den = reduce(mask.float(), 'b n -> b', 'sum')
95
+
96
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
+
98
+
99
+ # simple utf-8 tokenizer, since paper went character based
100
+ def list_str_to_tensor(
101
+ text: list[str],
102
+ padding_value = -1
103
+ ) -> int['b nt']:
104
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
105
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
106
+ return text
107
+
108
+ # char tokenizer, based on custom dataset's extracted .txt file
109
+ def list_str_to_idx(
110
+ text: list[str] | list[list[str]],
111
+ vocab_char_map: dict[str, int], # {char: idx}
112
+ padding_value = -1
113
+ ) -> int['b nt']:
114
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
115
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
116
+ return text
117
+
118
+
119
+ # Get tokenizer
120
+
121
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
122
+ '''
123
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
+ - "char" for char-wise tokenizer, need .txt vocab_file
125
+ - "byte" for utf-8 tokenizer
126
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
+ - if use "byte", set to 256 (unicode byte range)
130
+ '''
131
+ if tokenizer in ["pinyin", "char"]:
132
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
133
+ vocab_char_map = {}
134
+ for i, char in enumerate(f):
135
+ vocab_char_map[char[:-1]] = i
136
+ vocab_size = len(vocab_char_map)
137
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
138
+
139
+ elif tokenizer == "byte":
140
+ vocab_char_map = None
141
+ vocab_size = 256
142
+ elif tokenizer == "custom":
143
+ with open (dataset_name, "r", encoding="utf-8") as f:
144
+ vocab_char_map = {}
145
+ for i, char in enumerate(f):
146
+ vocab_char_map[char[:-1]] = i
147
+ vocab_size = len(vocab_char_map)
148
+
149
+ return vocab_char_map, vocab_size
150
+
151
+
152
+ # convert char to pinyin
153
+
154
+ def convert_char_to_pinyin(text_list, polyphone = True):
155
+ final_text_list = []
156
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
157
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
158
+ for text in text_list:
159
+ char_list = []
160
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
161
+ text = text.translate(custom_trans)
162
+ for seg in jieba.cut(text):
163
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
164
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
165
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
166
+ char_list.append(" ")
167
+ char_list.extend(seg)
168
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
169
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
170
+ for c in seg:
171
+ if c not in "。,、;:?!《》【】—…":
172
+ char_list.append(" ")
173
+ char_list.append(c)
174
+ else: # if mixed chinese characters, alphabets and symbols
175
+ for c in seg:
176
+ if ord(c) < 256:
177
+ char_list.extend(c)
178
+ else:
179
+ if c not in "。,、;:?!《》【】—…":
180
+ char_list.append(" ")
181
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
182
+ else: # if is zh punc
183
+ char_list.append(c)
184
+ final_text_list.append(char_list)
185
+
186
+ return final_text_list
187
+
188
+
189
+ # save spectrogram
190
+ def save_spectrogram(spectrogram, path):
191
+ plt.figure(figsize=(12, 4))
192
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
193
+ plt.colorbar()
194
+ plt.savefig(path)
195
+ plt.close()
196
+
197
+
198
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
199
+ def get_seedtts_testset_metainfo(metalst):
200
+ f = open(metalst); lines = f.readlines(); f.close()
201
+ metainfo = []
202
+ for line in lines:
203
+ if len(line.strip().split('|')) == 5:
204
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
205
+ elif len(line.strip().split('|')) == 4:
206
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
207
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
208
+ if not os.path.isabs(prompt_wav):
209
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
210
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
211
+ return metainfo
212
+
213
+
214
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
215
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
216
+ f = open(metalst); lines = f.readlines(); f.close()
217
+ metainfo = []
218
+ for line in lines:
219
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
220
+
221
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
222
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
223
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
224
+
225
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
226
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
227
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
228
+
229
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
230
+
231
+ return metainfo
232
+
233
+
234
+ # padded to max length mel batch
235
+ def padded_mel_batch(ref_mels):
236
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
237
+ padded_ref_mels = []
238
+ for mel in ref_mels:
239
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
+ padded_ref_mels.append(padded_ref_mel)
241
+ padded_ref_mels = torch.stack(padded_ref_mels)
242
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
+ return padded_ref_mels
244
+
245
+
246
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
247
+
248
+ def get_inference_prompt(
249
+ metainfo,
250
+ speed = 1., tokenizer = "pinyin", polyphone = True,
251
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
252
+ use_truth_duration = False,
253
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
254
+ ):
255
+ prompts_all = []
256
+
257
+ min_tokens = min_secs * target_sample_rate // hop_length
258
+ max_tokens = max_secs * target_sample_rate // hop_length
259
+
260
+ batch_accum = [0] * num_buckets
261
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
262
+ ([[] for _ in range(num_buckets)] for _ in range(6))
263
+
264
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
265
+
266
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
267
+
268
+ # Audio
269
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
270
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
271
+ if ref_rms < target_rms:
272
+ ref_audio = ref_audio * target_rms / ref_rms
273
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
274
+ if ref_sr != target_sample_rate:
275
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
276
+ ref_audio = resampler(ref_audio)
277
+
278
+ # Text
279
+ if len(prompt_text[-1].encode('utf-8')) == 1:
280
+ prompt_text = prompt_text + " "
281
+ text = [prompt_text + gt_text]
282
+ if tokenizer == "pinyin":
283
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
284
+ else:
285
+ text_list = text
286
+
287
+ # Duration, mel frame length
288
+ ref_mel_len = ref_audio.shape[-1] // hop_length
289
+ if use_truth_duration:
290
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
291
+ if gt_sr != target_sample_rate:
292
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
293
+ gt_audio = resampler(gt_audio)
294
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
295
+
296
+ # # test vocoder resynthesis
297
+ # ref_audio = gt_audio
298
+ else:
299
+ zh_pause_punc = r"。,、;:?!"
300
+ ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
301
+ gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
302
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
303
+
304
+ # to mel spectrogram
305
+ ref_mel = mel_spectrogram(ref_audio)
306
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
307
+
308
+ # deal with batch
309
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
310
+ assert min_tokens <= total_mel_len <= max_tokens, \
311
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
312
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
313
+
314
+ utts[bucket_i].append(utt)
315
+ ref_rms_list[bucket_i].append(ref_rms)
316
+ ref_mels[bucket_i].append(ref_mel)
317
+ ref_mel_lens[bucket_i].append(ref_mel_len)
318
+ total_mel_lens[bucket_i].append(total_mel_len)
319
+ final_text_list[bucket_i].extend(text_list)
320
+
321
+ batch_accum[bucket_i] += total_mel_len
322
+
323
+ if batch_accum[bucket_i] >= infer_batch_size:
324
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
325
+ prompts_all.append((
326
+ utts[bucket_i],
327
+ ref_rms_list[bucket_i],
328
+ padded_mel_batch(ref_mels[bucket_i]),
329
+ ref_mel_lens[bucket_i],
330
+ total_mel_lens[bucket_i],
331
+ final_text_list[bucket_i]
332
+ ))
333
+ batch_accum[bucket_i] = 0
334
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
335
+
336
+ # add residual
337
+ for bucket_i, bucket_frames in enumerate(batch_accum):
338
+ if bucket_frames > 0:
339
+ prompts_all.append((
340
+ utts[bucket_i],
341
+ ref_rms_list[bucket_i],
342
+ padded_mel_batch(ref_mels[bucket_i]),
343
+ ref_mel_lens[bucket_i],
344
+ total_mel_lens[bucket_i],
345
+ final_text_list[bucket_i]
346
+ ))
347
+ # not only leave easy work for last workers
348
+ random.seed(666)
349
+ random.shuffle(prompts_all)
350
+
351
+ return prompts_all
352
+
353
+
354
+ # get wav_res_ref_text of seed-tts test metalst
355
+ # https://github.com/BytedanceSpeech/seed-tts-eval
356
+
357
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
+ f = open(metalst)
359
+ lines = f.readlines()
360
+ f.close()
361
+
362
+ test_set_ = []
363
+ for line in tqdm(lines):
364
+ if len(line.strip().split('|')) == 5:
365
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
366
+ elif len(line.strip().split('|')) == 4:
367
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
368
+
369
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
370
+ continue
371
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
372
+ if not os.path.isabs(prompt_wav):
373
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
374
+
375
+ test_set_.append((gen_wav, prompt_wav, gt_text))
376
+
377
+ num_jobs = len(gpus)
378
+ if num_jobs == 1:
379
+ return [(gpus[0], test_set_)]
380
+
381
+ wav_per_job = len(test_set_) // num_jobs + 1
382
+ test_set = []
383
+ for i in range(num_jobs):
384
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
385
+
386
+ return test_set
387
+
388
+
389
+ # get librispeech test-clean cross sentence test
390
+
391
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
392
+ f = open(metalst)
393
+ lines = f.readlines()
394
+ f.close()
395
+
396
+ test_set_ = []
397
+ for line in tqdm(lines):
398
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
399
+
400
+ if eval_ground_truth:
401
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
402
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
403
+ else:
404
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
405
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
406
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
407
+
408
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
409
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
410
+
411
+ test_set_.append((gen_wav, ref_wav, gen_txt))
412
+
413
+ num_jobs = len(gpus)
414
+ if num_jobs == 1:
415
+ return [(gpus[0], test_set_)]
416
+
417
+ wav_per_job = len(test_set_) // num_jobs + 1
418
+ test_set = []
419
+ for i in range(num_jobs):
420
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
421
+
422
+ return test_set
423
+
424
+
425
+ # load asr model
426
+
427
+ def load_asr_model(lang, ckpt_dir = ""):
428
+ if lang == "zh":
429
+ from funasr import AutoModel
430
+ model = AutoModel(
431
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
432
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
433
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
434
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
435
+ disable_update=True,
436
+ ) # following seed-tts setting
437
+ elif lang == "en":
438
+ from faster_whisper import WhisperModel
439
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
+ return model
442
+
443
+
444
+ # WER Evaluation, the way Seed-TTS does
445
+
446
+ def run_asr_wer(args):
447
+ rank, lang, test_set, ckpt_dir = args
448
+
449
+ if lang == "zh":
450
+ import zhconv
451
+ torch.cuda.set_device(rank)
452
+ elif lang == "en":
453
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
454
+ else:
455
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
+
457
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
+
459
+ from zhon.hanzi import punctuation
460
+ punctuation_all = punctuation + string.punctuation
461
+ wers = []
462
+
463
+ from jiwer import compute_measures
464
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
465
+ if lang == "zh":
466
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
467
+ hypo = res[0]["text"]
468
+ hypo = zhconv.convert(hypo, 'zh-cn')
469
+ elif lang == "en":
470
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
471
+ hypo = ''
472
+ for segment in segments:
473
+ hypo = hypo + ' ' + segment.text
474
+
475
+ # raw_truth = truth
476
+ # raw_hypo = hypo
477
+
478
+ for x in punctuation_all:
479
+ truth = truth.replace(x, '')
480
+ hypo = hypo.replace(x, '')
481
+
482
+ truth = truth.replace(' ', ' ')
483
+ hypo = hypo.replace(' ', ' ')
484
+
485
+ if lang == "zh":
486
+ truth = " ".join([x for x in truth])
487
+ hypo = " ".join([x for x in hypo])
488
+ elif lang == "en":
489
+ truth = truth.lower()
490
+ hypo = hypo.lower()
491
+
492
+ measures = compute_measures(truth, hypo)
493
+ wer = measures["wer"]
494
+
495
+ # ref_list = truth.split(" ")
496
+ # subs = measures["substitutions"] / len(ref_list)
497
+ # dele = measures["deletions"] / len(ref_list)
498
+ # inse = measures["insertions"] / len(ref_list)
499
+
500
+ wers.append(wer)
501
+
502
+ return wers
503
+
504
+
505
+ # SIM Evaluation
506
+
507
+ def run_sim(args):
508
+ rank, test_set, ckpt_dir = args
509
+ device = f"cuda:{rank}"
510
+
511
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
+ model.load_state_dict(state_dict['model'], strict=False)
514
+
515
+ use_gpu=True if torch.cuda.is_available() else False
516
+ if use_gpu:
517
+ model = model.cuda(device)
518
+ model.eval()
519
+
520
+ sim_list = []
521
+ for wav1, wav2, truth in tqdm(test_set):
522
+
523
+ wav1, sr1 = torchaudio.load(wav1)
524
+ wav2, sr2 = torchaudio.load(wav2)
525
+
526
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
527
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
528
+ wav1 = resample1(wav1)
529
+ wav2 = resample2(wav2)
530
+
531
+ if use_gpu:
532
+ wav1 = wav1.cuda(device)
533
+ wav2 = wav2.cuda(device)
534
+ with torch.no_grad():
535
+ emb1 = model(wav1)
536
+ emb2 = model(wav2)
537
+
538
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
539
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
540
+ sim_list.append(sim)
541
+
542
+ return sim_list
543
+
544
+
545
+ # filter func for dirty data with many repetitions
546
+
547
+ def repetition_found(text, length = 2, tolerance = 10):
548
+ pattern_count = defaultdict(int)
549
+ for i in range(len(text) - length + 1):
550
+ pattern = text[i:i + length]
551
+ pattern_count[pattern] += 1
552
+ for pattern, count in pattern_count.items():
553
+ if count > tolerance:
554
+ return True
555
+ return False
556
+
557
+
558
+ # load model checkpoint for inference
559
+
560
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
+ from ema_pytorch import EMA
562
+
563
+ ckpt_type = ckpt_path.split(".")[-1]
564
+ if ckpt_type == "safetensors":
565
+ from safetensors.torch import load_file
566
+ checkpoint = load_file(ckpt_path, device=device)
567
+ else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
+
570
+ if use_ema == True:
571
+ ema_model = EMA(model, include_online_model = False).to(device)
572
+ if ckpt_type == "safetensors":
573
+ ema_model.load_state_dict(checkpoint)
574
+ else:
575
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
+ ema_model.copy_params_from_ema_to_model()
577
+ else:
578
+ model.load_state_dict(checkpoint['model_state_dict'])
579
+
580
+ return model