LPX55 commited on
Commit
c0f6bb1
·
verified ·
1 Parent(s): d5f3277

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import torch
4
+ import PIL
5
+ import os
6
+ from models import ViTClassifier
7
+ from datasets import load_dataset
8
+ from transformers import TrainingArguments, ViTConfig, ViTForImageClassification
9
+ from torchvision import transforms
10
+ import pandas as pd
11
+
12
+ def load_config(config_path):
13
+ with open(config_path, 'r') as f:
14
+ config = json.load(f)
15
+ print("Config Loaded:", config) # Debugging
16
+ return config
17
+
18
+ def load_model(config, device='cuda'):
19
+ device = torch.device(device if torch.cuda.is_available() else 'cpu')
20
+ ckpt = torch.load(config['checkpoint_path'], map_location=device)
21
+ print("Checkpoint Loaded:", ckpt.keys()) # Debugging
22
+ model = ViTClassifier(config, device=device, dtype=torch.float32)
23
+ print("Model Loaded:", model) # Debugging
24
+ model.load_state_dict(ckpt['model'])
25
+ return model.to(device).eval()
26
+
27
+ def prepare_model_for_push(model, config):
28
+ # Create a VisionTransformerConfig
29
+ vit_config = ViTConfig(
30
+ image_size=config['model']['input_size'],
31
+ patch_size=config['model']['patch_size'],
32
+ hidden_size=config['model']['hidden_size'],
33
+ num_heads=config['model']['num_attention_heads'],
34
+ num_layers=config['model']['num_hidden_layers'],
35
+ mlp_ratio=4, # Common default for ViT
36
+ hidden_dropout_prob=config['model']['hidden_dropout_prob'],
37
+ attention_probs_dropout_prob=config['model']['attention_probs_dropout_prob'],
38
+ layer_norm_eps=config['model']['layer_norm_eps'],
39
+ num_classes=config['model']['num_classes']
40
+ )
41
+ # Create a VisionTransformer model
42
+ vit_model = ViTForImageClassification(vit_config)
43
+ # Copy the weights from your custom model to the VisionTransformer model
44
+ state_dict = vit_model.state_dict()
45
+ for key in state_dict.keys():
46
+ if key in model.state_dict():
47
+ state_dict[key] = model.state_dict()[key]
48
+ vit_model.load_state_dict(state_dict)
49
+ return vit_model, vit_config
50
+
51
+ def run_inference(input_image, model):
52
+ print("Input Image Type:", type(input_image)) # Debugging
53
+ # Directly use the PIL Image object
54
+ fake_prob = model.forward(input_image).item()
55
+ result_description = get_result_description(fake_prob)
56
+ return {
57
+ "Fake Probability": fake_prob,
58
+ "Result Description": result_description
59
+ }
60
+
61
+ def get_result_description(fake_prob):
62
+ if fake_prob > 0.5:
63
+ return "The image is likely a fake."
64
+ else:
65
+ return "The image is likely real."
66
+
67
+ def run_evaluation(dataset_name, model, config, device):
68
+ dataset = load_dataset(dataset_name)
69
+ eval_df, accuracy = evaluate_model(model, dataset, config, device)
70
+ return accuracy, eval_df.to_csv(index=False)
71
+
72
+ def evaluate_model(model, dataset, config, device):
73
+ device = torch.device(device if torch.cuda.is_available() else 'cpu')
74
+ model.to(device).eval()
75
+ norm_mean = config['preprocessing']['norm_mean']
76
+ norm_std = config['preprocessing']['norm_std']
77
+ resize_size = config['preprocessing']['resize_size']
78
+ crop_size = config['preprocessing']['crop_size']
79
+ augment_list = [
80
+ transforms.Resize(resize_size),
81
+ transforms.CenterCrop(crop_size),
82
+ transforms.ToTensor(),
83
+ transforms.Normalize(mean=norm_mean, std=norm_std),
84
+ transforms.ConvertImageDtype(torch.float32),
85
+ ]
86
+ preprocess = transforms.Compose(augment_list)
87
+ true_labels = []
88
+ predicted_probs = []
89
+ predicted_labels = []
90
+ with torch.no_grad():
91
+ for sample in dataset:
92
+ image = sample['image']
93
+ label = sample['label']
94
+ image = preprocess(image).unsqueeze(0).to(device)
95
+ output = model.forward(image)
96
+ prob = output.item()
97
+ true_labels.append(label)
98
+ predicted_probs.append(prob)
99
+ predicted_labels.append(1 if prob > 0.5 else 0)
100
+ eval_df = pd.DataFrame({
101
+ 'True Label': true_labels,
102
+ 'Predicted Probability': predicted_probs,
103
+ 'Predicted Label': predicted_labels
104
+ })
105
+ accuracy = (eval_df['True Label'] == eval_df['Predicted Label']).mean()
106
+ return eval_df, accuracy
107
+
108
+ def main():
109
+ # Load configuration
110
+ config_path = "config.json"
111
+ config = load_config(config_path)
112
+ # Load model
113
+ device = config['device']
114
+ model = load_model(config, device=device)
115
+ # Define Gradio interface for inference
116
+ def gradio_interface(input_image):
117
+ return run_inference(input_image, model)
118
+ # Create Gradio Tabs
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("# Deepfake Detection")
121
+ with gr.Tab("Image Inference"):
122
+ with gr.Row():
123
+ with gr.Column():
124
+ gr.Markdown("## Upload Image for Evaluation")
125
+ input_image = gr.Image(type="pil", label="Upload Image")
126
+ with gr.Column():
127
+ output = gr.JSON(label="Classification Result")
128
+ input_image.change(fn=gradio_interface, inputs=input_image, outputs=output)
129
+ # Launch the Gradio app
130
+ demo.launch()
131
+
132
+ if __name__ == "__main__":
133
+ main()