import os os.system("python -m pip install --upgrade pip") os.system("pip install git+https://github.com/rwightman/pytorch-image-models") os.system("pip install git+https://github.com/huggingface/huggingface_hub") import gradio as gr import timm import torch from torch import nn from torch.nn import functional as F import torchvision class Model200M(torch.nn.Module): def __init__(self): super().__init__() self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False, num_classes=0) self.clf = nn.Sequential( nn.Linear(1536, 128), nn.ReLU(inplace=True), nn.Linear(128, 2)) def forward(self, image): image_features = self.model(image) return self.clf(image_features) class Model5M(torch.nn.Module): def __init__(self): super().__init__() self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0) self.clf = nn.Sequential( nn.Linear(1280, 128), nn.ReLU(inplace=True), nn.Linear(128, 2)) def forward(self, image): image_features = self.model(image) return self.clf(image_features) def load_model(name: str): model = Model200M() if "200M" in name else Model5M() ckpt = torch.load(name, map_location=torch.device('cpu')) model.load_state_dict(ckpt) model.eval() return model model_list = { 'midjourney_200M': load_model('models/midjourney200M.pt'), 'diffusions_200M': load_model('models/diffusions200M.pt'), 'midjourney_5M': load_model('models/midjourney5M.pt'), 'diffusions_5M': load_model('models/diffusions5M.pt') } tfm = torchvision.transforms.Compose([ torchvision.transforms.Resize((640, 640)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) tfm_small = torchvision.transforms.Compose([ torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict_from_model(model, img_1): y = model.forward(img_1[None, ...]) y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy() y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy() return {'created by AI': y_1.tolist(), 'created by human': y_2.tolist()} def predict(raw_image, model_name): img_1 = tfm(raw_image) img_2 = tfm_small(raw_image) if model_name not in model_list: return {'error': [0.]} model = model_list[model_name] img = img_1 if "200M" in model_name else img_2 return predict_from_model(model, img) general_examples = [ ["images/general/img_1.jpg"], ["images/general/img_2.jpg"], ["images/general/img_3.jpg"], ["images/general/img_4.jpg"], ["images/general/img_5.jpg"], ["images/general/img_6.jpg"], ["images/general/img_7.jpg"], ["images/general/img_8.jpg"], ["images/general/img_9.jpg"], ["images/general/img_10.jpg"], ] optic_examples = [ ["images/optic/img_1.jpg"], ["images/optic/img_2.jpg"], ["images/optic/img_3.jpg"], ["images/optic/img_4.jpg"], ["images/optic/img_5.jpg"], ] famous_deepfake_examples = [ ["images/famous_deepfakes/img_1.jpg"], ["images/famous_deepfakes/img_2.jpg"], ["images/famous_deepfakes/img_3.jpg"], ["images/famous_deepfakes/img_4.webp"], ] with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """

For Fake's Sake: a set of models for detecting generated and synthetic images

This is a demo space for synthetic image detectors: midjourney200M, midjourney5M, diffusions200M, diffusions5M.
We provide several detectors for images generated by popular tools, such as Midjourney and Stable Diffusion.
Please refer to model cards for evaluation metrics and limitations. """ ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil") drop_down = gr.Dropdown(model_list.keys(), type="value", label="Model", value="diffusions_200M") with gr.Row(): gr.ClearButton(components=[image_input]) submit_button = gr.Button("Submit", variant="primary") with gr.Column(): result_score = gr.Label(label='result', num_top_classes=2) with gr.Tab("Examples"): gr.Examples(examples=general_examples, inputs=image_input) # with gr.Tab("More examples"): # gr.Examples(examples=optic_examples, inputs=image_input) with gr.Tab("Widely known deepfakes"): gr.Examples(examples=famous_deepfake_examples, inputs=image_input) submit_button.click(predict, inputs=[image_input, drop_down], outputs=result_score) gr.Markdown( """

Models

*_200M models are based on convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384 with image size 640x640

*_5M models are based on tf_mobilenetv3_large_100.in1k with image size 224x224

Details

  • Model cards: midjourney200M, midjourney5M, diffusions200M, diffusions5M.
  • License: CC-By-SA-3.0
  • """ ) demo.launch()