Spaces:
Running
Running
import streamlit as st | |
from ced_model.feature_extraction_ced import CedFeatureExtractor | |
from ced_model.modeling_ced import CedForAudioClassification | |
import torchaudio | |
import torch | |
import os | |
import soundfile as sf | |
# New imports for handling MP3 and M4A files | |
from pydub import AudioSegment | |
import io | |
model_name = "mispeech/ced-tiny" | |
feature_extractor = CedFeatureExtractor.from_pretrained(model_name) | |
model = CedForAudioClassification.from_pretrained(model_name) | |
st.title("Audio Classification App") | |
st.subheader("Trained on 50 classes of ESC 50 dataset") | |
st.write("Upload an audio file to predict its class. It takes .wav file format") | |
audio_file = st.file_uploader("Upload Audio File", type=["wav"]) | |
if audio_file is not None: | |
st.write(f"Uploaded file: {audio_file.name}") | |
try: | |
# New code block for handling different audio formats | |
temp_file_path = "temp.wav" | |
if audio_file.name.lower().endswith(('.mp3', '.m4a')): | |
# Convert MP3/M4A to WAV | |
audio_bytes = audio_file.read() | |
audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format=audio_file.name.split('.')[-1]) | |
audio.export(temp_file_path, format="wav") | |
else: | |
# For WAV files, write directly | |
with open(temp_file_path, "wb") as f: | |
f.write(audio_file.read()) | |
try: | |
audio, sampling_rate = torchaudio.load(temp_file_path) | |
except Exception: | |
st.warning("Fallback to soundfile for audio loading.") | |
audio_data, sampling_rate = sf.read(temp_file_path) | |
audio = torch.tensor(audio_data).unsqueeze(0) | |
if sampling_rate != 16000: | |
st.warning("Resampling audio to 16000 Hz...") | |
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000) | |
audio = resampler(audio) | |
sampling_rate = 16000 | |
inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_class_id = torch.argmax(logits, dim=-1).item() | |
predicted_label = model.config.id2label[predicted_class_id] | |
st.success(f"Predicted Class: {predicted_label}") | |
os.remove(temp_file_path) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
else: | |
st.info("Please upload an audio file (WAV, MP3, or M4A) to continue.") |