omsandeeppatil commited on
Commit
1de8eea
·
verified ·
1 Parent(s): eb64d62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -26,18 +26,26 @@ def process_audio(audio):
26
  if audio is None:
27
  return ""
28
 
29
- # Get the audio data
30
- if isinstance(audio, tuple):
31
- audio = audio[1]
32
-
33
- # Convert to numpy array if needed
34
- audio = np.array(audio)
35
-
36
- # Ensure we have mono audio
37
- if len(audio.shape) > 1:
38
- audio = audio.mean(axis=1)
39
-
40
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Prepare input for the model
42
  inputs = feature_extractor(
43
  audio,
@@ -46,8 +54,8 @@ def process_audio(audio):
46
  padding=True
47
  )
48
 
49
- # Move to appropriate device
50
- inputs = {k: v.to(device) for k, v in inputs.items()}
51
 
52
  # Get prediction
53
  with torch.no_grad():
@@ -55,12 +63,16 @@ def process_audio(audio):
55
  logits = outputs.logits
56
  predicted_id = torch.argmax(logits, dim=-1).item()
57
 
 
 
 
 
58
  emotion = EMOTION_LABELS[predicted_id]
59
- return emotion
60
 
61
  except Exception as e:
62
- print(f"Error processing audio: {e}")
63
- return "Error processing audio"
64
 
65
  # Create Gradio interface
66
  demo = gr.Interface(
@@ -82,4 +94,5 @@ demo = gr.Interface(
82
  )
83
 
84
  # Launch with a small queue for better real-time performance
85
- demo.queue(max_size=1).launch(share=True)
 
 
26
  if audio is None:
27
  return ""
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
+ # Get the audio data
31
+ if isinstance(audio, tuple):
32
+ audio = audio[1]
33
+
34
+ # Convert to numpy array and ensure float32 type
35
+ audio = np.array(audio, dtype=np.float32)
36
+
37
+ # Ensure we have mono audio
38
+ if len(audio.shape) > 1:
39
+ audio = audio.mean(axis=1)
40
+
41
+ # Normalize audio if needed
42
+ if audio.max() > 1.0 or audio.min() < -1.0:
43
+ audio = audio / max(abs(audio.max()), abs(audio.min()))
44
+
45
+ # Ensure we have non-zero audio
46
+ if len(audio) == 0 or np.all(audio == 0):
47
+ return "No audio detected"
48
+
49
  # Prepare input for the model
50
  inputs = feature_extractor(
51
  audio,
 
54
  padding=True
55
  )
56
 
57
+ # Ensure all tensors are float32
58
+ inputs = {k: v.to(device, dtype=torch.float32) for k, v in inputs.items()}
59
 
60
  # Get prediction
61
  with torch.no_grad():
 
63
  logits = outputs.logits
64
  predicted_id = torch.argmax(logits, dim=-1).item()
65
 
66
+ # Get probabilities
67
+ probs = torch.nn.functional.softmax(logits, dim=-1)
68
+ confidence = probs[0][predicted_id].item() * 100
69
+
70
  emotion = EMOTION_LABELS[predicted_id]
71
+ return f"{emotion} (confidence: {confidence:.1f}%)"
72
 
73
  except Exception as e:
74
+ print(f"Error in audio processing: {str(e)}")
75
+ return "Error processing audio. Please try again."
76
 
77
  # Create Gradio interface
78
  demo = gr.Interface(
 
94
  )
95
 
96
  # Launch with a small queue for better real-time performance
97
+ if __name__ == "__main__":
98
+ demo.queue(max_size=1).launch(share=True)