Spaces:
Running
Running
File size: 7,360 Bytes
67c46fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import torch
import codecs
import logging
import argparse
import numpy as np
# import edit_distance
from itertools import zip_longest
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires
def ts_prediction_lfr6_standard(
us_alphas,
us_peaks,
char_list,
vad_offset=0.0,
force_time_shift=-1.5,
sil_in_str=True,
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
alphas, peaks = us_alphas, us_peaks
if char_list[-1] == "</s>":
char_list = char_list[:-1]
fire_place = (
torch.where(peaks > 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
) # total offset
if len(fire_place) != len(char_list) + 1:
alphas /= alphas.sum() / (len(char_list) + 1)
alphas = alphas.unsqueeze(0)
peaks = cif_wo_hidden(alphas, threshold=1.0 - 1e-4)[0]
fire_place = (
torch.where(peaks > 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
) # total offset
num_frames = peaks.shape[0]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = (
torch.where(peaks > 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
) # total offset
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0] * TIME_RATE])
new_char_list.append("<sil>")
# tokens timestamp
for i in range(len(fire_place) - 1):
new_char_list.append(char_list[i])
if (
MAX_TOKEN_DURATION < 0
or fire_place[i + 1] - fire_place[i] <= MAX_TOKEN_DURATION
):
timestamp_list.append(
[fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE]
)
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i] * TIME_RATE, _split * TIME_RATE])
timestamp_list.append([_split * TIME_RATE, fire_place[i + 1] * TIME_RATE])
new_char_list.append("<sil>")
# tail token and end silence
# new_char_list.append(char_list[-1])
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) * 0.5
# _end = fire_place[-1]
timestamp_list[-1][1] = _end * TIME_RATE
timestamp_list.append([_end * TIME_RATE, num_frames * TIME_RATE])
new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames * TIME_RATE
if vad_offset: # add offset time in model with vad
for i in range(len(timestamp_list)):
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
for char, timestamp in zip(new_char_list, timestamp_list):
# if char != '<sil>':
if not sil_in_str and char == "<sil>":
continue
res_txt += "{} {} {};".format(
char, str(timestamp[0] + 0.0005)[:5], str(timestamp[1] + 0.0005)[:5]
)
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
if char != "<sil>":
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res_txt, res
def timestamp_sentence(
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
punc_list = [",", "。", "?", "、"]
res = []
if text_postprocessed is None:
return res
if timestamp_postprocessed is None:
return res
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append(
{
"text": text_postprocessed.split(),
"start": timestamp_postprocessed[0][0],
"end": timestamp_postprocessed[-1][1],
"timestamp": timestamp_postprocessed,
}
)
return res
if len(punc_id_list) != len(timestamp_postprocessed):
logging.warning("length mismatch between punc and timestamp")
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = timestamp_postprocessed[0][0]
sentence_end = timestamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(
zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
)
for punc_stamp_text in punc_stamp_text_list:
punc_id, timestamp, text = punc_stamp_text
# sentence_text += text if text is not None else ''
if text is not None:
if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
sentence_text += " " + text
elif len(sentence_text) and (
"a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
):
sentence_text += " " + text
else:
sentence_text += text
sentence_text_seg += text + " "
ts_list.append(timestamp)
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = timestamp[1] if timestamp is not None else sentence_end
sentence_text_seg = (
sentence_text_seg[:-1]
if sentence_text_seg[-1] == " "
else sentence_text_seg
)
if punc_id > 1:
sentence_text += punc_list[punc_id - 2]
if return_raw_text:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
"raw_text": sentence_text_seg,
}
)
else:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
}
)
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = sentence_end
return res
|