rohitp1's picture
Update app.py
1139409
# import gradio as gr
# gr.Interface.load("models/rohitp1/kkkh_whisper_small_distillation_att_loss_libri360_epochs_100_batch_4_concat_dataset").launch()
import gradio as gr
import os
import transformers
from transformers import pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor
import time
import torch
# def greet_from_secret(ignored_param):
# name = os.environ.get('TOKEN')
# return
auth_token = os.environ.get('TOKEN')
M1 = "rohitp1/subh_whisper_small_distil_att_loss_mozilla_epochs_50_batch_8"
M2 = "rohitp1/dgx1_whisper_small_finetune_teacher_babble_noise_mozilla_40_epochs_batch_8"
M3 = "rohitp1/dgx1_whisper_small_finetune_teacher_no_noise_mozilla_40_epochs_batch_8"
model1 = WhisperForConditionalGeneration.from_pretrained(M1, use_auth_token=auth_token)
tokenizer1 = WhisperTokenizer.from_pretrained(M1, use_auth_token=auth_token)
feat_ext1 = WhisperFeatureExtractor.from_pretrained(M1, use_auth_token=auth_token)
model2 = WhisperForConditionalGeneration.from_pretrained(M2, use_auth_token=auth_token)
tokenizer2 = WhisperTokenizer.from_pretrained(M2, use_auth_token=auth_token)
feat_ext2 = WhisperFeatureExtractor.from_pretrained(M2, use_auth_token=auth_token)
model3 = WhisperForConditionalGeneration.from_pretrained(M3, use_auth_token=auth_token)
tokenizer3 = WhisperTokenizer.from_pretrained(M3, use_auth_token=auth_token)
feat_ext3 = WhisperFeatureExtractor.from_pretrained(M3, use_auth_token=auth_token)
# make quantized model
# quantized_model1 = torch.quantization.quantize_dynamic(
# model1, {torch.nn.Linear}, dtype=torch.qint8
# )
p1 = pipeline('automatic-speech-recognition', model=model1, tokenizer=tokenizer1, feature_extractor=feat_ext1)
p2 = pipeline('automatic-speech-recognition', model=model2, tokenizer=tokenizer2, feature_extractor=feat_ext2)
p3 = pipeline('automatic-speech-recognition', model=model3, tokenizer=tokenizer3, feature_extractor=feat_ext3)
# p1_quant = pipeline('automatic-speech-recognition', model=quantized_model1, tokenizer=tokenizer1, feature_extractor=feat_ext1)
def transcribe(mic_input, upl_input, model_type):
if mic_input:
audio = mic_input
else:
audio = upl_input
time.sleep(3)
st_time = time.time()
if model_type == 'NoisyFinetuned':
text = p2(audio)["text"]
elif model_type == 'CleanFinetuned':
text = p3(audio)["text"]
# elif model_type == 'DistilledAndQuantised':
# text = p1_quant(audio)['text']
else:
text = p1(audio)["text"]
end_time = time.time()
# state = text + " "
time_taken = round((end_time - st_time) / 60 , 4)
return text, time_taken
# gr.Interface(
# fn=transcribe,
# inputs=[
# gr.inputs.Audio(source="microphone", type="filepath"),
# 'state'
# ],
# outputs=[
# "textbox",
# "state"
# ],
# live=False).launch()
# demo = gr.load(
# "huggingface/rohitp1/kkkh_whisper_small_distillation_att_loss_libri360_epochs_100_batch_4_concat_dataset",
# title="Speech-to-text",
# inputs="mic",
# description="Let me try to guess what you're saying!",
# api_key="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"
# )
# demo.launch()
def clear_inputs_and_outputs():
return [None, None, "CleanFinetuned", None, None]
# Main function
if __name__ == "__main__":
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
<center><h1> Noise Robust English Automatic Speech Recognition LibriSpeech Dataset</h1></center> \
This space is a demo of an English ASR model using Huggingface.<br> \
In this space, you can record your voice or upload a wav file and the model will predict the text spoken in the audio<br><br>
"""
)
with gr.Row():
## Input
with gr.Column():
mic_input = gr.Audio(source="microphone", type="filepath", label="Record your own voice")
upl_input = gr.Audio(
source="upload", type="filepath", label="Upload a wav file"
)
with gr.Row():
model_type = gr.inputs.Dropdown(["RobustDistillation", "NoisyFinetuned", "CleanFinetuned"], label='Model Type')
with gr.Row():
clr_btn = gr.Button(value="Clear", variant="secondary")
prd_btn = gr.Button(value="Predict")
# Outputs
with gr.Column():
lbl_output = gr.Label(label="Transcription")
with gr.Row():
time_output = gr.Label(label="Time Taken (in sec)")
# with gr.Group():
# gr.Markdown("<center>Prediction per time slot</center>")
# plt_output = gr.Plot(
# label="Prediction per time slot", show_label=False
# )
with gr.Row():
gr.Examples(
[
os.path.join(os.path.dirname(__file__), "audio/sample1.wav"),
os.path.join(os.path.dirname(__file__), "audio/sample2.wav"),
os.path.join(os.path.dirname(__file__), "audio/sample3.wav"),
],
upl_input,
[lbl_output, time_output],
transcribe
)
# Credits
with gr.Row():
gr.Markdown(
"""
<h4>Credits</h4>
Author: Rohit Prasad <br>
Check out the model <a href="https://huggingface.co/rohitp1/subh_whisper_small_distil_att_loss_mozilla_epochs_50_batch_8">here</a>
"""
)
clr_btn.click(
fn=clear_inputs_and_outputs,
inputs=[],
outputs=[mic_input, upl_input, model_type, lbl_output, time_output],
)
prd_btn.click(
fn=transcribe,
inputs=[mic_input, upl_input, model_type],
outputs=[lbl_output, time_output],
)
demo.launch(debug=True)