alaahilal's picture
uploaded files
837062b verified
raw
history blame
2 kB
import streamlit as st
from ced_model.feature_extraction_ced import CedFeatureExtractor
from ced_model.modeling_ced import CedForAudioClassification
import torchaudio
import torch
import os
import soundfile as sf
model_name = "mispeech/ced-tiny"
feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
model = CedForAudioClassification.from_pretrained(model_name)
st.title("Audio Classification App")
st.subheader("Trained on 50 classes of ESC 50 dataset")
st.write("Upload an audio file to predict its class.")
audio_file = st.file_uploader("Upload Audio File", type=["wav"])
if audio_file is not None:
st.write(f"Uploaded file: {audio_file.name}")
try:
temp_file_path = "temp.wav"
with open(temp_file_path, "wb") as f:
f.write(audio_file.read())
try:
audio, sampling_rate = torchaudio.load(temp_file_path)
except Exception:
st.warning("Fallback to soundfile for audio loading.")
audio_data, sampling_rate = sf.read(temp_file_path)
audio = torch.tensor(audio_data).unsqueeze(0)
if sampling_rate != 16000:
st.warning("Resampling audio to 16000 Hz...")
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
audio = resampler(audio)
sampling_rate = 16000
inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = torch.argmax(logits, dim=-1).item()
predicted_label = model.config.id2label[predicted_class_id]
st.success(f"Predicted Class: {predicted_label}")
os.remove(temp_file_path)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.info("Please upload a .wav audio file to continue.")