|
import json |
|
import gradio as gr |
|
import torch |
|
import PIL |
|
import os |
|
from models import ViTClassifier |
|
from datasets import load_dataset |
|
from transformers import TrainingArguments, ViTConfig, ViTForImageClassification |
|
from torchvision import transforms |
|
import pandas as pd |
|
|
|
def load_config(config_path): |
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
print("Config Loaded:", config) |
|
return config |
|
|
|
def load_model(config, device='cuda'): |
|
device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
ckpt = torch.load(config['checkpoint_path'], map_location=device) |
|
print("Checkpoint Loaded:", ckpt.keys()) |
|
model = ViTClassifier(config, device=device, dtype=torch.float32) |
|
print("Model Loaded:", model) |
|
model.load_state_dict(ckpt['model']) |
|
return model.to(device).eval() |
|
|
|
def prepare_model_for_push(model, config): |
|
|
|
vit_config = ViTConfig( |
|
image_size=config['model']['input_size'], |
|
patch_size=config['model']['patch_size'], |
|
hidden_size=config['model']['hidden_size'], |
|
num_heads=config['model']['num_attention_heads'], |
|
num_layers=config['model']['num_hidden_layers'], |
|
mlp_ratio=4, |
|
hidden_dropout_prob=config['model']['hidden_dropout_prob'], |
|
attention_probs_dropout_prob=config['model']['attention_probs_dropout_prob'], |
|
layer_norm_eps=config['model']['layer_norm_eps'], |
|
num_classes=config['model']['num_classes'] |
|
) |
|
|
|
vit_model = ViTForImageClassification(vit_config) |
|
|
|
state_dict = vit_model.state_dict() |
|
for key in state_dict.keys(): |
|
if key in model.state_dict(): |
|
state_dict[key] = model.state_dict()[key] |
|
vit_model.load_state_dict(state_dict) |
|
return vit_model, vit_config |
|
|
|
def run_inference(input_image, model): |
|
print("Input Image Type:", type(input_image)) |
|
|
|
fake_prob = model.forward(input_image).item() |
|
result_description = get_result_description(fake_prob) |
|
return { |
|
"Fake Probability": fake_prob, |
|
"Result Description": result_description |
|
} |
|
|
|
def get_result_description(fake_prob): |
|
if fake_prob > 0.5: |
|
return "The image is likely a fake." |
|
else: |
|
return "The image is likely real." |
|
|
|
def run_evaluation(dataset_name, model, config, device): |
|
dataset = load_dataset(dataset_name) |
|
eval_df, accuracy = evaluate_model(model, dataset, config, device) |
|
return accuracy, eval_df.to_csv(index=False) |
|
|
|
def evaluate_model(model, dataset, config, device): |
|
device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
model.to(device).eval() |
|
norm_mean = config['preprocessing']['norm_mean'] |
|
norm_std = config['preprocessing']['norm_std'] |
|
resize_size = config['preprocessing']['resize_size'] |
|
crop_size = config['preprocessing']['crop_size'] |
|
augment_list = [ |
|
transforms.Resize(resize_size), |
|
transforms.CenterCrop(crop_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=norm_mean, std=norm_std), |
|
transforms.ConvertImageDtype(torch.float32), |
|
] |
|
preprocess = transforms.Compose(augment_list) |
|
true_labels = [] |
|
predicted_probs = [] |
|
predicted_labels = [] |
|
with torch.no_grad(): |
|
for sample in dataset: |
|
image = sample['image'] |
|
label = sample['label'] |
|
image = preprocess(image).unsqueeze(0).to(device) |
|
output = model.forward(image) |
|
prob = output.item() |
|
true_labels.append(label) |
|
predicted_probs.append(prob) |
|
predicted_labels.append(1 if prob > 0.5 else 0) |
|
eval_df = pd.DataFrame({ |
|
'True Label': true_labels, |
|
'Predicted Probability': predicted_probs, |
|
'Predicted Label': predicted_labels |
|
}) |
|
accuracy = (eval_df['True Label'] == eval_df['Predicted Label']).mean() |
|
return eval_df, accuracy |
|
|
|
def main(): |
|
|
|
config_path = "config.json" |
|
config = load_config(config_path) |
|
|
|
device = config['device'] |
|
model = load_model(config, device=device) |
|
|
|
def gradio_interface(input_image): |
|
return run_inference(input_image, model) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Community Forensics Dataset (Transformers OTW)") |
|
with gr.Tab("Image Inference"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
with gr.Column(): |
|
output = gr.JSON(label="Classification Result") |
|
input_image.change(fn=gradio_interface, inputs=input_image, outputs=output) |
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |