alaahilal commited on
Commit
837062b
·
verified ·
1 Parent(s): f036141

uploaded files

Browse files
Files changed (2) hide show
  1. app.py +54 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from ced_model.feature_extraction_ced import CedFeatureExtractor
3
+ from ced_model.modeling_ced import CedForAudioClassification
4
+ import torchaudio
5
+ import torch
6
+ import os
7
+ import soundfile as sf
8
+
9
+ model_name = "mispeech/ced-tiny"
10
+ feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
11
+ model = CedForAudioClassification.from_pretrained(model_name)
12
+
13
+ st.title("Audio Classification App")
14
+ st.subheader("Trained on 50 classes of ESC 50 dataset")
15
+ st.write("Upload an audio file to predict its class.")
16
+
17
+ audio_file = st.file_uploader("Upload Audio File", type=["wav"])
18
+
19
+ if audio_file is not None:
20
+ st.write(f"Uploaded file: {audio_file.name}")
21
+
22
+ try:
23
+ temp_file_path = "temp.wav"
24
+ with open(temp_file_path, "wb") as f:
25
+ f.write(audio_file.read())
26
+
27
+ try:
28
+ audio, sampling_rate = torchaudio.load(temp_file_path)
29
+ except Exception:
30
+ st.warning("Fallback to soundfile for audio loading.")
31
+ audio_data, sampling_rate = sf.read(temp_file_path)
32
+ audio = torch.tensor(audio_data).unsqueeze(0)
33
+
34
+ if sampling_rate != 16000:
35
+ st.warning("Resampling audio to 16000 Hz...")
36
+ resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
37
+ audio = resampler(audio)
38
+ sampling_rate = 16000
39
+
40
+ inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
41
+
42
+ with torch.no_grad():
43
+ logits = model(**inputs).logits
44
+
45
+ predicted_class_id = torch.argmax(logits, dim=-1).item()
46
+ predicted_label = model.config.id2label[predicted_class_id]
47
+
48
+ st.success(f"Predicted Class: {predicted_label}")
49
+
50
+ os.remove(temp_file_path)
51
+ except Exception as e:
52
+ st.error(f"An error occurred: {e}")
53
+ else:
54
+ st.info("Please upload a .wav audio file to continue.")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ git+https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
3
+ numpy==1.26.4
4
+ torchaudio==2.1.1+cpu
5
+ torch==2.1.1+cpu
6
+ soundfile