SuryaT1 commited on
Commit
1ef3e75
·
verified ·
1 Parent(s): 24b78ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import librosa
4
+ import pickle
5
+ import tensorflow as tf
6
+ from sklearn.preprocessing import LabelEncoder
7
+
8
+ # Paths to your models and label encoders
9
+ lstm_speaker_model = '/content/lstm_speaker_model.h5'
10
+ lstm_gender_model = '/content/lstm_gender_model.h5'
11
+ lstm_speaker_label = '/content/lstm_speaker_label.pkl'
12
+ lstm_gender_label = '/content/lstm_gender_label.pkl'
13
+
14
+ # ------------------- Feature Extraction -------------------
15
+ def extract_features(audio_data, max_len=34):
16
+ """Extract MFCC features from an audio file."""
17
+ audio, sr = librosa.load(audio_data, sr=None)
18
+
19
+ # Extract MFCC features (13 coefficients)
20
+ mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
21
+ mfccs_mean = np.mean(mfccs, axis=1)
22
+
23
+ # Spectral Features: Chroma
24
+ chroma = librosa.feature.chroma_stft(y=audio, sr=sr)
25
+ chroma_mean = np.mean(chroma, axis=1)
26
+
27
+ # Spectral Features: Spectral Contrast
28
+ spectral_contrast = librosa.feature.spectral_contrast(y=audio, sr=sr)
29
+ spectral_contrast_mean = np.mean(spectral_contrast, axis=1)
30
+
31
+ # Combine only a subset of features (to match the model's expected input size)
32
+ features = np.hstack([mfccs_mean[:13], chroma_mean[:13], spectral_contrast_mean[:8]])
33
+
34
+ # Padding or truncating to fixed length (max_len)
35
+ if features.shape[0] < max_len:
36
+ padding = np.zeros((max_len - features.shape[0],))
37
+ features = np.concatenate((features, padding))
38
+ elif features.shape[0] > max_len:
39
+ features = features[:max_len]
40
+
41
+ return features
42
+
43
+ def preprocess_audio_for_model(audio_data, max_len=34):
44
+ """Preprocess audio file for model prediction."""
45
+ features = extract_features(audio_data, max_len=max_len)
46
+ features = features.reshape(1, 1, features.shape[0]) # Shape for LSTM: (samples, timesteps, features)
47
+ return features
48
+
49
+ # ------------------- Load Pre-trained Models and Label Encoders -------------------
50
+ def load_trained_model(model_path='/content/lstm_speaker_model.h5'):
51
+ """Load the trained speaker model."""
52
+ return tf.keras.models.load_model(model_path)
53
+
54
+ def load_gender_model(model_path='/content/lstm_gender_model.h5'):
55
+ """Load the trained gender model."""
56
+ return tf.keras.models.load_model(model_path)
57
+
58
+ def load_label_encoder(label_encoder_path='/content/lstm_speaker_label.pkl'):
59
+ """Load the label encoder for speaker labels."""
60
+ with open(label_encoder_path, 'rb') as f:
61
+ label_encoder = pickle.load(f)
62
+ return label_encoder
63
+
64
+ def load_gender_label_encoder(label_encoder_path='/content/lstm_gender_label.pkl'):
65
+ """Load the label encoder for gender labels."""
66
+ with open(label_encoder_path, 'rb') as f:
67
+ label_encoder = pickle.load(f)
68
+ return label_encoder
69
+
70
+ # ------------------- Predict Top 3 Speakers and Gender -------------------
71
+ def predict_top_3_speakers_and_gender(audio_data, speaker_model, gender_model, speaker_encoder, gender_encoder, max_len=34):
72
+ """Predict the top 3 speakers and gender from an uploaded audio file."""
73
+ features = preprocess_audio_for_model(audio_data, max_len=max_len)
74
+
75
+ # Predict the speaker probabilities
76
+ speaker_pred = speaker_model.predict(features)
77
+
78
+ # Get top 3 speakers
79
+ top_3_speakers_idx = np.argsort(speaker_pred[0])[::-1][:3]
80
+ top_3_speakers_probs = speaker_pred[0][top_3_speakers_idx] * 100 # Convert to percentages
81
+ top_3_speakers = speaker_encoder.inverse_transform(top_3_speakers_idx)
82
+
83
+ # Predict the gender
84
+ gender_pred = gender_model.predict(features) # Gender model needs 1D features
85
+ predicted_gender = gender_encoder.inverse_transform([np.argmax(gender_pred)])[0]
86
+
87
+ return top_3_speakers, top_3_speakers_probs, predicted_gender
88
+
89
+ # ------------------- Gradio Interface -------------------
90
+ def gradio_interface(audio):
91
+ # Load the trained models and label encoders
92
+ speaker_model = load_trained_model(lstm_speaker_model) # Speaker model
93
+ gender_model = load_gender_model(lstm_gender_model) # Gender model
94
+ speaker_encoder = load_label_encoder(lstm_speaker_label) # Speaker label encoder
95
+ gender_encoder = load_gender_label_encoder(lstm_gender_label) # Gender label encoder
96
+
97
+ # Predict the top 3 speakers and gender from the uploaded audio file
98
+ top_3_speakers, top_3_speakers_probs, predicted_gender = predict_top_3_speakers_and_gender(
99
+ audio, speaker_model, gender_model, speaker_encoder, gender_encoder
100
+ )
101
+
102
+ # Return results as a formatted string for Gradio output
103
+ result = f"The top 3 predicted speakers are:\n"
104
+ for speaker, prob in zip(top_3_speakers, top_3_speakers_probs):
105
+ result += f"{speaker}: {prob:.2f}%\n"
106
+
107
+ result += f"\nThe predicted gender is: {predicted_gender}"
108
+
109
+ return result
110
+
111
+ # Gradio interface creation
112
+ demo = gr.Interface(
113
+ fn=gradio_interface, # The function to predict speaker and gender
114
+ inputs=gr.Audio(type="filepath"), # Audio input (file upload)
115
+ outputs="text", # Output the prediction result as text
116
+ live=False, # Disable live feedback
117
+ title="Speaker and Gender Prediction",
118
+ description="Upload or record an audio file to predict the top 3 speakers and gender.",
119
+ allow_flagging="never", # Disable flagging
120
+ theme="compact", # Set the theme
121
+ css="""
122
+ body {
123
+ margin: 0;
124
+ padding: 0;
125
+ background-color: #f1f1f1;
126
+ font-family: 'Roboto', sans-serif;
127
+ }
128
+
129
+ .gradio-container {
130
+ background-color: #ffffff;
131
+ padding: 20px;
132
+ border-radius: 8px;
133
+ box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
134
+ }
135
+
136
+ h1, p {
137
+ color: #333;
138
+ }
139
+ """
140
+ )
141
+
142
+ # Launch Gradio app
143
+ demo.launch()