Kabatubare commited on
Commit
fe0bcff
·
verified ·
1 Parent(s): af80923

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -34
app.py CHANGED
@@ -1,49 +1,54 @@
 
 
1
  import numpy as np
2
  import torch
3
- import librosa
4
- import gradio as gr
5
- from transformers import AutoModelForAudioClassification, Wav2Vec2Processor
6
- import logging
7
 
8
- logging.basicConfig(level=logging.INFO)
9
 
10
- model_path = "./"
11
- model = AutoModelForAudioClassification.from_pretrained(model_path)
12
- processor = Wav2Vec2Processor.from_pretrained(model_path)
 
 
 
 
 
 
 
 
 
 
13
 
14
- def preprocess_audio(audio_path, sr=16000):
15
- audio, _ = librosa.load(audio_path, sr=sr)
16
- audio, _ = librosa.effects.trim(audio)
17
- return audio
18
-
19
- def extract_features(audio, sr=16000):
20
- inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
21
- return inputs
22
 
23
  def predict_voice(audio_file_path):
24
  try:
25
- audio = preprocess_audio(audio_file_path)
26
- features = extract_features(audio)
27
-
 
28
  with torch.no_grad():
29
- outputs = model(**features)
30
- logits = outputs.logits
31
- predicted_index = logits.argmax(dim=-1)
32
- label = processor.decode(predicted_index)
33
- confidence = torch.softmax(logits, dim=-1).max().item() * 100
34
-
35
- result = f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
36
- logging.info("Prediction successful.")
37
  except Exception as e:
38
- result = f"Error during processing: {e}"
39
- logging.error(result)
40
-
41
- return result
42
 
43
  iface = gr.Interface(
44
  fn=predict_voice,
45
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
46
- outputs=gr.Text(label="Prediction"),
47
  title="Voice Authenticity Detection",
48
- description="This system uses advanced audio processing to detect whether a voice is real or AI-generated. Upload an audio file to see the results."
49
- ).launch()
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
  import numpy as np
4
  import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoModelForAudioClassification
7
+ import random
 
8
 
9
+ model = AutoModelForAudioClassification.from_pretrained("./")
10
 
11
+ def custom_feature_extraction(audio_file_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512, target_length=1024):
12
+ waveform, sample_rate = librosa.load(audio_file_path, sr=sr)
13
+ S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
14
+ S_DB = librosa.power_to_db(S, ref=np.max)
15
+ pitches, _ = librosa.piptrack(y=waveform, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
16
+ spectral_centroids = librosa.feature.spectral_centroid(y=waveform, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
17
+ features = np.concatenate([S_DB, pitches, spectral_centroids], axis=0)
18
+ features_tensor = torch.tensor(features).float()
19
+ if features_tensor.shape[1] > target_length:
20
+ features_tensor = features_tensor[:, :target_length]
21
+ else:
22
+ features_tensor = F.pad(features_tensor, (0, target_length - features_tensor.shape[1]), 'constant', 0)
23
+ return features_tensor.unsqueeze(0)
24
 
25
+ def apply_time_shift(waveform, max_shift_fraction=0.1):
26
+ shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
27
+ return np.roll(waveform, shift)
 
 
 
 
 
28
 
29
  def predict_voice(audio_file_path):
30
  try:
31
+ waveform, sample_rate = librosa.load(audio_file_path, sr=None)
32
+ augmented_waveform = apply_time_shift(waveform)
33
+ original_features = custom_feature_extraction(audio_file_path, sr=sample_rate)
34
+ augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
35
  with torch.no_grad():
36
+ outputs_original = model(original_features)
37
+ outputs_augmented = model(augmented_features)
38
+ logits = (outputs_original.logits + outputs_augmented.logits) / 2
39
+ predicted_index = logits.argmax()
40
+ label = model.config.id2label[predicted_index.item()]
41
+ confidence = torch.softmax(logits, dim=1).max().item() * 100
42
+ return f"The voice is classified as '{label}' with a confidence of {confidence:.2f}%."
 
43
  except Exception as e:
44
+ return f"Error during processing: {e}"
 
 
 
45
 
46
  iface = gr.Interface(
47
  fn=predict_voice,
48
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
49
+ outputs=gr.Textbox(label="Prediction"),
50
  title="Voice Authenticity Detection",
51
+ description="Detects whether a voice is real or AI-generated. Upload an audio file to see the results."
52
+ )
53
+
54
+ iface.launch()