alaahilal's picture
working as of now for .wav files
212fcb5 verified
import streamlit as st
from ced_model.feature_extraction_ced import CedFeatureExtractor
from ced_model.modeling_ced import CedForAudioClassification
import torchaudio
import torch
import os
import soundfile as sf
# New imports for handling MP3 and M4A files
from pydub import AudioSegment
import io
model_name = "mispeech/ced-tiny"
feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
model = CedForAudioClassification.from_pretrained(model_name)
st.title("Audio Classification App")
st.subheader("Trained on 50 classes of ESC 50 dataset")
st.write("Upload an audio file to predict its class. It takes .wav file format")
audio_file = st.file_uploader("Upload Audio File", type=["wav"])
if audio_file is not None:
st.write(f"Uploaded file: {audio_file.name}")
try:
# New code block for handling different audio formats
temp_file_path = "temp.wav"
if audio_file.name.lower().endswith(('.mp3', '.m4a')):
# Convert MP3/M4A to WAV
audio_bytes = audio_file.read()
audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format=audio_file.name.split('.')[-1])
audio.export(temp_file_path, format="wav")
else:
# For WAV files, write directly
with open(temp_file_path, "wb") as f:
f.write(audio_file.read())
try:
audio, sampling_rate = torchaudio.load(temp_file_path)
except Exception:
st.warning("Fallback to soundfile for audio loading.")
audio_data, sampling_rate = sf.read(temp_file_path)
audio = torch.tensor(audio_data).unsqueeze(0)
if sampling_rate != 16000:
st.warning("Resampling audio to 16000 Hz...")
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
audio = resampler(audio)
sampling_rate = 16000
inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = torch.argmax(logits, dim=-1).item()
predicted_label = model.config.id2label[predicted_class_id]
st.success(f"Predicted Class: {predicted_label}")
os.remove(temp_file_path)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.info("Please upload an audio file (WAV, MP3, or M4A) to continue.")