File size: 5,084 Bytes
c0f6bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import json
import gradio as gr
import torch
import PIL
import os
from models import ViTClassifier
from datasets import load_dataset
from transformers import TrainingArguments, ViTConfig, ViTForImageClassification
from torchvision import transforms
import pandas as pd
def load_config(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
print("Config Loaded:", config) # Debugging
return config
def load_model(config, device='cuda'):
device = torch.device(device if torch.cuda.is_available() else 'cpu')
ckpt = torch.load(config['checkpoint_path'], map_location=device)
print("Checkpoint Loaded:", ckpt.keys()) # Debugging
model = ViTClassifier(config, device=device, dtype=torch.float32)
print("Model Loaded:", model) # Debugging
model.load_state_dict(ckpt['model'])
return model.to(device).eval()
def prepare_model_for_push(model, config):
# Create a VisionTransformerConfig
vit_config = ViTConfig(
image_size=config['model']['input_size'],
patch_size=config['model']['patch_size'],
hidden_size=config['model']['hidden_size'],
num_heads=config['model']['num_attention_heads'],
num_layers=config['model']['num_hidden_layers'],
mlp_ratio=4, # Common default for ViT
hidden_dropout_prob=config['model']['hidden_dropout_prob'],
attention_probs_dropout_prob=config['model']['attention_probs_dropout_prob'],
layer_norm_eps=config['model']['layer_norm_eps'],
num_classes=config['model']['num_classes']
)
# Create a VisionTransformer model
vit_model = ViTForImageClassification(vit_config)
# Copy the weights from your custom model to the VisionTransformer model
state_dict = vit_model.state_dict()
for key in state_dict.keys():
if key in model.state_dict():
state_dict[key] = model.state_dict()[key]
vit_model.load_state_dict(state_dict)
return vit_model, vit_config
def run_inference(input_image, model):
print("Input Image Type:", type(input_image)) # Debugging
# Directly use the PIL Image object
fake_prob = model.forward(input_image).item()
result_description = get_result_description(fake_prob)
return {
"Fake Probability": fake_prob,
"Result Description": result_description
}
def get_result_description(fake_prob):
if fake_prob > 0.5:
return "The image is likely a fake."
else:
return "The image is likely real."
def run_evaluation(dataset_name, model, config, device):
dataset = load_dataset(dataset_name)
eval_df, accuracy = evaluate_model(model, dataset, config, device)
return accuracy, eval_df.to_csv(index=False)
def evaluate_model(model, dataset, config, device):
device = torch.device(device if torch.cuda.is_available() else 'cpu')
model.to(device).eval()
norm_mean = config['preprocessing']['norm_mean']
norm_std = config['preprocessing']['norm_std']
resize_size = config['preprocessing']['resize_size']
crop_size = config['preprocessing']['crop_size']
augment_list = [
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=norm_mean, std=norm_std),
transforms.ConvertImageDtype(torch.float32),
]
preprocess = transforms.Compose(augment_list)
true_labels = []
predicted_probs = []
predicted_labels = []
with torch.no_grad():
for sample in dataset:
image = sample['image']
label = sample['label']
image = preprocess(image).unsqueeze(0).to(device)
output = model.forward(image)
prob = output.item()
true_labels.append(label)
predicted_probs.append(prob)
predicted_labels.append(1 if prob > 0.5 else 0)
eval_df = pd.DataFrame({
'True Label': true_labels,
'Predicted Probability': predicted_probs,
'Predicted Label': predicted_labels
})
accuracy = (eval_df['True Label'] == eval_df['Predicted Label']).mean()
return eval_df, accuracy
def main():
# Load configuration
config_path = "config.json"
config = load_config(config_path)
# Load model
device = config['device']
model = load_model(config, device=device)
# Define Gradio interface for inference
def gradio_interface(input_image):
return run_inference(input_image, model)
# Create Gradio Tabs
with gr.Blocks() as demo:
gr.Markdown("# Deepfake Detection")
with gr.Tab("Image Inference"):
with gr.Row():
with gr.Column():
gr.Markdown("## Upload Image for Evaluation")
input_image = gr.Image(type="pil", label="Upload Image")
with gr.Column():
output = gr.JSON(label="Classification Result")
input_image.change(fn=gradio_interface, inputs=input_image, outputs=output)
# Launch the Gradio app
demo.launch()
if __name__ == "__main__":
main() |