andromeda01111 commited on
Commit
55719c9
·
verified ·
1 Parent(s): d0b37ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -30
app.py CHANGED
@@ -1,16 +1,10 @@
1
  import gradio as gr
2
  import torch
3
- import torch.nn as nn
4
  import torch.nn.functional as F
5
  import torchaudio
6
  from transformers import AutoConfig, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
7
  from src.models import Wav2Vec2ForSpeechClassification
8
-
9
- import librosa
10
- import IPython.display as ipd
11
  import numpy as np
12
- import pandas as pd
13
- import os
14
 
15
  model_name_or_path = "andromeda01111/Malayalam_SA"
16
  config = AutoConfig.from_pretrained(model_name_or_path)
@@ -18,47 +12,39 @@ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
18
  sampling_rate = feature_extractor.sampling_rate
19
  model = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path)
20
 
21
-
22
  def speech_file_to_array_fn(path, sampling_rate):
23
  speech_array, _sampling_rate = torchaudio.load(path)
24
- resampler = torchaudio.transforms.Resample(_sampling_rate)
25
  speech = resampler(speech_array).squeeze().numpy()
26
  return speech
27
 
28
-
29
- def predict(path, sampling_rate):
30
- speech = speech_file_to_array_fn(path, sampling_rate)
31
  features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
32
-
33
  input_values = features.input_values
34
  attention_mask = features.attention_mask
35
-
36
  with torch.no_grad():
37
  logits = model(input_values, attention_mask=attention_mask).logits
38
-
39
  scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
40
- output_emotion = [{"Emotion": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
41
-
42
  return output_emotion
43
 
44
-
45
- # Wrapper function for Gradio
46
  def gradio_predict(audio):
47
- predictions = predict(audio)
48
- return [f"{pred['Emotion']}: {pred['Score']}" for pred in predictions]
49
-
50
-
51
- # Gradio interface
52
- emotions = [config.id2label[i] for i in range(len(config.id2label))]
53
- outputs = [gr.Textbox(label=emotion, interactive=False) for emotion in emotions]
54
 
 
55
  interface = gr.Interface(
56
- fn=predict,
57
- inputs=gr.Audio(label="Upload Audio", type="filepath"),
58
- outputs=outputs,
59
  title="Emotion Recognition",
60
- description="Upload an audio file to predict emotions and their corresponding percentages.",
 
61
  )
62
 
63
  # Launch the app
64
- interface.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  import torch.nn.functional as F
4
  import torchaudio
5
  from transformers import AutoConfig, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
6
  from src.models import Wav2Vec2ForSpeechClassification
 
 
 
7
  import numpy as np
 
 
8
 
9
  model_name_or_path = "andromeda01111/Malayalam_SA"
10
  config = AutoConfig.from_pretrained(model_name_or_path)
 
12
  sampling_rate = feature_extractor.sampling_rate
13
  model = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path)
14
 
 
15
  def speech_file_to_array_fn(path, sampling_rate):
16
  speech_array, _sampling_rate = torchaudio.load(path)
17
+ resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
18
  speech = resampler(speech_array).squeeze().numpy()
19
  return speech
20
 
21
+ def predict(audio_path):
22
+ speech = speech_file_to_array_fn(audio_path, sampling_rate)
 
23
  features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
24
+
25
  input_values = features.input_values
26
  attention_mask = features.attention_mask
27
+
28
  with torch.no_grad():
29
  logits = model(input_values, attention_mask=attention_mask).logits
30
+
31
  scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
32
+ output_emotion = {config.id2label[i]: f"{round(score * 100, 3):.1f}%" for i, score in enumerate(scores)}
33
+
34
  return output_emotion
35
 
 
 
36
  def gradio_predict(audio):
37
+ return predict(audio)
 
 
 
 
 
 
38
 
39
+ # Gradio Interface with Audio Recording (max duration: 10 seconds)
40
  interface = gr.Interface(
41
+ fn=gradio_predict,
42
+ inputs=gr.Audio(source="microphone", type="filepath", label="Record or Upload Audio", streaming=False),
43
+ outputs=gr.JSON(label="Emotion Scores"),
44
  title="Emotion Recognition",
45
+ description="Record or upload an audio file (max 10 sec) to predict emotions and their corresponding percentages.",
46
+ live=False,
47
  )
48
 
49
  # Launch the app
50
+ interface.launch()