Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
import time | |
import os | |
import torch | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, pipeline | |
import numpy as np | |
auth_token = os.environ.get("key") | |
os.environ["HUGGING_FACE_HUB_TOKEN"] = auth_token | |
# set up transcription pipeline | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device=device) | |
# set up translation pipeline | |
translation_model_path = "mutisya/m2m100_418M-en-kik-v24.03.2" | |
def update_tokenizer_settings(tokenizer): | |
new_langTokens = { k: tokenizer.convert_tokens_to_ids(k) for k in tokenizer.additional_special_tokens } | |
all_lang_tokens =dict(list(tokenizer.lang_token_to_id.items()) + list(new_langTokens.items())) | |
tokenizer.id_to_lang_token = { v : k for k,v in all_lang_tokens.items() } | |
tokenizer.lang_token_to_id = { k : v for k,v in all_lang_tokens.items() } | |
tokenizer.lang_code_to_token = { k.replace("_", ""): k for k in all_lang_tokens.keys() } | |
tokenizer.lang_code_to_id = { k.replace("_", ""): v for k, v in all_lang_tokens.items() } | |
translation_model = M2M100ForConditionalGeneration.from_pretrained(translation_model_path) | |
translation_tokenizer = M2M100Tokenizer.from_pretrained(translation_model_path) | |
update_tokenizer_settings(translation_tokenizer) | |
# set translation direction | |
src_lang = "en" | |
tgt_lang = "kik" | |
translation_tokenizer.src_lang = src_lang | |
translation_tokenizer.tgt_lang = tgt_lang | |
translation_device = 0 if torch.cuda.is_available() else -1 | |
translator = pipeline('translation', model=translation_model, tokenizer=translation_tokenizer, device=translation_device) | |
# transcribe sections while keeping state | |
chunk_tracker = [] | |
ready_to_translate = [] | |
text_at_chunk_end = "" | |
chunk_index = 0; | |
translated_text = "" | |
transcribed_text = "" | |
def get_next_translation_block(): | |
global text_at_chunk_end | |
global chunk_tracker | |
global ready_to_translate | |
global translated_text | |
global transcribed_text | |
last_stop = text_at_chunk_end[0:-1].rfind('.') | |
ready_sentences = text_at_chunk_end[0:last_stop+1] | |
chunks_to_remove = [] | |
if len(ready_sentences) > 0: | |
print("Trying to match: "+ ready_sentences) | |
found_match = False | |
for i in range(0, len(chunk_tracker)): | |
curr_chunk = chunk_tracker[i] | |
chunks_to_remove.append(curr_chunk) | |
if curr_chunk["text_at_begining"] == curr_chunk["text_at_end"] and curr_chunk["text_at_begining"] == ready_sentences: | |
found_match = True | |
break | |
if found_match == False: | |
print("ERROR: no match found for "+ ready_sentences) | |
chunks_to_remove = [] | |
else: | |
transcribed_text += ready_sentences | |
translated_text += translator(ready_sentences, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text'] | |
print("TRANSLATED: "+ translated_text) | |
return ready_sentences, chunks_to_remove | |
def transcribe(stream, new_chunk): | |
global text_at_chunk_end | |
global chunk_tracker | |
global ready_to_translate | |
global chunk_index | |
global translated_text | |
global transcribed_text | |
chunk_index +=1 | |
sr, y = new_chunk | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
chunk_value = y | |
chunk_length = len(y) | |
if stream is not None: | |
stream = np.concatenate([stream, y]) | |
else: | |
stream = y | |
text_at_chunk_begining = text_at_chunk_end | |
text_at_chunk_end = transcriber({"sampling_rate": sr, "raw": stream})["text"] | |
curr_chunk = { | |
"value": chunk_value, | |
"length": chunk_length, | |
"text_at_begining": text_at_chunk_begining, | |
"text_at_end": text_at_chunk_end | |
} | |
#print(curr_chunk) | |
chunk_tracker.append(curr_chunk) | |
# get translation block | |
if chunk_index % 5 == 0: | |
ready_sentences, chunks_to_remove = get_next_translation_block(); | |
if len(chunks_to_remove) >0: | |
ready_to_translate.append(ready_sentences) | |
total_trim_length = 0 | |
for i in range(0, len(chunks_to_remove)): | |
total_trim_length += chunks_to_remove[i]["length"] | |
removed = chunk_tracker.pop(0) | |
# print("REMOVED: "+ removed["text_at_begining"] +" -> " + removed["text_at_end"]) | |
# set up new stream with remaining chunks | |
new_stream = chunk_tracker[0]["value"] | |
for i in range(1, len(chunk_tracker)): | |
new_stream = np.concatenate([new_stream, chunk_tracker[i]["value"]]) | |
stream = new_stream | |
return stream, text_at_chunk_end, transcribed_text, translated_text | |
# set up UI | |
demo = gr.Interface( | |
transcribe, | |
["state", gr.Audio(sources=["microphone"], streaming=True)], | |
["state", gr.Textbox(label="in progress"), gr.Textbox(label="Transcribed text"), gr.Textbox(label="Translated text")], | |
live=True, | |
allow_flagging="never" | |
) | |
demo.dependencies[0]["show_progress"] = False # this should hide the progress report? | |
if __name__ == "__main__": | |
demo.launch(debug=True) |