File size: 5,084 Bytes
c0f6bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import json
import gradio as gr
import torch
import PIL
import os
from models import ViTClassifier
from datasets import load_dataset
from transformers import TrainingArguments, ViTConfig, ViTForImageClassification
from torchvision import transforms
import pandas as pd

def load_config(config_path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    print("Config Loaded:", config)  # Debugging
    return config

def load_model(config, device='cuda'):
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    ckpt = torch.load(config['checkpoint_path'], map_location=device)
    print("Checkpoint Loaded:", ckpt.keys())  # Debugging
    model = ViTClassifier(config, device=device, dtype=torch.float32)
    print("Model Loaded:", model)  # Debugging
    model.load_state_dict(ckpt['model'])
    return model.to(device).eval()

def prepare_model_for_push(model, config):
    # Create a VisionTransformerConfig
    vit_config = ViTConfig(
        image_size=config['model']['input_size'],
        patch_size=config['model']['patch_size'],
        hidden_size=config['model']['hidden_size'],
        num_heads=config['model']['num_attention_heads'],
        num_layers=config['model']['num_hidden_layers'],
        mlp_ratio=4,  # Common default for ViT
        hidden_dropout_prob=config['model']['hidden_dropout_prob'],
        attention_probs_dropout_prob=config['model']['attention_probs_dropout_prob'],
        layer_norm_eps=config['model']['layer_norm_eps'],
        num_classes=config['model']['num_classes']
    )
    # Create a VisionTransformer model
    vit_model = ViTForImageClassification(vit_config)
    # Copy the weights from your custom model to the VisionTransformer model
    state_dict = vit_model.state_dict()
    for key in state_dict.keys():
        if key in model.state_dict():
            state_dict[key] = model.state_dict()[key]
    vit_model.load_state_dict(state_dict)
    return vit_model, vit_config

def run_inference(input_image, model):
    print("Input Image Type:", type(input_image))  # Debugging
    # Directly use the PIL Image object
    fake_prob = model.forward(input_image).item()
    result_description = get_result_description(fake_prob)
    return {
        "Fake Probability": fake_prob,
        "Result Description": result_description
    }

def get_result_description(fake_prob):
    if fake_prob > 0.5:
        return "The image is likely a fake."
    else:
        return "The image is likely real."

def run_evaluation(dataset_name, model, config, device):
    dataset = load_dataset(dataset_name)
    eval_df, accuracy = evaluate_model(model, dataset, config, device)
    return accuracy, eval_df.to_csv(index=False)

def evaluate_model(model, dataset, config, device):
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    model.to(device).eval()
    norm_mean = config['preprocessing']['norm_mean']
    norm_std = config['preprocessing']['norm_std']
    resize_size = config['preprocessing']['resize_size']
    crop_size = config['preprocessing']['crop_size']
    augment_list = [
        transforms.Resize(resize_size),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std),
        transforms.ConvertImageDtype(torch.float32),
    ]
    preprocess = transforms.Compose(augment_list)
    true_labels = []
    predicted_probs = []
    predicted_labels = []
    with torch.no_grad():
        for sample in dataset:
            image = sample['image']
            label = sample['label']
            image = preprocess(image).unsqueeze(0).to(device)
            output = model.forward(image)
            prob = output.item()
            true_labels.append(label)
            predicted_probs.append(prob)
            predicted_labels.append(1 if prob > 0.5 else 0)
    eval_df = pd.DataFrame({
        'True Label': true_labels,
        'Predicted Probability': predicted_probs,
        'Predicted Label': predicted_labels
    })
    accuracy = (eval_df['True Label'] == eval_df['Predicted Label']).mean()
    return eval_df, accuracy

def main():
    # Load configuration
    config_path = "config.json"
    config = load_config(config_path)
    # Load model
    device = config['device']
    model = load_model(config, device=device)
    # Define Gradio interface for inference
    def gradio_interface(input_image):
        return run_inference(input_image, model)
    # Create Gradio Tabs
    with gr.Blocks() as demo:
        gr.Markdown("# Deepfake Detection")
        with gr.Tab("Image Inference"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## Upload Image for Evaluation")
                    input_image = gr.Image(type="pil", label="Upload Image")
                with gr.Column():
                    output = gr.JSON(label="Classification Result")
            input_image.change(fn=gradio_interface, inputs=input_image, outputs=output)
    # Launch the Gradio app
    demo.launch()

if __name__ == "__main__":
    main()