transcribe-api-proxy / app_previous.py
mutisya's picture
Rename app.py to app_previous.py
92ff9dd verified
raw
history blame
2.45 kB
import os
import gradio as gr
from pydub import AudioSegment
import pyaudioconvert as pac
import torch
import torchaudio
import sox
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor,Wav2Vec2ProcessorWithLM
def convert (audio):
file_name = audio
if file_name.endswith("mp3") or file_name.endswith("wav") or file_name.endswith("ogg"):
if file_name.endswith("mp3"):
sound = AudioSegment.from_mp3(file_name)
sound.export(audio, format="wav")
elif file_name.endswith("ogg"):
sound = AudioSegment.from_ogg(audio)
sound.export(audio, format="wav")
else:
return False
pac.convert_wav_to_16bit_mono(audio,audio)
return True
def parse_transcription_with_lm(logits):
result = processor_with_LM.batch_decode(logits.cpu().numpy())
text = result.text
transcription = text[0].replace('<s>','')
return transcription
def parse_transcription(logits):
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
return transcription
def transcribe(audio_path, applyLM):
speech_array, sampling_rate = torchaudio.load(audio_path)
speech = torchaudio.functional.resample(speech_array, orig_freq=sampling_rate, new_freq=16000).squeeze().numpy()
"""
if convert(audio_path)== False:
return "The format must be mp3,wav and ogg"
speech, sample_rate = torchaudio.load(audio_path)
"""
inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values).logits
if applyLM:
return parse_transcription_with_lm(logits)
else:
return parse_transcription(logits)
auth_token = os.environ.get("key") or True
model_id = "mutisya/wav2vec2-300m-kik-t22-1k-ft-withLM"
processor = Wav2Vec2Processor.from_pretrained(model_id, use_auth_token=auth_token)
processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id, use_auth_token=auth_token)
model = Wav2Vec2ForCTC.from_pretrained(model_id, use_auth_token=auth_token)
gradio_ui = gr.Interface(
fn=transcribe,
title="Speech Recognition",
description="",
inputs=[gr.Audio(source="microphone", type="filepath", optional=True, label="Record from microphone"),
gr.Checkbox(label="Apply LM", value=False)],
outputs=[gr.outputs.Textbox(label="Recognized speech")]
)
gradio_ui.launch(enable_queue=True)