Kabatubare commited on
Commit
df3ef47
·
verified ·
1 Parent(s): e8e81bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -16
app.py CHANGED
@@ -4,14 +4,36 @@ import numpy as np
4
  import torch
5
  from torch.nn.functional import softmax
6
  import librosa
 
7
 
8
- # Path to the local directory where the model files are stored within the Space
9
  local_model_path = "./"
10
 
11
  # Load the model and feature extractor outside the function to improve performance
12
  extractor = AutoFeatureExtractor.from_pretrained(local_model_path)
13
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def predict_voice(audio_file_path):
16
  """
17
  Predicts whether a voice is real or spoofed from an audio file.
@@ -22,22 +44,21 @@ def predict_voice(audio_file_path):
22
  Returns:
23
  A string with the prediction and confidence level.
24
  """
25
- try:
26
- # Ensure the file path does not lead to unintended directories
27
- if not audio_file_path.startswith("/expected/path/for/safety"):
28
- return "Error: Invalid file path."
 
 
29
 
 
30
  # Load and preprocess the audio file
31
- waveform, sample_rate = librosa.load(audio_file_path, sr=16000, mono=True)
32
-
33
- # Convert the input audio file to model's expected format
34
  inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate)
35
 
36
- # Generate predictions from the model
37
- with torch.no_grad(): # Ensure no gradients are calculated
38
  outputs = model(**inputs)
39
-
40
- # Extract logits, compute the class with the highest score, and calculate confidence
41
  logits = outputs.logits
42
  predicted_index = logits.argmax()
43
  label = model.config.id2label[predicted_index.item()]
@@ -46,17 +67,18 @@ def predict_voice(audio_file_path):
46
  result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
47
  except Exception as e:
48
  result = f"An error occurred during processing: {str(e)}"
 
49
  return result
50
 
51
- # Setting up the Gradio interface
52
  iface = gr.Interface(
53
  fn=predict_voice,
54
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
55
  outputs=gr.Textbox(label="Prediction"),
56
  title="Voice Authenticity Detection",
57
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results.",
58
- theme="huggingface"
 
59
  )
60
 
61
- # Run the Gradio interface, consider using enable_queue=True if processing is expected to be long or the app faces high traffic
62
- iface.launch(share=True, enable_queue=True)
 
4
  import torch
5
  from torch.nn.functional import softmax
6
  import librosa
7
+ import os
8
 
9
+ # Path to the local directory where the model files are stored
10
  local_model_path = "./"
11
 
12
  # Load the model and feature extractor outside the function to improve performance
13
  extractor = AutoFeatureExtractor.from_pretrained(local_model_path)
14
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
15
 
16
+ def safe_path_join(base_path, path):
17
+ """
18
+ Safely join a base path and a potentially unsafe relative path.
19
+
20
+ Args:
21
+ base_path: The base directory path.
22
+ path: The relative path to join with the base path.
23
+
24
+ Returns:
25
+ The safely joined path if it's a subpath of the base_path, otherwise None.
26
+ """
27
+ # Normalize and absolute both paths
28
+ base_path = os.path.abspath(os.path.normpath(base_path))
29
+ target_path = os.path.abspath(os.path.normpath(os.path.join(base_path, path)))
30
+
31
+ # Ensure the target path is within the base_path directory
32
+ if os.path.commonpath([base_path]) == os.path.commonpath([base_path, target_path]):
33
+ return target_path
34
+ else:
35
+ return None
36
+
37
  def predict_voice(audio_file_path):
38
  """
39
  Predicts whether a voice is real or spoofed from an audio file.
 
44
  Returns:
45
  A string with the prediction and confidence level.
46
  """
47
+ # Safety check and path normalization
48
+ expected_base_path = "/expected/path/for/safety"
49
+ safe_audio_file_path = safe_path_join(expected_base_path, audio_file_path)
50
+
51
+ if not safe_audio_file_path:
52
+ return "Error: Invalid file path."
53
 
54
+ try:
55
  # Load and preprocess the audio file
56
+ waveform, sample_rate = librosa.load(safe_audio_file_path, sr=16000, mono=True)
 
 
57
  inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate)
58
 
59
+ with torch.no_grad(): # No gradients needed
 
60
  outputs = model(**inputs)
61
+
 
62
  logits = outputs.logits
63
  predicted_index = logits.argmax()
64
  label = model.config.id2label[predicted_index.item()]
 
67
  result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
68
  except Exception as e:
69
  result = f"An error occurred during processing: {str(e)}"
70
+
71
  return result
72
 
73
+ # Gradio interface setup with enhancements for scalability and performance
74
  iface = gr.Interface(
75
  fn=predict_voice,
76
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
77
  outputs=gr.Textbox(label="Prediction"),
78
  title="Voice Authenticity Detection",
79
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results.",
80
+ theme="huggingface",
81
+ enable_queue=True # Enable queuing to handle high traffic efficiently
82
  )
83
 
84
+ iface.launch(share=True)