Kabatubare commited on
Commit
e8e81bf
·
verified ·
1 Parent(s): 6f6f035

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -31
app.py CHANGED
@@ -8,7 +8,7 @@ import librosa
8
  # Path to the local directory where the model files are stored within the Space
9
  local_model_path = "./"
10
 
11
- # Initialize the feature extractor and model from the local files
12
  extractor = AutoFeatureExtractor.from_pretrained(local_model_path)
13
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
14
 
@@ -22,44 +22,41 @@ def predict_voice(audio_file_path):
22
  Returns:
23
  A string with the prediction and confidence level.
24
  """
25
-
26
- # Load the audio file. librosa automatically resamples to the target sample rate if needed.
27
- waveform, sample_rate = librosa.load(audio_file_path, sr=16000) # Force resampling to 16000 Hz
28
-
29
- # Ensure waveform is mono
30
- if waveform.ndim > 1:
31
- waveform = np.mean(waveform, axis=0)
32
-
33
- # Convert the input audio file to model's expected format.
34
- inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate)
35
-
36
- # Generate predictions from the model.
37
- with torch.no_grad(): # Ensure no gradients are calculated
38
- outputs = model(**inputs)
39
-
40
- # Extract logits and compute the class with the highest score.
41
- logits = outputs.logits
42
- predicted_index = logits.argmax()
43
-
44
- # Translate index to label
45
- label = model.config.id2label[predicted_index.item()]
46
-
47
- # Calculate the confidence of the prediction using softmax.
48
- confidence = softmax(logits, dim=1).max().item() * 100
49
-
50
- # Prepare the output string.
51
- result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
52
  return result
53
 
54
  # Setting up the Gradio interface
55
  iface = gr.Interface(
56
  fn=predict_voice,
57
- inputs=gr.Audio(label="Upload Audio File", type="filepath"), # Correct parameter usage
58
  outputs=gr.Textbox(label="Prediction"),
59
  title="Voice Authenticity Detection",
60
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results.",
61
  theme="huggingface"
62
  )
63
 
64
- # Run the Gradio interface with share=True for creating a public link
65
- iface.launch(share=True)
 
8
  # Path to the local directory where the model files are stored within the Space
9
  local_model_path = "./"
10
 
11
+ # Load the model and feature extractor outside the function to improve performance
12
  extractor = AutoFeatureExtractor.from_pretrained(local_model_path)
13
  model = AutoModelForAudioClassification.from_pretrained(local_model_path)
14
 
 
22
  Returns:
23
  A string with the prediction and confidence level.
24
  """
25
+ try:
26
+ # Ensure the file path does not lead to unintended directories
27
+ if not audio_file_path.startswith("/expected/path/for/safety"):
28
+ return "Error: Invalid file path."
29
+
30
+ # Load and preprocess the audio file
31
+ waveform, sample_rate = librosa.load(audio_file_path, sr=16000, mono=True)
32
+
33
+ # Convert the input audio file to model's expected format
34
+ inputs = extractor(waveform, return_tensors="pt", sampling_rate=sample_rate)
35
+
36
+ # Generate predictions from the model
37
+ with torch.no_grad(): # Ensure no gradients are calculated
38
+ outputs = model(**inputs)
39
+
40
+ # Extract logits, compute the class with the highest score, and calculate confidence
41
+ logits = outputs.logits
42
+ predicted_index = logits.argmax()
43
+ label = model.config.id2label[predicted_index.item()]
44
+ confidence = softmax(logits, dim=1).max().item() * 100
45
+
46
+ result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
47
+ except Exception as e:
48
+ result = f"An error occurred during processing: {str(e)}"
 
 
 
49
  return result
50
 
51
  # Setting up the Gradio interface
52
  iface = gr.Interface(
53
  fn=predict_voice,
54
+ inputs=gr.Audio(label="Upload Audio File", type="filepath"),
55
  outputs=gr.Textbox(label="Prediction"),
56
  title="Voice Authenticity Detection",
57
  description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results.",
58
  theme="huggingface"
59
  )
60
 
61
+ # Run the Gradio interface, consider using enable_queue=True if processing is expected to be long or the app faces high traffic
62
+ iface.launch(share=True, enable_queue=True)