Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import dataset | |
import torch | |
from model import Wav2Vec2BERT_Llama | |
# init | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# init model | |
model = Wav2Vec2BERT_Llama().to(device) | |
checkpoint_path = "ckpt/model_checkpoint.pth" | |
if os.path.exists(checkpoint_path): | |
checkpoint = torch.load(checkpoint_path) | |
model_state_dict = checkpoint['model_state_dict'] | |
# 处理模型状态字典 | |
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()): | |
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()} | |
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()): | |
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()} | |
model.load_state_dict(model_state_dict) | |
model.eval() | |
else: | |
raise FileNotFoundError(f"Not found checkpoint: {checkpoint_path}") | |
def detect(dataset, model): | |
with torch.no_grad(): | |
for batch in dataset: | |
main_features = { | |
'input_features': batch['main_features']['input_features'].to(device), | |
'attention_mask': batch['main_features']['attention_mask'].to(device) | |
} | |
prompt_features = [{ | |
'input_features': pf['input_features'].to(device), | |
'attention_mask': pf['attention_mask'].to(device) | |
} for pf in batch['prompt_features']] | |
def audio_deepfake_detection(demonstration_paths, audio_path, model): | |
"""Audio deepfake detection function""" | |
# Replace with your actual detection logic | |
print("Demonstration audio paths: {}".format(demonstration_paths)) | |
print("Query audio path: {}".format(audio_path)) | |
# dataset | |
dataset = dataset.DemoDataset(demonstration_paths, audio_path) | |
# Example return value, modify according to your model | |
result = detect(dataset, model) | |
# Return detection results and confidence scores | |
return { | |
"Is AI Generated": result["is_fake"], | |
"Confidence": f"{result['confidence']:.2f}%" | |
} | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Audio Deepfake Detection System | |
This demo helps you detect whether an audio clip is AI-generated or authentic. | |
""" | |
) | |
gr.Markdown( | |
""" | |
## Upload Audio | |
**Note**: Supports common audio formats (wav, mp3, etc.). | |
""" | |
) | |
# Create container for demonstration audio | |
with gr.Row(): | |
# Demonstration audio file upload | |
demonstration_audio_input = gr.File( | |
file_count="multiple", | |
file_types=["audio"], | |
label="Demonstration Audios", | |
) | |
# Add demonstration type selection | |
demonstration_type = gr.Dropdown( | |
choices=["bonafide", "spoof"], | |
value="bonafide", | |
label="Demonstration Label", | |
) | |
# Query audio input component | |
query_audio_input = gr.Audio( | |
sources=["upload"], | |
label="Query Audio (Audio for Detection)", | |
type="filepath", | |
) | |
# Submit button | |
submit_btn = gr.Button(value="Start Detection", variant="primary") | |
# Output results | |
output_labels = gr.Json(label="Detection Results") | |
# Set click event | |
submit_btn.click( | |
fn=audio_deepfake_detection, | |
inputs=[demonstration_audio_input, demonstration_type, query_audio_input], | |
outputs=[output_labels] | |
) | |
# Examples section | |
gr.Markdown("## Test Examples") | |
gr.Examples( | |
examples=[ | |
["examples/real_audio.wav", "bonafide", "examples/query_audio.wav"], | |
["examples/fake_audio.wav", "spoof", "examples/query_audio.wav"], | |
], | |
inputs=[demonstration_audio_input, demonstration_type, query_audio_input], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |