asr-test / app.py
Himel
update
0599b65
raw
history blame
1.76 kB
import os
import gradio as gr
import librosa
import torch
import torchaudio
import numpy as np
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers import WhisperFeatureExtractor
from transformers import WhisperForConditionalGeneration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = os.environ.get("HF_REPO_ID")
access_token = os.environ.get("HF_TOKEN")
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path, token=access_token)
tokenizer = WhisperTokenizer.from_pretrained(model_path, token=access_token)
processor = WhisperProcessor.from_pretrained(model_path, token=access_token)
model = WhisperForConditionalGeneration.from_pretrained(model_path, token=access_token).to(device)
def transcribe_audio(file_path):
speech_array, sampling_rate = torchaudio.load(file_path, format="wav")
speech_array = speech_array[0].numpy()
speech_array = librosa.resample(np.asarray(speech_array), orig_sr=sampling_rate, target_sr=16000)
input_features = feature_extractor(speech_array, sampling_rate=16000, return_tensors="pt").input_features
# batch = processor.feature_extractor.pad(input_features, return_tensors="pt")
predicted_ids = model.generate(inputs=input_features.to(device))[0]
transcription = processor.decode(predicted_ids, skip_special_tokens=True)
return transcription
# Create a list of example audio files
examples = [f"test_sample/{x}" for x in os.listdir("test_sample")]
# Create the Gradio interface
interface = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Textbox(),
examples=examples
)
# Launch the interface
interface.launch()