File size: 7,271 Bytes
cdbb2b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunClip). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import re
def time_convert(ms):
ms = int(ms)
tail = ms % 1000
s = ms // 1000
mi = s // 60
s = s % 60
h = mi // 60
mi = mi % 60
h = "00" if h == 0 else str(h)
mi = "00" if mi == 0 else str(mi)
s = "00" if s == 0 else str(s)
tail = str(tail)
if len(h) == 1: h = '0' + h
if len(mi) == 1: mi = '0' + mi
if len(s) == 1: s = '0' + s
return "{}:{}:{},{}".format(h, mi, s, tail)
def str2list(text):
pattern = re.compile(r'[\u4e00-\u9fff]|[\w-]+', re.UNICODE)
elements = pattern.findall(text)
return elements
class Text2SRT():
def __init__(self, text, timestamp, offset=0):
self.token_list = text
self.timestamp = timestamp
start, end = timestamp[0][0] - offset, timestamp[-1][1] - offset
self.start_sec, self.end_sec = start, end
self.start_time = time_convert(start)
self.end_time = time_convert(end)
def text(self):
if isinstance(self.token_list, str):
return self.token_list
else:
res = ""
for word in self.token_list:
if '\u4e00' <= word <= '\u9fff':
res += word
else:
res += " " + word
return res.lstrip()
def srt(self, acc_ost=0.0):
return "{} --> {}\n{}\n".format(
time_convert(self.start_sec+acc_ost*1000),
time_convert(self.end_sec+acc_ost*1000),
self.text())
def time(self, acc_ost=0.0):
return (self.start_sec/1000+acc_ost, self.end_sec/1000+acc_ost)
class Text2SRT_audio():
def __init__(self, text, start,end, offset=0):
self.token_list = text
# self.timestamp = timestamp
start, end = start*1000 - offset, end*1000 - offset
self.start_sec, self.end_sec = start, end
self.start_time = time_convert(start)
self.end_time = time_convert(end)
def text(self):
if isinstance(self.token_list, str):
return self.token_list
else:
res = ""
for word in self.token_list:
if '\u4e00' <= word <= '\u9fff':
res += word
else:
res += " " + word
return res.lstrip()
def srt(self, acc_ost=0.0):
return "{} --> {}\n{}\n".format(
time_convert(self.start_sec+acc_ost*1000),
time_convert(self.end_sec+acc_ost*1000),
self.text())
def time(self, acc_ost=0.0):
return (self.start_sec/1000+acc_ost, self.end_sec/1000+acc_ost)
def generate_srt(sentence_list):
srt_total = ''
for i, sent in enumerate(sentence_list):
t2s = Text2SRT(sent['text'], sent['timestamp'])
if 'spk' in sent:
srt_total += "{} spk{}\n{}".format(i, sent['spk'], t2s.srt())
else:
srt_total += "{}\n{}".format(i, t2s.srt())
return srt_total
def trans_format(text):
# 将whisper的识别结果转化为后续的数据标准
total_list = []
timestamp_list = []
sentence_info = []
for segment in text["segments"]:
timestamp_list.append([int(segment["start"]*1000), int(segment["end"]*1000)])
if segment["words"] != []:
sentence_info.append({"text":segment["text"],"start":int(segment["start"]*1000),"end":int(segment["end"]*1000),"timestamp":[[int(item['start']*1000), int(item['end']*1000)] for item in segment["words"]],"raw_text":segment["text"]})
raw_text = text["text"]
total_list.append({"text":text["text"],"raw_text":raw_text,"timestamp":timestamp_list,"sentence_info":sentence_info})
return total_list
def generate_audio_srt(sentence_list):
'''根据音频转文字,生成对应的srt格式字幕'''
srt_total = ''
for i, sent in enumerate(sentence_list):
t2s = Text2SRT_audio(sent['text'], sent['start'],sent['end'])
if 'spk' in sent:
srt_total += "{} spk{}\n{}".format(i, sent['spk'], t2s.srt())
else:
srt_total += "{}\n{}".format(i, t2s.srt())
return srt_total
def generate_srt_clip(sentence_list, start, end, begin_index=0, time_acc_ost=0.0):
'''
生成字幕片段
return:
srt_total:生成的SRT格式字幕文本。
subs:字幕的时间范围及文本信息,格式为 [(时间, 文本), ...]。
cc:字幕的最终编号。
'''
start, end = int(start * 1000), int(end * 1000)
srt_total = ''
cc = 1 + begin_index
subs = []
for _, sent in enumerate(sentence_list):
if isinstance(sent['text'], str):
sent['text'] = str2list(sent['text'])
if sent['timestamp'][-1][1] <= start:
# print("CASE0")
continue
if sent['timestamp'][0][0] >= end:
# print("CASE4")
break
# parts in between
if (sent['timestamp'][-1][1] <= end and sent['timestamp'][0][0] > start) or (sent['timestamp'][-1][1] == end and sent['timestamp'][0][0] == start):
# print("CASE1"); import pdb; pdb.set_trace()
t2s = Text2SRT(sent['text'], sent['timestamp'], offset=start)
srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost))
subs.append((t2s.time(time_acc_ost), t2s.text()))
cc += 1
continue
if sent['timestamp'][0][0] <= start:
# print("CASE2"); import pdb; pdb.set_trace()
if not sent['timestamp'][-1][1] > end:
for j, ts in enumerate(sent['timestamp']):
if ts[1] > start:
break
_text = sent['text'][j:]
_ts = sent['timestamp'][j:]
else:
for j, ts in enumerate(sent['timestamp']):
if ts[1] > start:
_start = j
break
for j, ts in enumerate(sent['timestamp']):
if ts[1] > end:
_end = j
break
# _text = " ".join(sent['text'][_start:_end])
_text = sent['text'][_start:_end]
_ts = sent['timestamp'][_start:_end]
if len(ts):
t2s = Text2SRT(_text, _ts, offset=start)
srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost))
subs.append((t2s.time(time_acc_ost), t2s.text()))
cc += 1
continue
if sent['timestamp'][-1][1] > end:
# print("CASE3"); import pdb; pdb.set_trace()
for j, ts in enumerate(sent['timestamp']):
if ts[1] > end:
break
_text = sent['text'][:j]
_ts = sent['timestamp'][:j]
if len(_ts):
t2s = Text2SRT(_text, _ts, offset=start)
srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost))
subs.append(
(t2s.time(time_acc_ost), t2s.text())
)
cc += 1
continue
return srt_total, subs, cc
|