wav2vec2-demo / app.py
rohitp1's picture
Create app.py
abc4a4a
raw
history blame
5.58 kB
# import gradio as gr
# gr.Interface.load("models/rohitp1/kkkh_whisper_small_distillation_att_loss_libri360_epochs_100_batch_4_concat_dataset").launch()
import gradio as gr
import os
import transformers
from transformers import pipeline, Wav2Vec2ForCTC,Wav2Vec2Processor
import time
import torch
# def greet_from_secret(ignored_param):
# name = os.environ.get('TOKEN')
# return
auth_token = os.environ.get('TOKEN')
M1 = "rohitp1/dgx1_w2v2_base_teacher_student_distillation_mozilla_epochs_100_batch_16_concatenate_datasets"
M2 = "rohitp1/finetune_teacher_babble_noise_mozilla_200_epochs"
M3 = "rohitp1/finetune_teacher_clean_mozilla_200_epochs"
model1 = Wav2Vec2ForCTC.from_pretrained(M1, use_auth_token=auth_token)
processor1 = Wav2Vec2Processor.from_pretrained(M1, use_auth_token=auth_token)
model2 = Wav2Vec2ForCTC.from_pretrained(M2, use_auth_token=auth_token)
processor2 = Wav2Vec2Processor.from_pretrained(M2, use_auth_token=auth_token)
model3 = Wav2Vec2ForCTC.from_pretrained(M3, use_auth_token=auth_token)
processor3 = Wav2Vec2Processor.from_pretrained(M3, use_auth_token=auth_token)
# make quantized model
quantized_model1 = torch.quantization.quantize_dynamic(
model3, {torch.nn.Linear}, dtype=torch.qint8
)
p1 = pipeline('automatic-speech-recognition', model=model1, processor=processor1)
p2 = pipeline('automatic-speech-recognition', model=model2, processor=processor2)
p3 = pipeline('automatic-speech-recognition', model=model3, processor=processor3)
p1_quant = pipeline('automatic-speech-recognition', model=quantized_model1, processor=processor1)
def transcribe(mic_input, upl_input, model_type):
if mic_input:
audio = mic_input
else:
audio = upl_input
time.sleep(3)
st_time = time.time()
if model_type == 'NoisyFinetuned':
text = p2(audio)["text"]
elif model_type == 'CleanFinetuned':
text = p3(audio)["text"]
elif model_type == 'DistilledQuantised':
text = p1_quant(audio)['text']
else:
text = p1(audio)["text"]
end_time = time.time()
# state = text + " "
time_taken = round((end_time - st_time) / 60 , 4)
return text, time_taken
# gr.Interface(
# fn=transcribe,
# inputs=[
# gr.inputs.Audio(source="microphone", type="filepath"),
# 'state'
# ],
# outputs=[
# "textbox",
# "state"
# ],
# live=False).launch()
# demo = gr.load(
# "huggingface/rohitp1/kkkh_whisper_small_distillation_att_loss_libri360_epochs_100_batch_4_concat_dataset",
# title="Speech-to-text",
# inputs="mic",
# description="Let me try to guess what you're saying!",
# api_key="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"
# )
# demo.launch()
def clear_inputs_and_outputs():
return [None, None, "CleanFinetuned", None, None]
# Main function
if __name__ == "__main__":
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
<center><h1> Noise Robust English Automatic Speech Recognition LibriSpeech Dataset</h1></center> \
This space is a demo of an English ASR model using Huggingface.<br> \
In this space, you can record your voice or upload a wav file and the model will predict the text spoken in the audio<br><br>
"""
)
with gr.Row():
## Input
with gr.Column():
mic_input = gr.Audio(source="microphone", type="filepath", label="Record your own voice")
upl_input = gr.Audio(
source="upload", type="filepath", label="Upload a wav file"
)
with gr.Row():
model_type = gr.inputs.Dropdown(["RobustDistillation", "NoisyFinetuned", "CleanFinetuned", "DistilledAndQuantised"], label='Model Type')
with gr.Row():
clr_btn = gr.Button(value="Clear", variant="secondary")
prd_btn = gr.Button(value="Predict")
# Outputs
with gr.Column():
lbl_output = gr.Label(label="Transcription")
with gr.Row():
time_output = gr.Label(label="Time Taken (in sec)")
# with gr.Group():
# gr.Markdown("<center>Prediction per time slot</center>")
# plt_output = gr.Plot(
# label="Prediction per time slot", show_label=False
# )
with gr.Row():
gr.Examples(
[
# os.path.join(os.path.dirname(__file__), "audio/sample1.wav"),
# os.path.join(os.path.dirname(__file__), "audio/sample2.wav"),
os.path.join(os.path.dirname(__file__), "audio/sample3.wav"),
],
upl_input,
[lbl_output, time_output],
transcribe
)
# Credits
with gr.Row():
gr.Markdown(
"""
<h4>Credits</h4>
Author: Rohit Prasad <br>
Check out the model <a href="https://huggingface.co/rohitp1/subh_whisper_small_distil_att_loss_mozilla_epochs_50_batch_8">here</a>
"""
)
clr_btn.click(
fn=clear_inputs_and_outputs,
inputs=[],
outputs=[mic_input, upl_input, model_type, lbl_output, time_output],
)
prd_btn.click(
fn=transcribe,
inputs=[mic_input, upl_input, model_type],
outputs=[lbl_output, time_output],
)
demo.launch(debug=True)