jcvsalinas commited on
Commit
eb270a4
·
verified ·
1 Parent(s): 3b66fb1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -24
app.py CHANGED
@@ -1,29 +1,146 @@
1
  import gradio as gr
2
  import numpy as np
3
  import matplotlib.pyplot as plt
 
4
 
5
- # This function returns the waveform data to be displayed
6
- def record_audio(audio):
7
- if audio is not None:
8
- data = audio[0] # Extract the audio data from the tuple
9
- sample_rate = audio[1] # Extract the sample rate
10
-
11
- # Create a mm plot
12
- plt.figure(figsize=(10, 4))
13
- plt.plot(data)
14
- plt.title("Real-Time Audio Waveform")
15
- plt.xlabel("Sample Number")
16
- plt.ylabel("Amplitude")
17
- plt.grid(True)
18
- return plt
19
-
20
- # Define the Gradio interface
21
- gr.Interface(
22
- fn=record_audio,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  theme= gr.themes.Soft(),
24
- inputs=gr.Audio(type="numpy"), # Capture audio as numpy array
25
- outputs="plot", # Output the waveform plot
26
- live=True, # Enable real-time recording
27
- title="Real-Time Audio Recording",
28
- description="Record audio in real time and view the waveform."
29
- ).launch(share = True)
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
+ import librosa
5
 
6
+ HOME_DIR = ""
7
+ local_config_path = 'config.json'
8
+ local_preprocessor_config_path = 'preprocessor_config.json'
9
+ local_weights_path = 'pytorch_model.bin'
10
+ local_training_args_path = 'training_args.bin'
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ # Define the id2label mapping
18
+ id2label = {
19
+ 0: "angry",
20
+ 1: "disgust",
21
+ 2: "fear",
22
+ 3: "happy",
23
+ 4: "neutral",
24
+ 5: "sad"
25
+ }
26
+
27
+
28
+ def predict(model, feature_extractor, data, max_length, id2label):
29
+ # Extract features
30
+ inputs = feature_extractor(data, sampling_rate=16000, max_length=max_length, return_tensors='tf', padding=True, truncation=True)
31
+ torch_inputs = torch.tensor(inputs['input_values'].numpy(), dtype=torch.float32)
32
+
33
+ # Forward pass
34
+ outputs = model(input_values=torch_inputs)
35
+
36
+ # Extract logits from the output
37
+ logits = outputs
38
+
39
+ # Apply softmax to get probabilities
40
+ probabilities = F.softmax(logits, dim=-1)
41
+
42
+ # Get the predicted class index
43
+ predicted_class_idx = torch.argmax(probabilities, dim=-1).item()
44
+ predicted_label = id2label[predicted_class_idx]
45
+ #predicted_label = predicted_class_idx
46
+
47
+ return predicted_label
48
+
49
+ from transformers import Wav2Vec2Config, Wav2Vec2Model
50
+ import torch.nn as nn
51
+ from huggingface_hub import PyTorchModelHubMixin
52
+
53
+ config = Wav2Vec2Config.from_pretrained(local_config_path)
54
+ class Wav2Vec2ForSpeechClassification(nn.Module, PyTorchModelHubMixin):
55
+ def __init__(self, config):
56
+ super(Wav2Vec2ForSpeechClassification, self).__init__()
57
+ self.wav2vec2 = Wav2Vec2Model(config)
58
+
59
+ self.classifier = nn.ModuleDict({
60
+ 'dense': nn.Linear(config.hidden_size, config.hidden_size),
61
+ 'activation': nn.ReLU(),
62
+ 'dropout': nn.Dropout(config.final_dropout),
63
+ 'out_proj': nn.Linear(config.hidden_size, config.num_labels)
64
+ })
65
+
66
+ def forward(self, input_values):
67
+ outputs = self.wav2vec2(input_values)
68
+ hidden_states = outputs.last_hidden_state
69
+
70
+ x = self.classifier['dense'](hidden_states[:, 0, :])
71
+ x = self.classifier['activation'](x)
72
+ x = self.classifier['dropout'](x)
73
+ logits = self.classifier['out_proj'](x)
74
+
75
+ return logits
76
+
77
+ import json
78
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
79
+
80
+ # Load the preprocessor configuration from the local file
81
+ with open(local_preprocessor_config_path, 'r') as file:
82
+ preprocessor_config = json.load(file)
83
+
84
+ # Initialize the preprocessor using the loaded configuration
85
+ feature_extractor = Wav2Vec2FeatureExtractor(
86
+ do_normalize=preprocessor_config["do_normalize"],
87
+ feature_extractor_type=preprocessor_config["feature_extractor_type"],
88
+ feature_size=preprocessor_config["feature_size"],
89
+ padding_side=preprocessor_config["padding_side"],
90
+ padding_value=preprocessor_config["padding_value"],
91
+ processor_class_from_name=preprocessor_config["processor_class"],
92
+ return_attention_mask=preprocessor_config["return_attention_mask"],
93
+ sampling_rate=preprocessor_config["sampling_rate"]
94
+ )
95
+
96
+ # load the newly finetuned model from huggingface repo
97
+
98
+ from huggingface_hub import hf_hub_download
99
+
100
+ model_path = hf_hub_download(
101
+ repo_id="kvilla/wav2vec-english-speech-emotion-recognition-finetuned",
102
+ filename="model_finetuned.pth"
103
+ )
104
+
105
+ # load the newly finetuned model! from local
106
+ saved_model = torch.load(model_path, map_location=torch.device('cpu'))
107
+
108
+ # Create the model with the loaded configuration
109
+ model = Wav2Vec2ForSpeechClassification(config=config)
110
+
111
+ # Load the state dictionary
112
+ model.load_state_dict(saved_model, strict=False)
113
+
114
+ print("Model initialized successfully.")
115
+
116
+ model.eval()
117
+
118
+
119
+ def recognize_emotion(audio):
120
+ # Load the audio file using librosa
121
+ #audio, _ = librosa.load(file_path, sr=16000)
122
+ sample_rate, audio_data = audio
123
+ print(audio_data)
124
+
125
+ # Ensure audio data is in floating-point format
126
+ if not np.issubdtype(audio_data.dtype, np.floating):
127
+ audio_data = audio_data.astype(np.float32)
128
+ print(audio_data)
129
+ # If you still want to process it with librosa, e.g., to change sample rate:
130
+ if sample_rate != 16000:
131
+ audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
132
+ return predict(model, feature_extractor, audio_data, len(audio_data), id2label)
133
+
134
+
135
+ demo = gr.Blocks()
136
+ with demo:
137
  theme= gr.themes.Soft(),
138
+ audio_input = gr.Audio(type="numpy",
139
+ sources=["microphone"],
140
+ show_label=True
141
+ )
142
+ text_output = gr.Textbox(label="Recognized Emotion")
143
+
144
+ # Automatically call the recognize_emotion function when audio is recorded
145
+ audio_input.stop_recording(fn=recognize_emotion, inputs=audio_input, outputs=text_output)
146
+ demo.launch(share=True)