mutisya's picture
Duplicate from mutisya/kik_asr_demo_1
6becda8
raw
history blame
2.61 kB
import os
import gradio as gr
from pydub import AudioSegment
import pyaudioconvert as pac
import torch
import torchaudio
import sox
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor,Wav2Vec2ProcessorWithLM
def convert (audio):
file_name = audio
if file_name.endswith("mp3") or file_name.endswith("wav") or file_name.endswith("ogg"):
if file_name.endswith("mp3"):
sound = AudioSegment.from_mp3(file_name)
sound.export(audio, format="wav")
elif file_name.endswith("ogg"):
sound = AudioSegment.from_ogg(audio)
sound.export(audio, format="wav")
else:
return False
pac.convert_wav_to_16bit_mono(audio,audio)
return True
def parse_transcription_with_lm(logits):
result = processor_with_LM.batch_decode(logits.cpu().numpy())
text = result.text
transcription = text[0].replace('<s>','')
return transcription
def parse_transcription(logits):
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
return transcription
def transcribe(audio, audio_microphone, applyLM):
audio_path = audio_microphone if audio_microphone else audio
speech_array, sampling_rate = torchaudio.load(audio_path)
speech = torchaudio.functional.resample(speech_array, orig_freq=sampling_rate, new_freq=16000).squeeze().numpy()
"""
if convert(audio_path)== False:
return "The format must be mp3,wav and ogg"
speech, sample_rate = torchaudio.load(audio_path)
"""
inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values).logits
if applyLM:
return parse_transcription_with_lm(logits)
else:
return parse_transcription(logits)
auth_token = os.environ.get("key") or True
model_id = "mutisya/wav2vec2-300m-kik-t22-1k-ft-withLM"
processor = Wav2Vec2Processor.from_pretrained(model_id, use_auth_token=auth_token)
processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id, use_auth_token=auth_token)
model = Wav2Vec2ForCTC.from_pretrained(model_id, use_auth_token=auth_token)
gradio_ui = gr.Interface(
fn=transcribe,
title="Kikuyu Speech Recognition",
description="",
inputs=[gr.Audio(label="Upload Audio File", type="filepath", optional=True),
gr.Audio(source="microphone", type="filepath", optional=True, label="Record from microphone"),
gr.Checkbox(label="Apply LM", value=False)],
outputs=[gr.outputs.Textbox(label="Recognized speech")]
)
gradio_ui.launch(enable_queue=True)