raffaelsiregar commited on
Commit
3ab0534
·
verified ·
1 Parent(s): 06b7cba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+ import numpy as np
8
+ from sklearn.preprocessing import LabelEncoder
9
+
10
+ class CNN1DLSTMAudioClassifier(nn.Module):
11
+ def __init__(self, num_classes, input_channels=1, sample_rate=16000, n_fft=400, hop_length=160):
12
+ super(CNN1DLSTMAudioClassifier, self).__init__()
13
+
14
+ # 1D CNN layers
15
+ self.conv1 = nn.Conv1d(input_channels, 8, kernel_size=5, stride=1, padding=2)
16
+ self.bn1 = nn.BatchNorm1d(8)
17
+ self.pool1 = nn.MaxPool1d(kernel_size=2)
18
+ self.conv2 = nn.Conv1d(8, 16, kernel_size=5, stride=1, padding=2)
19
+ self.bn2 = nn.BatchNorm1d(16)
20
+ self.pool2 = nn.MaxPool1d(kernel_size=2)
21
+ self.conv3 = nn.Conv1d(16, 32, kernel_size=5, stride=1, padding=2)
22
+ self.bn3 = nn.BatchNorm1d(32)
23
+ self.pool3 = nn.MaxPool1d(kernel_size=2)
24
+
25
+ # Calculate the output size of the last CNN layer
26
+ self._to_linear = None
27
+ self._calculate_to_linear(input_channels, sample_rate, n_fft, hop_length)
28
+
29
+ # LSTM layers
30
+ self.lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=3, batch_first=True, bidirectional=True)
31
+
32
+ # Fully connected layer
33
+ self.fc1 = nn.Linear(128, 64)
34
+ self.fc2 = nn.Linear(64, 32)
35
+ self.fc3 = nn.Linear(32, num_classes)
36
+
37
+ # Dropout
38
+ self.dropout = nn.Dropout(0.2)
39
+
40
+ def _calculate_to_linear(self, input_channels, sample_rate, n_fft, hop_length):
41
+ # Calculate the size of the input to the LSTM layer
42
+ num_frames = (sample_rate - n_fft) // hop_length + 1
43
+ x = torch.randn(1, input_channels, num_frames)
44
+ self.convs(x)
45
+ self._to_linear = x.shape[1]
46
+
47
+ def convs(self, x):
48
+ x = self.pool1(self.bn1(F.relu(self.conv1(x))))
49
+ x = self.pool2(self.bn2(F.relu(self.conv2(x))))
50
+ x = self.pool3(self.bn3(F.relu(self.conv3(x))))
51
+ return x
52
+
53
+ def forward(self, x):
54
+ x = x.view(x.size(0), 1, -1)
55
+ x = self.convs(x)
56
+
57
+ x = x.permute(0, 2, 1)
58
+ x, _ = self.lstm(x)
59
+ x = x[:, -1, :]
60
+
61
+ # Fully connected layers
62
+ x = self.dropout(x)
63
+ x = self.fc1(x)
64
+ x = self.dropout(x)
65
+ x = self.fc2(x)
66
+
67
+ return x
68
+
69
+ num_class = 6
70
+ model = CNN1DLSTMAudioClassifier(num_class)
71
+
72
+ model.load_state_dict(torch.load("speech-emotion-recognition-best-model.bin", weights_only=False))
73
+ model.eval()
74
+
75
+ def preprocess_single_audio(file_path, sample_rate=16000, n_mels=128, n_fft=2048, hop_length=512):
76
+ # Load the audio file
77
+ waveform, sr = torchaudio.load(file_path)
78
+
79
+ # Resample if necessary
80
+ if sr != sample_rate:
81
+ resampler = torchaudio.transforms.Resample(sr, sample_rate)
82
+ waveform = resampler(waveform)
83
+
84
+ # Ensure consistent audio length (2 seconds)
85
+ target_length = 2 * sample_rate
86
+ if waveform.size(1) > target_length:
87
+ waveform = waveform[:, :target_length]
88
+ else:
89
+ waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform.size(1)))
90
+
91
+ # Apply Mel Spectrogram transform
92
+ mel_transform = torchaudio.transforms.MelSpectrogram(
93
+ sample_rate=sample_rate,
94
+ n_mels=n_mels,
95
+ n_fft=n_fft,
96
+ hop_length=hop_length
97
+ )
98
+ mel_spectrogram = mel_transform(waveform)
99
+
100
+ # Normalize (use the mean and std from your training data)
101
+ mean = 12.65
102
+ std = 117.07
103
+ normalized_mel_spectrogram = (mel_spectrogram - mean) / std
104
+
105
+ # Flatten the mel spectrogram
106
+ flattened = normalized_mel_spectrogram.flatten()
107
+
108
+ if flattened.shape[0] < 12288:
109
+ flattened = torch.nn.functional.pad(flattened, (0, 12288 - flattened.shape[0]))
110
+ elif flattened.shape[0] > 12288:
111
+ flattened = flattened[:12288]
112
+
113
+ return flattened
114
+
115
+ def decode_emotion_prediction(prediction_tensor, label_encoder):
116
+ """
117
+ Decodes the prediction tensor into an emotion label.
118
+
119
+ Args:
120
+ prediction_tensor (torch.Tensor): The model's output tensor of shape [1, 6]
121
+ label_encoder (LabelEncoder): The LabelEncoder used during training
122
+
123
+ Returns:
124
+ str: The predicted emotion label
125
+ float: The confidence score for the prediction
126
+ """
127
+ # Get the index of the highest probability
128
+ max_index = torch.argmax(prediction_tensor, dim=1).item()
129
+
130
+ # Get the confidence score (probability) for the prediction
131
+ confidence = torch.softmax(prediction_tensor, dim=1)[0, max_index].item()
132
+
133
+ # Decode the index to get the emotion label
134
+ predicted_emotion = label_encoder.inverse_transform([max_index])[0]
135
+
136
+ return predicted_emotion, confidence
137
+
138
+
139
+ def predict(wave):
140
+ wave = preprocess_single_audio(wave)
141
+ le = LabelEncoder()
142
+ le.classes_ = np.array(['Angry', 'Disgusting', 'Fear', 'Happy', 'Neutral', 'Sad'])
143
+ wave = wave.unsqueeze(0)
144
+ with torch.no_grad():
145
+ prediction = model(wave)
146
+ predicted_emotion, confidence = decode_emotion_prediction(prediction, le)
147
+ return f"Predicted emotion: {predicted_emotion} (Confidence: {confidence:.2f})"
148
+
149
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
150
+ gr.Markdown("# Audio Prediction App")
151
+ gr.Markdown("Upload an audio file or record directly to get a prediction")
152
+
153
+ with gr.Row():
154
+ audio_input = gr.Audio(source="microphone", type="filepath")
155
+ audio_output = gr.Audio(label="Processed Audio")
156
+
157
+ with gr.Row():
158
+ submit_btn = gr.Button("Get Prediction", variant="primary")
159
+ clear_btn = gr.Button("Clear")
160
+
161
+ prediction_output = gr.Textbox(label="Prediction")
162
+
163
+ submit_btn.click(
164
+ fn=predict,
165
+ inputs=[audio_input, audio_input.source],
166
+ outputs=[audio_output, prediction_output]
167
+ )
168
+
169
+ clear_btn.click(
170
+ fn=lambda: (None, None, ""),
171
+ outputs=[audio_input, audio_output, prediction_output]
172
+ )
173
+
174
+ demo.launch()