import gradio as gr import os import dataset import torch from model import Wav2Vec2BERT_Llama # init device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init model model = Wav2Vec2BERT_Llama().to(device) checkpoint_path = "ckpt/model_checkpoint.pth" if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) model_state_dict = checkpoint['model_state_dict'] # 处理模型状态字典 if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()): model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()} elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()): model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()} model.load_state_dict(model_state_dict) model.eval() else: raise FileNotFoundError(f"Not found checkpoint: {checkpoint_path}") def detect(dataset, model): with torch.no_grad(): for batch in dataset: main_features = { 'input_features': batch['main_features']['input_features'].to(device), 'attention_mask': batch['main_features']['attention_mask'].to(device) } prompt_features = [{ 'input_features': pf['input_features'].to(device), 'attention_mask': pf['attention_mask'].to(device) } for pf in batch['prompt_features']] def audio_deepfake_detection(demonstration_paths, audio_path, model): """Audio deepfake detection function""" # Replace with your actual detection logic print("Demonstration audio paths: {}".format(demonstration_paths)) print("Query audio path: {}".format(audio_path)) # dataset dataset = dataset.DemoDataset(demonstration_paths, audio_path) # Example return value, modify according to your model result = detect(dataset, model) # Return detection results and confidence scores return { "Is AI Generated": result["is_fake"], "Confidence": f"{result['confidence']:.2f}%" } with gr.Blocks() as demo: gr.Markdown( """ # Audio Deepfake Detection System This demo helps you detect whether an audio clip is AI-generated or authentic. """ ) gr.Markdown( """ ## Upload Audio **Note**: Supports common audio formats (wav, mp3, etc.). """ ) # Create container for demonstration audio with gr.Row(): # Demonstration audio file upload demonstration_audio_input = gr.File( file_count="multiple", file_types=["audio"], label="Demonstration Audios", ) # Add demonstration type selection demonstration_type = gr.Dropdown( choices=["bonafide", "spoof"], value="bonafide", label="Demonstration Label", ) # Query audio input component query_audio_input = gr.Audio( sources=["upload"], label="Query Audio (Audio for Detection)", type="filepath", ) # Submit button submit_btn = gr.Button(value="Start Detection", variant="primary") # Output results output_labels = gr.Json(label="Detection Results") # Set click event submit_btn.click( fn=audio_deepfake_detection, inputs=[demonstration_audio_input, demonstration_type, query_audio_input], outputs=[output_labels] ) # Examples section gr.Markdown("## Test Examples") gr.Examples( examples=[ ["examples/real_audio.wav", "bonafide", "examples/query_audio.wav"], ["examples/fake_audio.wav", "spoof", "examples/query_audio.wav"], ], inputs=[demonstration_audio_input, demonstration_type, query_audio_input], ) if __name__ == "__main__": demo.launch()