reagvis commited on
Commit
b9ec101
·
verified ·
1 Parent(s): 824db7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
 
4
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
5
 
6
  # Load the HF feature extractor and model
@@ -11,17 +12,24 @@ model = AutoModelForAudioClassification.from_pretrained(
11
  "MelodyMachine/Deepfake-audio-detection-V2"
12
  )
13
 
 
 
14
  def detect_deepfake_audio(audio_path: str) -> str:
15
  # Load audio file
16
- waveform, sample_rate = torchaudio.load(audio_path)
17
 
18
  # Mix to mono if necessary
19
  if waveform.shape[0] > 1:
20
  waveform = torch.mean(waveform, dim=0, keepdim=True)
21
 
 
 
 
 
 
22
  # Prepare inputs
23
  inputs = feature_extractor(
24
- waveform, sampling_rate=sample_rate, return_tensors="pt"
25
  )
26
  with torch.no_grad():
27
  outputs = model(**inputs)
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ from torchaudio.transforms import Resample
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
 
7
  # Load the HF feature extractor and model
 
12
  "MelodyMachine/Deepfake-audio-detection-V2"
13
  )
14
 
15
+ TARGET_SR = feature_extractor.sampling_rate # should be 16000
16
+
17
  def detect_deepfake_audio(audio_path: str) -> str:
18
  # Load audio file
19
+ waveform, orig_sr = torchaudio.load(audio_path)
20
 
21
  # Mix to mono if necessary
22
  if waveform.shape[0] > 1:
23
  waveform = torch.mean(waveform, dim=0, keepdim=True)
24
 
25
+ # Resample if not already 16 kHz
26
+ if orig_sr != TARGET_SR:
27
+ resampler = Resample(orig_sr, TARGET_SR)
28
+ waveform = resampler(waveform)
29
+
30
  # Prepare inputs
31
  inputs = feature_extractor(
32
+ waveform, sampling_rate=TARGET_SR, return_tensors="pt"
33
  )
34
  with torch.no_grad():
35
  outputs = model(**inputs)