Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchaudio
|
7 |
+
import numpy as np
|
8 |
+
from sklearn.preprocessing import LabelEncoder
|
9 |
+
|
10 |
+
class CNN1DLSTMAudioClassifier(nn.Module):
|
11 |
+
def __init__(self, num_classes, input_channels=1, sample_rate=16000, n_fft=400, hop_length=160):
|
12 |
+
super(CNN1DLSTMAudioClassifier, self).__init__()
|
13 |
+
|
14 |
+
# 1D CNN layers
|
15 |
+
self.conv1 = nn.Conv1d(input_channels, 8, kernel_size=5, stride=1, padding=2)
|
16 |
+
self.bn1 = nn.BatchNorm1d(8)
|
17 |
+
self.pool1 = nn.MaxPool1d(kernel_size=2)
|
18 |
+
self.conv2 = nn.Conv1d(8, 16, kernel_size=5, stride=1, padding=2)
|
19 |
+
self.bn2 = nn.BatchNorm1d(16)
|
20 |
+
self.pool2 = nn.MaxPool1d(kernel_size=2)
|
21 |
+
self.conv3 = nn.Conv1d(16, 32, kernel_size=5, stride=1, padding=2)
|
22 |
+
self.bn3 = nn.BatchNorm1d(32)
|
23 |
+
self.pool3 = nn.MaxPool1d(kernel_size=2)
|
24 |
+
|
25 |
+
# Calculate the output size of the last CNN layer
|
26 |
+
self._to_linear = None
|
27 |
+
self._calculate_to_linear(input_channels, sample_rate, n_fft, hop_length)
|
28 |
+
|
29 |
+
# LSTM layers
|
30 |
+
self.lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=3, batch_first=True, bidirectional=True)
|
31 |
+
|
32 |
+
# Fully connected layer
|
33 |
+
self.fc1 = nn.Linear(128, 64)
|
34 |
+
self.fc2 = nn.Linear(64, 32)
|
35 |
+
self.fc3 = nn.Linear(32, num_classes)
|
36 |
+
|
37 |
+
# Dropout
|
38 |
+
self.dropout = nn.Dropout(0.2)
|
39 |
+
|
40 |
+
def _calculate_to_linear(self, input_channels, sample_rate, n_fft, hop_length):
|
41 |
+
# Calculate the size of the input to the LSTM layer
|
42 |
+
num_frames = (sample_rate - n_fft) // hop_length + 1
|
43 |
+
x = torch.randn(1, input_channels, num_frames)
|
44 |
+
self.convs(x)
|
45 |
+
self._to_linear = x.shape[1]
|
46 |
+
|
47 |
+
def convs(self, x):
|
48 |
+
x = self.pool1(self.bn1(F.relu(self.conv1(x))))
|
49 |
+
x = self.pool2(self.bn2(F.relu(self.conv2(x))))
|
50 |
+
x = self.pool3(self.bn3(F.relu(self.conv3(x))))
|
51 |
+
return x
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = x.view(x.size(0), 1, -1)
|
55 |
+
x = self.convs(x)
|
56 |
+
|
57 |
+
x = x.permute(0, 2, 1)
|
58 |
+
x, _ = self.lstm(x)
|
59 |
+
x = x[:, -1, :]
|
60 |
+
|
61 |
+
# Fully connected layers
|
62 |
+
x = self.dropout(x)
|
63 |
+
x = self.fc1(x)
|
64 |
+
x = self.dropout(x)
|
65 |
+
x = self.fc2(x)
|
66 |
+
|
67 |
+
return x
|
68 |
+
|
69 |
+
num_class = 6
|
70 |
+
model = CNN1DLSTMAudioClassifier(num_class)
|
71 |
+
|
72 |
+
model.load_state_dict(torch.load("speech-emotion-recognition-best-model.bin", weights_only=False))
|
73 |
+
model.eval()
|
74 |
+
|
75 |
+
def preprocess_single_audio(file_path, sample_rate=16000, n_mels=128, n_fft=2048, hop_length=512):
|
76 |
+
# Load the audio file
|
77 |
+
waveform, sr = torchaudio.load(file_path)
|
78 |
+
|
79 |
+
# Resample if necessary
|
80 |
+
if sr != sample_rate:
|
81 |
+
resampler = torchaudio.transforms.Resample(sr, sample_rate)
|
82 |
+
waveform = resampler(waveform)
|
83 |
+
|
84 |
+
# Ensure consistent audio length (2 seconds)
|
85 |
+
target_length = 2 * sample_rate
|
86 |
+
if waveform.size(1) > target_length:
|
87 |
+
waveform = waveform[:, :target_length]
|
88 |
+
else:
|
89 |
+
waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform.size(1)))
|
90 |
+
|
91 |
+
# Apply Mel Spectrogram transform
|
92 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
93 |
+
sample_rate=sample_rate,
|
94 |
+
n_mels=n_mels,
|
95 |
+
n_fft=n_fft,
|
96 |
+
hop_length=hop_length
|
97 |
+
)
|
98 |
+
mel_spectrogram = mel_transform(waveform)
|
99 |
+
|
100 |
+
# Normalize (use the mean and std from your training data)
|
101 |
+
mean = 12.65
|
102 |
+
std = 117.07
|
103 |
+
normalized_mel_spectrogram = (mel_spectrogram - mean) / std
|
104 |
+
|
105 |
+
# Flatten the mel spectrogram
|
106 |
+
flattened = normalized_mel_spectrogram.flatten()
|
107 |
+
|
108 |
+
if flattened.shape[0] < 12288:
|
109 |
+
flattened = torch.nn.functional.pad(flattened, (0, 12288 - flattened.shape[0]))
|
110 |
+
elif flattened.shape[0] > 12288:
|
111 |
+
flattened = flattened[:12288]
|
112 |
+
|
113 |
+
return flattened
|
114 |
+
|
115 |
+
def decode_emotion_prediction(prediction_tensor, label_encoder):
|
116 |
+
"""
|
117 |
+
Decodes the prediction tensor into an emotion label.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
prediction_tensor (torch.Tensor): The model's output tensor of shape [1, 6]
|
121 |
+
label_encoder (LabelEncoder): The LabelEncoder used during training
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
str: The predicted emotion label
|
125 |
+
float: The confidence score for the prediction
|
126 |
+
"""
|
127 |
+
# Get the index of the highest probability
|
128 |
+
max_index = torch.argmax(prediction_tensor, dim=1).item()
|
129 |
+
|
130 |
+
# Get the confidence score (probability) for the prediction
|
131 |
+
confidence = torch.softmax(prediction_tensor, dim=1)[0, max_index].item()
|
132 |
+
|
133 |
+
# Decode the index to get the emotion label
|
134 |
+
predicted_emotion = label_encoder.inverse_transform([max_index])[0]
|
135 |
+
|
136 |
+
return predicted_emotion, confidence
|
137 |
+
|
138 |
+
|
139 |
+
def predict(wave):
|
140 |
+
wave = preprocess_single_audio(wave)
|
141 |
+
le = LabelEncoder()
|
142 |
+
le.classes_ = np.array(['Angry', 'Disgusting', 'Fear', 'Happy', 'Neutral', 'Sad'])
|
143 |
+
wave = wave.unsqueeze(0)
|
144 |
+
with torch.no_grad():
|
145 |
+
prediction = model(wave)
|
146 |
+
predicted_emotion, confidence = decode_emotion_prediction(prediction, le)
|
147 |
+
return f"Predicted emotion: {predicted_emotion} (Confidence: {confidence:.2f})"
|
148 |
+
|
149 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
150 |
+
gr.Markdown("# Audio Prediction App")
|
151 |
+
gr.Markdown("Upload an audio file or record directly to get a prediction")
|
152 |
+
|
153 |
+
with gr.Row():
|
154 |
+
audio_input = gr.Audio(source="microphone", type="filepath")
|
155 |
+
audio_output = gr.Audio(label="Processed Audio")
|
156 |
+
|
157 |
+
with gr.Row():
|
158 |
+
submit_btn = gr.Button("Get Prediction", variant="primary")
|
159 |
+
clear_btn = gr.Button("Clear")
|
160 |
+
|
161 |
+
prediction_output = gr.Textbox(label="Prediction")
|
162 |
+
|
163 |
+
submit_btn.click(
|
164 |
+
fn=predict,
|
165 |
+
inputs=[audio_input, audio_input.source],
|
166 |
+
outputs=[audio_output, prediction_output]
|
167 |
+
)
|
168 |
+
|
169 |
+
clear_btn.click(
|
170 |
+
fn=lambda: (None, None, ""),
|
171 |
+
outputs=[audio_input, audio_output, prediction_output]
|
172 |
+
)
|
173 |
+
|
174 |
+
demo.launch()
|