import json import gradio as gr import torch import PIL import os from models import ViTClassifier from datasets import load_dataset from transformers import TrainingArguments, ViTConfig, ViTForImageClassification from torchvision import transforms import pandas as pd def load_config(config_path): with open(config_path, 'r') as f: config = json.load(f) print("Config Loaded:", config) # Debugging return config def load_model(config, device='cuda'): device = torch.device(device if torch.cuda.is_available() else 'cpu') ckpt = torch.load(config['checkpoint_path'], map_location=device) print("Checkpoint Loaded:", ckpt.keys()) # Debugging model = ViTClassifier(config, device=device, dtype=torch.float32) print("Model Loaded:", model) # Debugging model.load_state_dict(ckpt['model']) return model.to(device).eval() def prepare_model_for_push(model, config): # Create a VisionTransformerConfig vit_config = ViTConfig( image_size=config['model']['input_size'], patch_size=config['model']['patch_size'], hidden_size=config['model']['hidden_size'], num_heads=config['model']['num_attention_heads'], num_layers=config['model']['num_hidden_layers'], mlp_ratio=4, # Common default for ViT hidden_dropout_prob=config['model']['hidden_dropout_prob'], attention_probs_dropout_prob=config['model']['attention_probs_dropout_prob'], layer_norm_eps=config['model']['layer_norm_eps'], num_classes=config['model']['num_classes'] ) # Create a VisionTransformer model vit_model = ViTForImageClassification(vit_config) # Copy the weights from your custom model to the VisionTransformer model state_dict = vit_model.state_dict() for key in state_dict.keys(): if key in model.state_dict(): state_dict[key] = model.state_dict()[key] vit_model.load_state_dict(state_dict) return vit_model, vit_config def run_inference(input_image, model): print("Input Image Type:", type(input_image)) # Debugging # Directly use the PIL Image object fake_prob = model.forward(input_image).item() result_description = get_result_description(fake_prob) return { "Fake Probability": fake_prob, "Result Description": result_description } def get_result_description(fake_prob): if fake_prob > 0.5: return "The image is likely a fake." else: return "The image is likely real." def run_evaluation(dataset_name, model, config, device): dataset = load_dataset(dataset_name) eval_df, accuracy = evaluate_model(model, dataset, config, device) return accuracy, eval_df.to_csv(index=False) def evaluate_model(model, dataset, config, device): device = torch.device(device if torch.cuda.is_available() else 'cpu') model.to(device).eval() norm_mean = config['preprocessing']['norm_mean'] norm_std = config['preprocessing']['norm_std'] resize_size = config['preprocessing']['resize_size'] crop_size = config['preprocessing']['crop_size'] augment_list = [ transforms.Resize(resize_size), transforms.CenterCrop(crop_size), transforms.ToTensor(), transforms.Normalize(mean=norm_mean, std=norm_std), transforms.ConvertImageDtype(torch.float32), ] preprocess = transforms.Compose(augment_list) true_labels = [] predicted_probs = [] predicted_labels = [] with torch.no_grad(): for sample in dataset: image = sample['image'] label = sample['label'] image = preprocess(image).unsqueeze(0).to(device) output = model.forward(image) prob = output.item() true_labels.append(label) predicted_probs.append(prob) predicted_labels.append(1 if prob > 0.5 else 0) eval_df = pd.DataFrame({ 'True Label': true_labels, 'Predicted Probability': predicted_probs, 'Predicted Label': predicted_labels }) accuracy = (eval_df['True Label'] == eval_df['Predicted Label']).mean() return eval_df, accuracy def main(): # Load configuration config_path = "config.json" config = load_config(config_path) # Load model device = config['device'] model = load_model(config, device=device) # Define Gradio interface for inference def gradio_interface(input_image): return run_inference(input_image, model) # Create Gradio Tabs with gr.Blocks() as demo: gr.Markdown("# Community Forensics Dataset (Transformers OTW)") with gr.Tab("Image Inference"): with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image") with gr.Column(): output = gr.JSON(label="Classification Result") input_image.change(fn=gradio_interface, inputs=input_image, outputs=output) # Launch the Gradio app demo.launch() if __name__ == "__main__": main()