wli3221134's picture
Update app.py
3b375a6 verified
raw
history blame
3.97 kB
import gradio as gr
import os
import dataset
import torch
from model import Wav2Vec2BERT_Llama
# init
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# init model
model = Wav2Vec2BERT_Llama().to(device)
checkpoint_path = "ckpt/model_checkpoint.pth"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model_state_dict = checkpoint['model_state_dict']
# 处理模型状态字典
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict)
model.eval()
else:
raise FileNotFoundError(f"Not found checkpoint: {checkpoint_path}")
def detect(dataset, model):
with torch.no_grad():
for batch in dataset:
main_features = {
'input_features': batch['main_features']['input_features'].to(device),
'attention_mask': batch['main_features']['attention_mask'].to(device)
}
prompt_features = [{
'input_features': pf['input_features'].to(device),
'attention_mask': pf['attention_mask'].to(device)
} for pf in batch['prompt_features']]
def audio_deepfake_detection(demonstration_paths, audio_path, model):
"""Audio deepfake detection function"""
# Replace with your actual detection logic
print("Demonstration audio paths: {}".format(demonstration_paths))
print("Query audio path: {}".format(audio_path))
# dataset
dataset = dataset.DemoDataset(demonstration_paths, audio_path)
# Example return value, modify according to your model
result = detect(dataset, model)
# Return detection results and confidence scores
return {
"Is AI Generated": result["is_fake"],
"Confidence": f"{result['confidence']:.2f}%"
}
with gr.Blocks() as demo:
gr.Markdown(
"""
# Audio Deepfake Detection System
This demo helps you detect whether an audio clip is AI-generated or authentic.
"""
)
gr.Markdown(
"""
## Upload Audio
**Note**: Supports common audio formats (wav, mp3, etc.).
"""
)
# Create container for demonstration audio
with gr.Row():
# Demonstration audio file upload
demonstration_audio_input = gr.File(
file_count="multiple",
file_types=["audio"],
label="Demonstration Audios",
)
# Add demonstration type selection
demonstration_type = gr.Dropdown(
choices=["bonafide", "spoof"],
value="bonafide",
label="Demonstration Label",
)
# Query audio input component
query_audio_input = gr.Audio(
sources=["upload"],
label="Query Audio (Audio for Detection)",
type="filepath",
)
# Submit button
submit_btn = gr.Button(value="Start Detection", variant="primary")
# Output results
output_labels = gr.Json(label="Detection Results")
# Set click event
submit_btn.click(
fn=audio_deepfake_detection,
inputs=[demonstration_audio_input, demonstration_type, query_audio_input],
outputs=[output_labels]
)
# Examples section
gr.Markdown("## Test Examples")
gr.Examples(
examples=[
["examples/real_audio.wav", "bonafide", "examples/query_audio.wav"],
["examples/fake_audio.wav", "spoof", "examples/query_audio.wav"],
],
inputs=[demonstration_audio_input, demonstration_type, query_audio_input],
)
if __name__ == "__main__":
demo.launch()