File size: 2,505 Bytes
3ba39f6
 
 
 
 
 
 
cd3b0fc
 
 
3ba39f6
 
 
 
 
 
 
212fcb5
3ba39f6
212fcb5
3ba39f6
 
 
 
 
cd3b0fc
3ba39f6
cd3b0fc
 
 
 
 
 
 
 
 
3ba39f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3b0fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
from ced_model.feature_extraction_ced import CedFeatureExtractor
from ced_model.modeling_ced import CedForAudioClassification
import torchaudio
import torch
import os
import soundfile as sf
# New imports for handling MP3 and M4A files
from pydub import AudioSegment
import io

model_name = "mispeech/ced-tiny"
feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
model = CedForAudioClassification.from_pretrained(model_name)

st.title("Audio Classification App")
st.subheader("Trained on 50 classes of ESC 50 dataset")
st.write("Upload an audio file to predict its class. It takes .wav file format")

audio_file = st.file_uploader("Upload Audio File", type=["wav"])

if audio_file is not None:
    st.write(f"Uploaded file: {audio_file.name}")
    
    try:
        # New code block for handling different audio formats
        temp_file_path = "temp.wav"
        if audio_file.name.lower().endswith(('.mp3', '.m4a')):
            # Convert MP3/M4A to WAV
            audio_bytes = audio_file.read()
            audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format=audio_file.name.split('.')[-1])
            audio.export(temp_file_path, format="wav")
        else:
            # For WAV files, write directly
            with open(temp_file_path, "wb") as f:
                f.write(audio_file.read())
        
        try:
            audio, sampling_rate = torchaudio.load(temp_file_path)
        except Exception:
            st.warning("Fallback to soundfile for audio loading.")
            audio_data, sampling_rate = sf.read(temp_file_path)
            audio = torch.tensor(audio_data).unsqueeze(0)  

        if sampling_rate != 16000:
            st.warning("Resampling audio to 16000 Hz...")
            resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
            audio = resampler(audio)
            sampling_rate = 16000
        
        inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
        
        with torch.no_grad():
            logits = model(**inputs).logits
        
        predicted_class_id = torch.argmax(logits, dim=-1).item()
        predicted_label = model.config.id2label[predicted_class_id]
        
        st.success(f"Predicted Class: {predicted_label}")
        
        os.remove(temp_file_path)
    except Exception as e:
        st.error(f"An error occurred: {e}")
else:
    st.info("Please upload an audio file (WAV, MP3, or M4A) to continue.")