mutisya's picture
Update app.py
93523d1 verified
raw
history blame
5.06 kB
import gradio as gr
from transformers import pipeline
import time
import os
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, pipeline
import numpy as np
auth_token = os.environ.get("key")
os.environ["HUGGING_FACE_HUB_TOKEN"] = auth_token
# set up transcription pipeline
device = "cuda:0" if torch.cuda.is_available() else "cpu"
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device=device)
# set up translation pipeline
translation_model_path = "mutisya/m2m100_418M-en-kik-v24.03.2"
def update_tokenizer_settings(tokenizer):
new_langTokens = { k: tokenizer.convert_tokens_to_ids(k) for k in tokenizer.additional_special_tokens }
all_lang_tokens =dict(list(tokenizer.lang_token_to_id.items()) + list(new_langTokens.items()))
tokenizer.id_to_lang_token = { v : k for k,v in all_lang_tokens.items() }
tokenizer.lang_token_to_id = { k : v for k,v in all_lang_tokens.items() }
tokenizer.lang_code_to_token = { k.replace("_", ""): k for k in all_lang_tokens.keys() }
tokenizer.lang_code_to_id = { k.replace("_", ""): v for k, v in all_lang_tokens.items() }
translation_model = M2M100ForConditionalGeneration.from_pretrained(translation_model_path)
translation_tokenizer = M2M100Tokenizer.from_pretrained(translation_model_path)
update_tokenizer_settings(translation_tokenizer)
# set translation direction
src_lang = "en"
tgt_lang = "kik"
translation_tokenizer.src_lang = src_lang
translation_tokenizer.tgt_lang = tgt_lang
translation_device = 0 if torch.cuda.is_available() else -1
translator = pipeline('translation', model=translation_model, tokenizer=translation_tokenizer, device=translation_device)
# transcribe sections while keeping state
chunk_tracker = []
ready_to_translate = []
text_at_chunk_end = ""
chunk_index = 0;
translated_text = ""
transcribed_text = ""
def get_next_translation_block():
global text_at_chunk_end
global chunk_tracker
global ready_to_translate
global translated_text
global transcribed_text
last_stop = text_at_chunk_end[0:-1].rfind('.')
ready_sentences = text_at_chunk_end[0:last_stop+1]
chunks_to_remove = []
if len(ready_sentences) > 0:
print("Trying to match: "+ ready_sentences)
found_match = False
for i in range(0, len(chunk_tracker)):
curr_chunk = chunk_tracker[i]
chunks_to_remove.append(curr_chunk)
if curr_chunk["text_at_begining"] == curr_chunk["text_at_end"] and curr_chunk["text_at_begining"] == ready_sentences:
found_match = True
break
if found_match == False:
print("ERROR: no match found for "+ ready_sentences)
chunks_to_remove = []
else:
transcribed_text += ready_sentences
translated_text += translator(ready_sentences, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text']
print("TRANSLATED: "+ translated_text)
return ready_sentences, chunks_to_remove
def transcribe(stream, new_chunk):
global text_at_chunk_end
global chunk_tracker
global ready_to_translate
global chunk_index
global translated_text
global transcribed_text
chunk_index +=1
sr, y = new_chunk
y = y.astype(np.float32)
y /= np.max(np.abs(y))
chunk_value = y
chunk_length = len(y)
if stream is not None:
stream = np.concatenate([stream, y])
else:
stream = y
text_at_chunk_begining = text_at_chunk_end
text_at_chunk_end = transcriber({"sampling_rate": sr, "raw": stream})["text"]
curr_chunk = {
"value": chunk_value,
"length": chunk_length,
"text_at_begining": text_at_chunk_begining,
"text_at_end": text_at_chunk_end
}
#print(curr_chunk)
chunk_tracker.append(curr_chunk)
# get translation block
if chunk_index % 5 == 0:
ready_sentences, chunks_to_remove = get_next_translation_block();
if len(chunks_to_remove) >0:
ready_to_translate.append(ready_sentences)
total_trim_length = 0
for i in range(0, len(chunks_to_remove)):
total_trim_length += chunks_to_remove[i]["length"]
removed = chunk_tracker.pop(0)
# print("REMOVED: "+ removed["text_at_begining"] +" -> " + removed["text_at_end"])
# set up new stream with remaining chunks
new_stream = chunk_tracker[0]["value"]
for i in range(1, len(chunk_tracker)):
new_stream = np.concatenate([new_stream, chunk_tracker[i]["value"]])
stream = new_stream
return stream, text_at_chunk_end, transcribed_text, translated_text
# set up UI
demo = gr.Interface(
transcribe,
["state", gr.Audio(sources=["microphone"], streaming=True)],
["state", gr.Textbox(label="in progress"), gr.Textbox(label="Transcribed text"), gr.Textbox(label="Translated text")],
live=True,
allow_flagging="never"
)
demo.dependencies[0]["show_progress"] = False # this should hide the progress report?
if __name__ == "__main__":
demo.launch(debug=True)