Kabatubare commited on
Commit
38963c6
·
verified ·
1 Parent(s): 49ef139

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -22
app.py CHANGED
@@ -3,44 +3,80 @@ import librosa
3
  import numpy as np
4
  import torch
5
  import logging
6
- from transformers import Wav2Vec2ForSequenceClassification
 
 
 
 
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
 
 
10
  model_path = "./"
11
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
 
12
 
13
- def preprocess_audio(audio_path, target_sr=16000):
14
- y, sr = librosa.load(audio_path, sr=target_sr)
15
- y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
16
- return y
 
 
 
17
 
18
- def predict_voice(audio_file_path):
 
 
 
19
  try:
20
- audio_data = preprocess_audio(audio_file_path)
21
- inputs = model.processor(audio_data, sampling_rate=16000, return_tensors="pt", padding=True)
22
-
 
 
 
 
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
-
26
  logits = outputs.logits
27
  predicted_index = logits.argmax(dim=1).item()
28
- label = model.config.id2label[predicted_index]
29
  confidence = torch.softmax(logits, dim=1).max().item() * 100
30
-
31
- result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
32
- logging.info("Prediction successful.")
33
  except Exception as e:
34
- result = f"Error during processing: {e}"
35
- logging.error(result)
36
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
38
  iface = gr.Interface(
39
- fn=predict_voice,
40
- inputs=gr.Audio(label="Upload Audio File", type="filepath"),
41
- outputs=gr.Textbox(label="Prediction"),
42
- title="Voice Authenticity Detection",
43
- description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
44
  )
45
 
46
  iface.launch()
 
3
  import numpy as np
4
  import torch
5
  import logging
6
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
7
+ from pydub import AudioSegment
8
+ import os
9
+ import tempfile
10
+ import soundfile as sf
11
 
12
+ # Setup logging
13
  logging.basicConfig(level=logging.INFO)
14
 
15
+ # Load model and processor
16
  model_path = "./"
17
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path)
18
+ processor = Wav2Vec2Processor.from_pretrained(model_path)
19
 
20
+ def preprocess_audio(audio_file_path, target_sampling_rate=16000):
21
+ """
22
+ Preprocess the input audio file to the target sampling rate and format.
23
+ """
24
+ # Convert audio to target sampling rate using librosa
25
+ y, sr = librosa.load(audio_file_path, sr=target_sampling_rate)
26
+ return y, sr
27
 
28
+ def predict_audio_class(audio_file_path):
29
+ """
30
+ Predict the class of the input audio file using Wav2Vec 2.0 model.
31
+ """
32
  try:
33
+ # Preprocess audio
34
+ audio, sr = preprocess_audio(audio_file_path, target_sampling_rate=16000)
35
+
36
+ # Prepare the audio for the model
37
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True, truncation=True)
38
+
39
+ # Predict
40
  with torch.no_grad():
41
  outputs = model(**inputs)
42
+
43
  logits = outputs.logits
44
  predicted_index = logits.argmax(dim=1).item()
 
45
  confidence = torch.softmax(logits, dim=1).max().item() * 100
46
+ label = model.config.id2label[predicted_index]
47
+
48
+ return f"Predicted class: {label} with confidence: {confidence:.2f}%"
49
  except Exception as e:
50
+ logging.error(f"Error during processing: {e}")
51
+ return "Prediction failed due to an error."
52
+
53
+ def save_temp_audio(file):
54
+ """
55
+ Saves a temporary audio file, returns the path.
56
+ """
57
+ temp_dir = tempfile.gettempdir()
58
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir, suffix=".wav")
59
+ temp_file_path = temp_file.name
60
+ # Convert to WAV for consistency
61
+ AudioSegment.from_file(file).export(temp_file_path, format="wav")
62
+ return temp_file_path
63
+
64
+ def handle_audio_input(file_info):
65
+ """
66
+ Handles the input audio file for prediction.
67
+ """
68
+ audio_file_path = save_temp_audio(file_info)
69
+ prediction = predict_audio_class(audio_file_path)
70
+ os.unlink(audio_file_path) # Clean up temp file
71
+ return prediction
72
 
73
+ # Setup Gradio interface
74
  iface = gr.Interface(
75
+ fn=handle_audio_input,
76
+ inputs=gr.inputs.Audio(source="upload", type="file", label="Upload Audio"),
77
+ outputs="text",
78
+ title="Audio Class Prediction",
79
+ description="Predicts the class of uploaded audio files using a fine-tuned Wav2Vec 2.0 model."
80
  )
81
 
82
  iface.launch()