Kabatubare commited on
Commit
dfabd2f
·
verified ·
1 Parent(s): 30a5efb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -1,14 +1,13 @@
 
1
  import torch
2
- import torch.nn.functional as F
3
  import librosa
4
- import numpy as np
5
  import gradio as gr
6
  from transformers import AutoModelForAudioClassification
7
  import logging
8
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
- # Load your model here
12
  model_path = "./"
13
  model = AutoModelForAudioClassification.from_pretrained(model_path)
14
 
@@ -17,26 +16,40 @@ def preprocess_audio(audio_path, sr=22050):
17
  audio, _ = librosa.effects.trim(audio)
18
  return audio, sr
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def extract_features(audio, sr):
21
  S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128, hop_length=512, n_fft=2048)
22
  S_DB = librosa.power_to_db(S, ref=np.max)
23
 
24
- # Reshape the spectrogram to a sequence of overlapping 16x16 patches
25
- patches = librosa.util.frame(S_DB.flatten(), frame_length=16*16, hop_length=(16-6)*(16-6)).T
26
- patches = patches.reshape(patches.shape[0], 16, 16)
27
-
28
- # Linear projection layer equivalent (patch embedding layer)
29
- patch_embeddings = patches.reshape(patches.shape[0], -1)
30
- patch_embeddings = torch.tensor(patch_embeddings).float()
31
 
32
- # Assuming positional embeddings and [CLS] token embedding are handled within the model
33
- return patch_embeddings.unsqueeze(0) # Add batch dimension for compatibility with model
34
 
35
  def predict_voice(audio_file_path):
36
  try:
37
  audio, sr = preprocess_audio(audio_file_path)
38
  features = extract_features(audio, sr)
39
 
 
 
 
 
40
  with torch.no_grad():
41
  outputs = model(features)
42
  logits = outputs.logits
@@ -57,7 +70,7 @@ iface = gr.Interface(
57
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
58
  outputs=gr.Text(label="Prediction"),
59
  title="Voice Authenticity Detection",
60
- description="This system uses advanced audio processing to detect whether a voice is real or AI-generated. Upload an audio file to see the results."
61
  )
62
 
63
  iface.launch()
 
1
+ import numpy as np
2
  import torch
 
3
  import librosa
 
4
  import gradio as gr
5
  from transformers import AutoModelForAudioClassification
6
  import logging
7
 
8
  logging.basicConfig(level=logging.INFO)
9
 
10
+ # Placeholder for loading your AST-compatible model
11
  model_path = "./"
12
  model = AutoModelForAudioClassification.from_pretrained(model_path)
13
 
 
16
  audio, _ = librosa.effects.trim(audio)
17
  return audio, sr
18
 
19
+ def extract_patches(S_DB, patch_size=16, patch_overlap=6):
20
+ stride = patch_size - patch_overlap
21
+ num_patches_x = (S_DB.shape[1] - patch_size) // stride + 1
22
+ num_patches_y = (S_DB.shape[0] - patch_size) // stride + 1
23
+ patches = []
24
+
25
+ for i in range(num_patches_y):
26
+ for j in range(num_patches_x):
27
+ start_i = i * stride
28
+ start_j = j * stride
29
+ patch = S_DB[start_i:start_i+patch_size, start_j:start_j+patch_size]
30
+ patches.append(patch.reshape(-1))
31
+
32
+ return np.array(patches)
33
+
34
  def extract_features(audio, sr):
35
  S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128, hop_length=512, n_fft=2048)
36
  S_DB = librosa.power_to_db(S, ref=np.max)
37
 
38
+ patches = extract_patches(S_DB)
39
+ patch_embeddings = torch.tensor(patches).float()
 
 
 
 
 
40
 
41
+ # Assuming the model includes a patch embedding layer internally
42
+ return patch_embeddings.unsqueeze(0) # Add batch dimension
43
 
44
  def predict_voice(audio_file_path):
45
  try:
46
  audio, sr = preprocess_audio(audio_file_path)
47
  features = extract_features(audio, sr)
48
 
49
+ # Flatten the patches to match the model's expected input shape
50
+ # Adjust this based on your AST model's input requirements
51
+ features = features.view(1, -1, 768) # Reshape assuming the model expects (batch_size, seq_len, embedding_dim)
52
+
53
  with torch.no_grad():
54
  outputs = model(features)
55
  logits = outputs.logits
 
70
  inputs=gr.Audio(label="Upload Audio File", type="filepath"),
71
  outputs=gr.Text(label="Prediction"),
72
  title="Voice Authenticity Detection",
73
+ description="Detects whether a voice is real or AI-generated using an advanced AST model. Upload an audio file to see the results."
74
  )
75
 
76
  iface.launch()