File size: 10,805 Bytes
95a3ca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a34b6
95a3ca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a34b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# imports
import os
import sys
import gradio as gr
import whisper
import torch
import traceback
import shutil
import yaml
import re
from pydub import AudioSegment
from huggingface_hub import snapshot_download
import json
import requests
import wave
from pynvml import *
import time

import mRASPloader

torch.cuda.empty_cache()

# TTS header and url
headers = {"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYTI5NDFhMmEtYzA5ZS00YTcyLWI5ZGItODM5ODEzZDIwMGEwIiwidHlwZSI6ImFwaV90b2tlbiJ9.StBap5nQtNqjh1BMz9DledR5tg5FTWdUMVBrDwY6DjY"}
url ="https://api.edenai.run/v2/audio/text_to_speech"

# the model we are using for ASR, options are small, medium, large and largev2 (large and largev2 don't fit on huggingface cpu)
model = whisper.load_model("medium")



# A table to look up all the languages
language_id_lookup = {
            "Arabic"    : "ar",
            "English"   : "en",
            "Chinese"   : "zh",
            "Spanish"   : "es",
            "Russian"   : "ru",
            "French"    : "fr",
            "German"    : "de",
            "Italian"   : "it",
            "Netherlands": "nl",
            "Portuguese": "pt",
            "Romanian"  : "ro",
            }


# A lookup table for ConST

LANG_GEN_SETUPS = {
    "de": {"beam": 10, "lenpen": 0.7},
    "es": {"beam": 10, "lenpen": 0.1},
    "fr": {"beam": 10, "lenpen": 1.0},
    "it": {"beam": 10, "lenpen": 0.5},
    "nl": {"beam": 10, "lenpen": 0.4},
    "pt": {"beam": 10, "lenpen": 0.9},
    "ro": {"beam": 10, "lenpen": 1.0},
    "ru": {"beam": 10, "lenpen": 0.3},
}


# A lookup table for TTS (edenai)
lang2voice = {
            "Arabic"    : ["ar-XA", "MALE"],
            "English"   : ["en-US", "FEMALE"],
            "Chinese"   : ["cmn-TW", "MALE"],
            "Spanish"   : ["es-ES","MALE"],
            "Russian"   : ["ru-RU,", "FEMALE"],
            "French"    : ["fr-FR", "FEMALE"],
            "German"    : ["de-DE", "MALE"],
            "Italian"   : ["it-IT", "FEMALE"],
            "Netherlands": ["nl-NL", "MALE"],
            "Portuguese": ["pt-BR", "FEMALE"],
            "Romanian"  : ["ro-RO", "MALE"],
            }



# load whisper
os.system("pip install git+https://github.com/openai/whisper.git")

# load mRASP2




# load ConST
#os.system("git clone https://github.com/ReneeYe/ConST")
#os.system("mv ConST ConST_git")
#os.system('mv -n ConST_git/* ./')
#os.system("rm -rf ConST_git")
#os.system("pip3 install --editable ./")
#os.system("mkdir -p data checkpoint")




def restrict_src_options(model_type):
    if model_type == 'Whisper+mRASP2':
        return gr.Dropdown.update(visible= True), gr.Dropdown.update(visible= True), gr.Dropdown.update(visible= False), gr.Button.update(visible= True)
    else:
        return gr.Dropdown.update(visible= False), gr.Dropdown.update(visible= False), gr.Dropdown.update(visible= True), gr.Button.update(visible= False)

def switchLang(src_lang, tgt_lang):
    return tgt_lang, src_lang


# The predict function. audio, language and mic_audio are all parameters directly passed by gradio 
# which means they are user inputted. They are specified in gr.inputs[] block at the bottom. The 
# gr.outputs[] block will specify the output type. 

def predict(audio, src_language, tgt_language_mRASP, tgt_language_ConST, model_type, mic_audio=None):
     # checks if mic_audio is used, otherwise feeds model uploaded audio
    start_predict = time.time()
    if mic_audio is not None:
        input_audio = mic_audio
    elif audio is not None:
        input_audio = audio
    else:
        return "(please provide audio)"
    
    transcript = "Undefined"
    translation = "Undefined"
    
    if model_type == 'Whisper+mRASP2':
        transcript, translation = predictWithmRASP2(input_audio, src_language, tgt_language_mRASP)
        language = tgt_language_mRASP
    elif model_type == 'ConST':
        predictWithConST(input_audio, tgt_language_ConST)
        language = tgt_language_ConST

    start_tts = time.time()

    payload={
    "providers": "google", 
    "language": lang2voice[language][0], 
    "option": lang2voice[language][1], 
    "text": translation, 
    }

    response = requests.post(url, json=payload, headers=headers)

    result = json.loads(response.text)

    os.system('wget -O output.wav "{}"'.format(result['google']['audio_resource_url']))

    tts_time = time.time() - start_tts
    print(f"Took {tts_time} to do text to speech")

    total_time = time.time() - start_predict
    print(f"Took {total_time} to do entire prediction")

    return transcript, translation, "output.wav"



