import streamlit as st from ced_model.feature_extraction_ced import CedFeatureExtractor from ced_model.modeling_ced import CedForAudioClassification import torchaudio import torch import os import soundfile as sf # New imports for handling MP3 and M4A files from pydub import AudioSegment import io model_name = "mispeech/ced-tiny" feature_extractor = CedFeatureExtractor.from_pretrained(model_name) model = CedForAudioClassification.from_pretrained(model_name) st.title("Audio Classification App") st.subheader("Trained on 50 classes of ESC 50 dataset") st.write("Upload an audio file to predict its class. It takes .wav file format") audio_file = st.file_uploader("Upload Audio File", type=["wav"]) if audio_file is not None: st.write(f"Uploaded file: {audio_file.name}") try: # New code block for handling different audio formats temp_file_path = "temp.wav" if audio_file.name.lower().endswith(('.mp3', '.m4a')): # Convert MP3/M4A to WAV audio_bytes = audio_file.read() audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format=audio_file.name.split('.')[-1]) audio.export(temp_file_path, format="wav") else: # For WAV files, write directly with open(temp_file_path, "wb") as f: f.write(audio_file.read()) try: audio, sampling_rate = torchaudio.load(temp_file_path) except Exception: st.warning("Fallback to soundfile for audio loading.") audio_data, sampling_rate = sf.read(temp_file_path) audio = torch.tensor(audio_data).unsqueeze(0) if sampling_rate != 16000: st.warning("Resampling audio to 16000 Hz...") resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000) audio = resampler(audio) sampling_rate = 16000 inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = torch.argmax(logits, dim=-1).item() predicted_label = model.config.id2label[predicted_class_id] st.success(f"Predicted Class: {predicted_label}") os.remove(temp_file_path) except Exception as e: st.error(f"An error occurred: {e}") else: st.info("Please upload an audio file (WAV, MP3, or M4A) to continue.")