Kabatubare commited on
Commit
9ff14b4
·
verified ·
1 Parent(s): 36bf420

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -52
app.py CHANGED
@@ -5,63 +5,19 @@ from torch.nn.functional import softmax
5
  import librosa
6
  import os
7
 
8
- # Path to the local directory where the model files are stored
9
  local_model_path = "./"
10
-
11
- # Load the model and feature extractor outside the function to improve performance
12
  extractor = AutoFeatureExtractor.from_pretrained(local_model_path)
13
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
14
 
15
- def safe_path_join(base_path, path):
16
- """
17
- Safely join a base path and a potentially unsafe relative path.
18
-
19
- Args:
20
- base_path: The base directory path.
21
- path: The relative path to join with the base path.
22
-
23
- Returns:
24
- The safely joined path if it's a subpath of the base_path, otherwise None.
25
- """
26
- base_path = os.path.abspath(os.path.normpath(base_path))
27
- target_path = os.path.abspath(os.path.normpath(os.path.join(base_path, path)))
28
- if os.path.commonpath([base_path]) == os.path.commonpath([base_path, target_path]):
29
- return target_path
30
- else:
31
- return None
32
-
33
  def preprocess_audio(audio_file_path, target_sample_rate=16000):
34
- """
35
- Preprocesses the audio file for compatibility with the model's expectations.
36
-
37
- Args:
38
- audio_file_path: Path to the audio file.
39
- target_sample_rate: Desired sample rate compatible with the model.
40
-
41
- Returns:
42
- Processed waveform and sample rate.
43
- """
44
  waveform, _ = librosa.load(audio_file_path, sr=target_sample_rate, mono=True)
45
  return waveform, target_sample_rate
46
 
47
  def predict_voice(audio_file_path):
48
- """
49
- Predicts whether a voice is real or spoofed from an audio file.
50
-
51
- Args:
52
- audio_file_path: The path to the input audio file to be classified.
53
-
54
- Returns:
55
- A string with the prediction and confidence level.
56
- """
57
- expected_base_path = "/expected/path/for/safety"
58
- safe_audio_file_path = safe_path_join(expected_base_path, audio_file_path)
59
-
60
- if not safe_audio_file_path:
61
- return "Error: Invalid file path."
62
-
63
  try:
64
- waveform, sample_rate = preprocess_audio(safe_audio_file_path)
 
 
65
  inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate)
66
 
67
  with torch.no_grad():
@@ -78,15 +34,12 @@ def predict_voice(audio_file_path):
78
 
79
  return result
80
 
81
- # Initialize Gradio interface without the enable_queue parameter
82
  iface = gr.Interface(
83
  fn=predict_voice,
84
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
85
  outputs=gr.Textbox(label="Prediction"),
86
  title="Voice Authenticity Detection",
87
- description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results.",
88
- theme="huggingface"
89
  )
90
 
91
- # Launch the Gradio app
92
- iface.launch(share=True)
 
5
  import librosa
6
  import os
7
 
 
8
  local_model_path = "./"
 
 
9
  extractor = AutoFeatureExtractor.from_pretrained(local_model_path)
10
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def preprocess_audio(audio_file_path, target_sample_rate=16000):
 
 
 
 
 
 
 
 
 
 
13
  waveform, _ = librosa.load(audio_file_path, sr=target_sample_rate, mono=True)
14
  return waveform, target_sample_rate
15
 
16
  def predict_voice(audio_file_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
+ # In Hugging Face Spaces, uploaded files are temporarily stored in a way that's accessible
19
+ # to the app, so there's no need for a strict path check here.
20
+ waveform, sample_rate = preprocess_audio(audio_file_path)
21
  inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate)
22
 
23
  with torch.no_grad():
 
34
 
35
  return result
36
 
 
37
  iface = gr.Interface(
38
  fn=predict_voice,
39
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
40
  outputs=gr.Textbox(label="Prediction"),
41
  title="Voice Authenticity Detection",
42
+ description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
 
43
  )
44
 
45
+ iface.launch()