def predictWithmRASP2(input_audio, src_language, tgt_language):
    print("Called predictWithmRASP2")
    # Uses the model's preprocessing methods to preprocess audio
    asr_start = time.time()

    audio = whisper.load_audio(input_audio)
    audio = whisper.pad_or_trim(audio)
    
    # Calculates the mel frequency spectogram
    mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
    
    # if model is supposed to detect language, set outLanguage to None
    # otherwise set to specified language
    if(src_language == "Detect Language"):
        src_language = None
    else:
        src_language = language_id_lookup[src_language.split()[0]]
    tgt_language = language_id_lookup[tgt_language.split()[0]]

    # Runs the audio through the whisper model and gets the DecodingResult object, which has the features:
    # audio_features (Tensor), language, language_probs, tokens, text, avg_logprob, no_speech_prob, temperature, compression_ratio


    # asr

    options = whisper.DecodingOptions(fp16 = True, language = src_language)
    result = whisper.decode(model, mel, options)
    if src_language is None:
        src_language = result.language    

    transcript = result.text

    asr_time = time.time() - asr_start
    mt_start_time = time.time()
    # mt
    with open("input." + src_language, 'w') as w:
        w.write(result.text)
    with open("input." + tgt_language, 'w') as w:
        w.write('LANG_TOK_' + src_language.upper())

    #os.system("python3 fairseq/fairseq_cli/preprocess.py --dataset-impl raw \
    #          --srcdict bpe_vocab --tgtdict bpe_vocab --testpref input -s {} -t {}".format( \
    #    src_language, tgt_language))

    #previous way of doing it
    old_way = """os.system("python3 fairseq/fairseq_cli/interactive.py ./data-bin \
              --user-dir mcolt \
              -s zh \
              -t en \
              --skip-invalid-size-inputs-valid-test \
              --path {} \
              --max-tokens 1024 \
              --task translation_w_langtok \
              --lang-prefix-tok \"LANG_TOK_{}\" \
              --max-source-positions 1024 \
              --max-target-positions 1024 \
              --nbest 1 \
              --bpe subword_nmt \
              --bpe-codes codes.bpe.32000 \
              --post-process --tokenizer moses \
              --input input.{} | grep -E '[D]-[0-9]+' > output".format(
        model_name, tgt_language.upper(), src_language))"""
    
    translation = mRASPloader.infer(cfg, models, task, max_positions, tokenizer, bpe, use_cuda, generator, src_dict, tgt_dict, align_dict, start_time, start_id, src_language, tgt_language)
    translation = (' '.join(translation.split(' ')[1:])).strip()

    mt_time = time.time() - mt_start_time


    # Returns the text
    return transcript, translation



title = "Demo for Speech Translation (Whisper+mRASP2 and ConST)"

description = """
<b>How to use:</b> Upload an audio file or record using the microphone. The audio is either processed by being inputted into the openai whisper model for transcription 
and then mRASP2 for translation, or by ConST, which directly takes the audio input and produces text in the desired language. When using Whisper+mRASP2, 
you can ask the model to detect a language, it will tell you what language it detected. ConST only supports translating from English to another language. 
"""

# The gradio block

cfg = mRASPloader.createCFG()
print(cfg)
models, task, max_positions, tokenizer, bpe, use_cuda, generator, src_dict, tgt_dict, align_dict, start_time, start_id = mRASPloader.loadmRASP2(cfg)

demo = gr.Blocks()

with demo:
    gr.Markdown("# " + title)
    gr.Markdown("###" + description)
    with gr.Row():
        with gr.Column():
            model_type = gr.Dropdown(['Whisper+mRASP2'], type = "value", value = 'Whisper+mRASP2', label = "Select the model you want to use.")
            audio_file = gr.Audio(label="Upload Speech", source="upload", type="filepath")
            src_language = gr.Dropdown(['Arabic',
                                    'Chinese',
                                    'English',
                                    'Spanish',
                                    'Russian',
                                    'French',
                                    'Detect Language'], value = 'English', label="Select the language of input")
            tgt_language_mRASP = gr.Dropdown(['Arabic',
                                    'Chinese',
                                    'English',
                                    'Spanish',
                                    'Russian',
                                    'French'], type="value", value='English', label="Select the language of output")
            tgt_language_ConST = gr.Dropdown(['German', 
                                              'Spanish', 
                                              'French', 
                                              'Italian', 
                                              'Netherlands', 
                                              'Portugese', 
                                              'Romanian', 
                                              'Russian'], type = 'value', value='German', label="Select the language of output", visible= False)
            switch_lang_button = gr.Button("Switch input and output languages")
            mic_audio = gr.Audio(label="Record Speech", source="microphone", type="filepath")
            
            model_type.change(fn = restrict_src_options, inputs=[model_type], outputs=[src_language, tgt_language_mRASP, tgt_language_ConST, switch_lang_button])
            submit_button = gr.Button("Submit")
        with gr.Column():
            transcript = gr.Text(label= "Transcription")
            translate = gr.Text(label= "Translation")
            translated_speech = gr.Audio(label="Translation Speech")

    submit_button.click(fn = predict, inputs=[audio_file, src_language, tgt_language_mRASP, tgt_language_ConST, model_type, mic_audio], outputs=[transcript, translate, translated_speech])
    switch_lang_button.click(switchLang, [src_language, tgt_language_mRASP], [src_language, tgt_language_mRASP])

demo.launch()