import gradio as gr
from transformers import Wav2Vec2Processor
from transformers import AutoModelForCTC
from conversationalnlp.models.wav2vec2 import Wav2Vec2Predict
from conversationalnlp.models.wav2vec2 import ModelLoader
from conversationalnlp.utils import *
import soundfile as sf
import os

"""
run gradio with
>>python app.py
"""

audioheaderpath = os.path.join(
    os.getcwd(), "temp")


pretrained_model = "codenamewei/speech-to-text"

processor = Wav2Vec2Processor.from_pretrained(
    pretrained_model)

model = AutoModelForCTC.from_pretrained(
    pretrained_model)

modelloader = ModelLoader(model, processor)

predictor = Wav2Vec2Predict(modelloader)

audiofileexamples = ["example1.flac", "example2.flac"]

fileextension = ".wav"


def greet(*args):
    """
    List[tuple, tuple]
    mic: param[0] (int, np.array)
    audiofile: param[1] (int, np.array)
    """

    dictinput = dict(mic=args[0], file=args[1])
    audiofiles = []

    for key, audioarray in dictinput.items():

        if audioarray is not None:
            # WORKAROUND: Save to file and reread to get the array shape needed for prediction

            audioabspath = audioheaderpath + "_" + key + fileextension
            print(f"Audio at path {audioabspath}")
            sf.write(audioabspath,
                     audioarray[1], audioarray[0])
            audiofiles.append(audioabspath)

    predictiontexts = predictor.predictfiles(audiofiles)

    mictext = predictiontexts["predicted_text"][0] + "\n" + \
        predictiontexts["corrected_text"][0] if dictinput['mic'] is not None else ""
    filetext = predictiontexts["predicted_text"][-1] + "\n" + \
        predictiontexts["corrected_text"][-1] if dictinput['file'] is not None else ""

    return [mictext, filetext]


demo = gr.Interface(fn=greet,
                    inputs=["mic", "audio"],
                    outputs=["text", "text"],
                    title="Speech-to-Text",
                    examples=[audiofileexamples])

demo.launch()  # share=True)