mutisya commited on
Commit
3ad86d1
·
verified ·
1 Parent(s): da08ce1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -16
app.py CHANGED
@@ -3,25 +3,150 @@ from transformers import pipeline
3
  import time
4
  import torch
5
 
 
 
 
 
6
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device=device)
9
 
10
- def transcribe(audio, state=""):
11
- #print(audio)
12
- time.sleep(2)
13
- text = pipe(audio)["text"]
14
- state += text + " "
15
- return state, state
16
 
 
 
 
 
 
 
 
 
17
 
18
- with gr.Blocks() as demo:
19
- state = gr.State(value="")
20
- with gr.Row():
21
- with gr.Column():
22
- audio = gr.Audio(sources="microphone", type="filepath")
23
- with gr.Column():
24
- textbox = gr.Textbox()
25
- audio.stream(fn=transcribe, inputs=[audio, state], outputs=[textbox, state])
26
 
27
- demo.launch(debug=True)
 
 
3
  import time
4
  import torch
5
 
6
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, pipeline
7
+ import numpy as np
8
+
9
+ # set up transcription pipeline
10
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
+ transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en", device=device)
12
+
13
+
14
+ # set up translation pipeline
15
+ translation_model_path = "mutisya/m2m100_418M-en-kik-v24.03.2"
16
+
17
+ def update_tokenizer_settings(tokenizer):
18
+ new_langTokens = { k: tokenizer.convert_tokens_to_ids(k) for k in tokenizer.additional_special_tokens }
19
+ all_lang_tokens =dict(list(tokenizer.lang_token_to_id.items()) + list(new_langTokens.items()))
20
+
21
+ tokenizer.id_to_lang_token = { v : k for k,v in all_lang_tokens.items() }
22
+ tokenizer.lang_token_to_id = { k : v for k,v in all_lang_tokens.items() }
23
+ tokenizer.lang_code_to_token = { k.replace("_", ""): k for k in all_lang_tokens.keys() }
24
+ tokenizer.lang_code_to_id = { k.replace("_", ""): v for k, v in all_lang_tokens.items() }
25
+
26
+
27
+ translation_model = M2M100ForConditionalGeneration.from_pretrained(translation_model_path)
28
+ translation_tokenizer = M2M100Tokenizer.from_pretrained(translation_model_path)
29
+
30
+ update_tokenizer_settings(translation_tokenizer)
31
+
32
+ # set translation direction
33
+ src_lang = "en"
34
+ tgt_lang = "kik"
35
+
36
+ translation_tokenizer.src_lang = src_lang
37
+ translation_tokenizer.tgt_lang = tgt_lang
38
+
39
+
40
+ translation_device = 0 if torch.cuda.is_available() else -1
41
+ translator = pipeline('translation', model=translation_model, tokenizer=translation_tokenizer, device=translation_device)
42
+
43
+
44
+ # transcribe sections while keeping state
45
+ chunk_tracker = []
46
+ ready_to_translate = []
47
+ text_at_chunk_end = ""
48
+ chunk_index = 0;
49
+ translated_text = ""
50
+ transcribed_text = ""
51
+
52
+
53
+ def get_next_translation_block():
54
+ global text_at_chunk_end
55
+ global chunk_tracker
56
+ global ready_to_translate
57
+ global translated_text
58
+ global transcribed_text
59
+
60
+ last_stop = text_at_chunk_end[0:-1].rfind('.')
61
+ ready_sentences = text_at_chunk_end[0:last_stop+1]
62
+ chunks_to_remove = []
63
+
64
+ if len(ready_sentences) > 0:
65
+ print("Trying to match: "+ ready_sentences)
66
+ found_match = False
67
+ for i in range(0, len(chunk_tracker)):
68
+ curr_chunk = chunk_tracker[i]
69
+ chunks_to_remove.append(curr_chunk)
70
+ if curr_chunk["text_at_begining"] == curr_chunk["text_at_end"] and curr_chunk["text_at_begining"] == ready_sentences:
71
+ found_match = True
72
+ break
73
+
74
+ if found_match == False:
75
+ print("ERROR: no match found for "+ ready_sentences)
76
+ chunks_to_remove = []
77
+ else:
78
+ transcribed_text += ready_sentences
79
+ translated_text += translator(ready_sentences, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text']
80
+ print("TRANSLATED: "+ translated_text)
81
+
82
+ return ready_sentences, chunks_to_remove
83
+
84
+ def transcribe(stream, new_chunk):
85
+ global text_at_chunk_end
86
+ global chunk_tracker
87
+ global ready_to_translate
88
+ global chunk_index
89
+ global translated_text
90
+ global transcribed_text
91
+
92
+ chunk_index +=1
93
+
94
+ sr, y = new_chunk
95
+ y = y.astype(np.float32)
96
+ y /= np.max(np.abs(y))
97
+
98
+ chunk_value = y
99
+ chunk_length = len(y)
100
+
101
+ if stream is not None:
102
+ stream = np.concatenate([stream, y])
103
+ else:
104
+ stream = y
105
+
106
+ text_at_chunk_begining = text_at_chunk_end
107
+ text_at_chunk_end = transcriber({"sampling_rate": sr, "raw": stream})["text"]
108
+
109
+ curr_chunk = {
110
+ "value": chunk_value,
111
+ "length": chunk_length,
112
+ "text_at_begining": text_at_chunk_begining,
113
+ "text_at_end": text_at_chunk_end
114
+ }
115
+
116
+ #print(curr_chunk)
117
+ chunk_tracker.append(curr_chunk)
118
+
119
+ # get translation block
120
+ if chunk_index % 5 == 0:
121
+ ready_sentences, chunks_to_remove = get_next_translation_block();
122
+ if len(chunks_to_remove) >0:
123
+ ready_to_translate.append(ready_sentences)
124
+ total_trim_length = 0
125
+ for i in range(0, len(chunks_to_remove)):
126
+ total_trim_length += chunks_to_remove[i]["length"]
127
+ removed = chunk_tracker.pop(0)
128
+ # print("REMOVED: "+ removed["text_at_begining"] +" -> " + removed["text_at_end"])
129
+
130
+ # set up new stream with remaining chunks
131
+ new_stream = chunk_tracker[0]["value"]
132
+ for i in range(1, len(chunk_tracker)):
133
+ new_stream = np.concatenate([new_stream, chunk_tracker[i]["value"]])
134
+
135
+ stream = new_stream
136
 
137
+ return stream, text_at_chunk_end, transcribed_text, translated_text
138
 
 
 
 
 
 
 
139
 
140
+ # set up UI
141
+ demo = gr.Interface(
142
+ transcribe,
143
+ ["state", gr.Audio(sources=["microphone"], streaming=True)],
144
+ ["state", gr.Textbox(label="in progress"), gr.Textbox(label="Transcribed text"), gr.Textbox(label="Translated text")],
145
+ live=True,
146
+ allow_flagging="never"
147
+ )
148
 
149
+ demo.dependencies[0]["show_progress"] = False # this should hide the progress report?
 
 
 
 
 
 
 
150
 
151
+ if __name__ == "__main__":
152
+ demo.launch(debug=True)