Spaces:
Running
Running
uploaded files
Browse files- app.py +54 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from ced_model.feature_extraction_ced import CedFeatureExtractor
|
3 |
+
from ced_model.modeling_ced import CedForAudioClassification
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
import soundfile as sf
|
8 |
+
|
9 |
+
model_name = "mispeech/ced-tiny"
|
10 |
+
feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
|
11 |
+
model = CedForAudioClassification.from_pretrained(model_name)
|
12 |
+
|
13 |
+
st.title("Audio Classification App")
|
14 |
+
st.subheader("Trained on 50 classes of ESC 50 dataset")
|
15 |
+
st.write("Upload an audio file to predict its class.")
|
16 |
+
|
17 |
+
audio_file = st.file_uploader("Upload Audio File", type=["wav"])
|
18 |
+
|
19 |
+
if audio_file is not None:
|
20 |
+
st.write(f"Uploaded file: {audio_file.name}")
|
21 |
+
|
22 |
+
try:
|
23 |
+
temp_file_path = "temp.wav"
|
24 |
+
with open(temp_file_path, "wb") as f:
|
25 |
+
f.write(audio_file.read())
|
26 |
+
|
27 |
+
try:
|
28 |
+
audio, sampling_rate = torchaudio.load(temp_file_path)
|
29 |
+
except Exception:
|
30 |
+
st.warning("Fallback to soundfile for audio loading.")
|
31 |
+
audio_data, sampling_rate = sf.read(temp_file_path)
|
32 |
+
audio = torch.tensor(audio_data).unsqueeze(0)
|
33 |
+
|
34 |
+
if sampling_rate != 16000:
|
35 |
+
st.warning("Resampling audio to 16000 Hz...")
|
36 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
|
37 |
+
audio = resampler(audio)
|
38 |
+
sampling_rate = 16000
|
39 |
+
|
40 |
+
inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
logits = model(**inputs).logits
|
44 |
+
|
45 |
+
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
46 |
+
predicted_label = model.config.id2label[predicted_class_id]
|
47 |
+
|
48 |
+
st.success(f"Predicted Class: {predicted_label}")
|
49 |
+
|
50 |
+
os.remove(temp_file_path)
|
51 |
+
except Exception as e:
|
52 |
+
st.error(f"An error occurred: {e}")
|
53 |
+
else:
|
54 |
+
st.info("Please upload a .wav audio file to continue.")
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
git+https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
|
3 |
+
numpy==1.26.4
|
4 |
+
torchaudio==2.1.1+cpu
|
5 |
+
torch==2.1.1+cpu
|
6 |
+
soundfile
|