added transcribe
Browse files- src/transcribe/transcribe.py +268 -0
src/transcribe/transcribe.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sys import platform
|
2 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
from transformers.utils import is_flash_attn_2_available
|
6 |
+
from pyannote.audio import Pipeline
|
7 |
+
from pyannote.core import Segment
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
languages = {
|
11 |
+
"English": "en",
|
12 |
+
"Chinese": "zh",
|
13 |
+
"German": "de",
|
14 |
+
"Spanish": "es",
|
15 |
+
"Russian": "ru",
|
16 |
+
"Korean": "ko",
|
17 |
+
"French": "fr",
|
18 |
+
"Japanese": "ja",
|
19 |
+
"Portuguese": "pt",
|
20 |
+
"Turkish": "tr",
|
21 |
+
"Polish": "pl",
|
22 |
+
"Catalan": "ca",
|
23 |
+
"Dutch": "nl",
|
24 |
+
"Arabic": "ar",
|
25 |
+
"Swedish": "sv",
|
26 |
+
"Italian": "it",
|
27 |
+
"Indonesian": "id",
|
28 |
+
"Hindi": "hi",
|
29 |
+
"Finnish": "fi",
|
30 |
+
"Vietnamese": "vi",
|
31 |
+
"Hebrew": "iw",
|
32 |
+
"Ukrainian": "uk",
|
33 |
+
"Greek": "el",
|
34 |
+
"Malay": "ms",
|
35 |
+
"Czech": "cs",
|
36 |
+
"Romanian": "ro",
|
37 |
+
"Danish": "da",
|
38 |
+
"Hungarian": "hu",
|
39 |
+
"Tamil": "ta",
|
40 |
+
"Norwegian": "no",
|
41 |
+
"Thai": "th",
|
42 |
+
"Urdu": "ur",
|
43 |
+
"Croatian": "hr",
|
44 |
+
"Bulgarian": "bg",
|
45 |
+
"Lithuanian": "lt",
|
46 |
+
"Latin": "la",
|
47 |
+
"Maori": "mi",
|
48 |
+
"Malayalam": "ml",
|
49 |
+
"Welsh": "cy",
|
50 |
+
"Slovak": "sk",
|
51 |
+
"Telugu": "te",
|
52 |
+
"Persian": "fa",
|
53 |
+
"Latvian": "lv",
|
54 |
+
"Bengali": "bn",
|
55 |
+
"Serbian": "sr",
|
56 |
+
"Azerbaijani": "az",
|
57 |
+
"Slovenian": "sl",
|
58 |
+
"Kannada": "kn",
|
59 |
+
"Estonian": "et",
|
60 |
+
"Macedonian": "mk",
|
61 |
+
"Breton": "br",
|
62 |
+
"Basque": "eu",
|
63 |
+
"Icelandic": "is",
|
64 |
+
"Armenian": "hy",
|
65 |
+
"Nepali": "ne",
|
66 |
+
"Mongolian": "mn",
|
67 |
+
"Bosnian": "bs",
|
68 |
+
"Kazakh": "kk",
|
69 |
+
"Albanian": "sq",
|
70 |
+
"Swahili": "sw",
|
71 |
+
"Galician": "gl",
|
72 |
+
"Marathi": "mr",
|
73 |
+
"Punjabi": "pa",
|
74 |
+
"Sinhala": "si",
|
75 |
+
"Khmer": "km",
|
76 |
+
"Shona": "sn",
|
77 |
+
"Yoruba": "yo",
|
78 |
+
"Somali": "so",
|
79 |
+
"Afrikaans": "af",
|
80 |
+
"Occitan": "oc",
|
81 |
+
"Georgian": "ka",
|
82 |
+
"Belarusian": "be",
|
83 |
+
"Tajik": "tg",
|
84 |
+
"Sindhi": "sd",
|
85 |
+
"Gujarati": "gu",
|
86 |
+
"Amharic": "am",
|
87 |
+
"Yiddish": "yi",
|
88 |
+
"Lao": "lo",
|
89 |
+
"Uzbek": "uz",
|
90 |
+
"Faroese": "fo",
|
91 |
+
"Haitian creole": "ht",
|
92 |
+
"Pashto": "ps",
|
93 |
+
"Turkmen": "tk",
|
94 |
+
"Nynorsk": "nn",
|
95 |
+
"Maltese": "mt",
|
96 |
+
"Sanskrit": "sa",
|
97 |
+
"Luxembourgish": "lb",
|
98 |
+
"Myanmar": "my",
|
99 |
+
"Tibetan": "bo",
|
100 |
+
"Tagalog": "tl",
|
101 |
+
"Malagasy": "mg",
|
102 |
+
"Assamese": "as",
|
103 |
+
"Tatar": "tt",
|
104 |
+
"Hawaiian": "haw",
|
105 |
+
"Lingala": "ln",
|
106 |
+
"Hausa": "ha",
|
107 |
+
"Bashkir": "ba",
|
108 |
+
"Javanese": "jw",
|
109 |
+
"Sundanese": "su",
|
110 |
+
}
|
111 |
+
|
112 |
+
if torch.cuda.is_available():
|
113 |
+
device = torch.device("cuda:0")
|
114 |
+
elif platform == "darwin":
|
115 |
+
device = torch.device("mps")
|
116 |
+
else:
|
117 |
+
device = torch.device("cpu")
|
118 |
+
|
119 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
def get_text_with_timestamp(transcribe_res):
|
124 |
+
timestamp_texts = []
|
125 |
+
for item in transcribe_res["chunks"]:
|
126 |
+
start = item["timestamp"][0]
|
127 |
+
end = item["timestamp"][1]
|
128 |
+
text = item["text"]
|
129 |
+
timestamp_texts.append((Segment(start, end), text))
|
130 |
+
return timestamp_texts
|
131 |
+
|
132 |
+
|
133 |
+
def add_speaker_info_to_text(timestamp_texts, ann):
|
134 |
+
spk_text = []
|
135 |
+
for seg, text in timestamp_texts:
|
136 |
+
spk = ann.crop(seg).argmax()
|
137 |
+
spk_text.append((seg, spk, text))
|
138 |
+
return spk_text
|
139 |
+
|
140 |
+
|
141 |
+
def merge_cache(text_cache):
|
142 |
+
sentence = "".join([item[-1] for item in text_cache])
|
143 |
+
spk = text_cache[0][1]
|
144 |
+
start = text_cache[0][0].start
|
145 |
+
end = text_cache[-1][0].end
|
146 |
+
return Segment(start, end), spk, sentence
|
147 |
+
|
148 |
+
|
149 |
+
PUNC_SENT_END = [".", "?", "!"]
|
150 |
+
|
151 |
+
|
152 |
+
def merge_sentence(spk_text):
|
153 |
+
merged_spk_text = []
|
154 |
+
pre_spk = None
|
155 |
+
text_cache = []
|
156 |
+
for seg, spk, text in spk_text:
|
157 |
+
if spk != pre_spk and pre_spk is not None and len(text_cache) > 0:
|
158 |
+
merged_spk_text.append(merge_cache(text_cache))
|
159 |
+
text_cache = [(seg, spk, text)]
|
160 |
+
pre_spk = spk
|
161 |
+
|
162 |
+
elif text[-1] in PUNC_SENT_END:
|
163 |
+
text_cache.append((seg, spk, text))
|
164 |
+
merged_spk_text.append(merge_cache(text_cache))
|
165 |
+
text_cache = []
|
166 |
+
pre_spk = spk
|
167 |
+
else:
|
168 |
+
text_cache.append((seg, spk, text))
|
169 |
+
pre_spk = spk
|
170 |
+
if len(text_cache) > 0:
|
171 |
+
merged_spk_text.append(merge_cache(text_cache))
|
172 |
+
return merged_spk_text
|
173 |
+
|
174 |
+
def diarize_text(transcribe_res, diarization_result):
|
175 |
+
timestamp_texts = get_text_with_timestamp(transcribe_res)
|
176 |
+
spk_text = add_speaker_info_to_text(timestamp_texts, diarization_result)
|
177 |
+
res_processed = merge_sentence(spk_text)
|
178 |
+
return res_processed
|
179 |
+
|
180 |
+
def make_conversation(transcribe_result, diarization_result):
|
181 |
+
processed = diarize_text(transcribe_result, diarization_result)
|
182 |
+
df = pd.DataFrame(processed, columns=["segment", "speaker", "text"])[
|
183 |
+
["speaker", "text"]
|
184 |
+
]
|
185 |
+
df["key"] = (df["speaker"] != df["speaker"].shift(1)).astype(int).cumsum()
|
186 |
+
conversation = df.groupby(["key", "speaker"])["text"].apply(" ".join).reset_index()
|
187 |
+
conversation_list = list(zip(conversation.text, conversation.speaker))
|
188 |
+
return conversation_list
|
189 |
+
|
190 |
+
# def transcriber(input: str, language: str, translate: bool, progress) -> dict:
|
191 |
+
def transcriber(input: str, model: str, language: str, translate: bool, diarize: bool, input_diarization_token) -> dict:
|
192 |
+
"""Transcribes the audio using the OpenAI Whisper model.
|
193 |
+
Args:
|
194 |
+
input: file path to the audio file in any format
|
195 |
+
language: name of the language in which the audio is recorded
|
196 |
+
translate: boolean indicator to enable immediate translation
|
197 |
+
Returns: transcription and segment-timestamps.
|
198 |
+
"""
|
199 |
+
model_id = model
|
200 |
+
|
201 |
+
if diarize:
|
202 |
+
|
203 |
+
pipeline_diarization = Pipeline.from_pretrained(
|
204 |
+
"pyannote/speaker-diarization-3.1",
|
205 |
+
use_auth_token=input_diarization_token)
|
206 |
+
|
207 |
+
# send pipeline to GPU (when available)
|
208 |
+
pipeline_diarization.to(device)
|
209 |
+
|
210 |
+
# apply pretrained pipeline
|
211 |
+
diarization = pipeline_diarization(input)
|
212 |
+
|
213 |
+
# print the result
|
214 |
+
# for turn, _, speaker in diarization.itertracks(yield_label=True):
|
215 |
+
# print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}")
|
216 |
+
|
217 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
218 |
+
model_id,
|
219 |
+
torch_dtype=torch_dtype,
|
220 |
+
low_cpu_mem_usage=True,
|
221 |
+
use_safetensors=True,
|
222 |
+
use_flash_attention_2=True if is_flash_attn_2_available() else False
|
223 |
+
)
|
224 |
+
|
225 |
+
print(device)
|
226 |
+
|
227 |
+
model.to(device)
|
228 |
+
|
229 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
230 |
+
|
231 |
+
language = languages.get(language, None)
|
232 |
+
task = None
|
233 |
+
if translate:
|
234 |
+
task = "translate"
|
235 |
+
|
236 |
+
pipe = pipeline(
|
237 |
+
"automatic-speech-recognition",
|
238 |
+
model=model,
|
239 |
+
tokenizer=processor.tokenizer,
|
240 |
+
feature_extractor=processor.feature_extractor,
|
241 |
+
max_new_tokens=128,
|
242 |
+
chunk_length_s=15,
|
243 |
+
batch_size=16,
|
244 |
+
return_timestamps=True,
|
245 |
+
torch_dtype=torch_dtype,
|
246 |
+
device=device,
|
247 |
+
generate_kwargs={"task": task}
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
results = pipe(input)
|
252 |
+
results["text"] = results["text"].strip()
|
253 |
+
|
254 |
+
text = ""
|
255 |
+
chunks = results.get("chunks", [])
|
256 |
+
for chunk in chunks:
|
257 |
+
text += chunk["text"] + "\n"
|
258 |
+
|
259 |
+
# conversation = make_conversation(transcription, diarization)
|
260 |
+
|
261 |
+
# Transform the list to skip one line each time
|
262 |
+
# conversation_gradio = []
|
263 |
+
# for i in range(0, len(conversation), 2): # Increment by 2 to skip one line each time
|
264 |
+
# current_text = conversation[i][0]
|
265 |
+
# next_text = conversation[i + 1][0] if i + 1 < len(conversation) else ""
|
266 |
+
# conversation_gradio.append((current_text, next_text))
|
267 |
+
|
268 |
+
return text
|