aizanlabs commited on
Commit
9a50e70
·
verified ·
1 Parent(s): ddf85e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import pickle
4
+ import numpy as np
5
+ import gradio as gr
6
+
7
+ class ML_model:
8
+ def __init__(self):
9
+ self.ml_model = torch.load("support_file/resnet_carcrash_94.pth", map_location=torch.device('cpu'))
10
+ self.ml_model.eval()
11
+ with open('support_file/indtocat.pkl', 'rb') as f:
12
+ self.i2c = pickle.load(f)
13
+
14
+ def spec_to_image(self, spec, eps=1e-6):
15
+ mean = spec.mean()
16
+ std = spec.std()
17
+ spec_norm = (spec - mean) / (std + eps)
18
+ spec_min, spec_max = spec_norm.min(), spec_norm.max()
19
+ spec_scaled = 255 * (spec_norm - spec_min) / (spec_max - spec_min)
20
+ spec_scaled = spec_scaled.astype(np.uint8)
21
+ return spec_scaled
22
+
23
+ def get_melspectrogram_db(self, file_path):
24
+ # Load audio file
25
+ wav, sr = librosa.load(file_path, sr=None)
26
+ sr= 44100
27
+ # Parameters for mel spectrogram
28
+ n_fft = 2048
29
+ hop_length = 512
30
+ n_mels = 128
31
+ fmin = 20
32
+ fmax = 8300
33
+ if wav.shape[0]<5*sr:
34
+ wav=np.pad(wav,int(np.ceil((5*sr-wav.shape[0])/2)),mode='reflect')
35
+ else:
36
+ wav=wav[:5*sr]
37
+ # Compute mel spectrogram
38
+ spec = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax)
39
+
40
+ # Convert to dB scale
41
+ spec_db = librosa.power_to_db(spec, ref=np.max)
42
+ return spec_db
43
+
44
+ def get_prediction(self, file_path):
45
+ spec_db = self.get_melspectrogram_db(file_path)
46
+ input_image = self.spec_to_image(spec_db)
47
+ input_tensor = torch.tensor(input_image[np.newaxis, np.newaxis, ...], dtype=torch.float32).to('cpu')
48
+ predictions = self.ml_model(input_tensor)
49
+ predicted_index = predictions.argmax(dim=1).item()
50
+ return self.i2c[predicted_index]
51
+
52
+ def predict(file_path):
53
+ ml_model = ML_model() # Initialize model
54
+ prediction = ml_model.get_prediction(file_path)
55
+ return prediction
56
+
57
+ interface = gr.Interface(
58
+ fn=predict,
59
+ inputs=gr.Audio(type="filepath", label="Upload your audio file"),
60
+ outputs="text",
61
+ title="Car Crash Sound Detection",
62
+ description="Upload a car crash sound clip and the model will identify the crash type.",
63
+ examples=["input_fileszQ1QmqrakIA_5-talking.wav","input_fileszQ1QmqrakIA_13-crash.wav"],
64
+ cache_examples=False
65
+ )
66
+
67
+ interface.launch(share=True)