|
import os |
|
os.system("python -m pip install --upgrade pip") |
|
os.system("pip install git+https://github.com/rwightman/pytorch-image-models") |
|
os.system("pip install git+https://github.com/huggingface/huggingface_hub") |
|
|
|
import gradio as gr |
|
import timm |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import torchvision |
|
|
|
|
|
class Model200M(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False, |
|
num_classes=0) |
|
|
|
self.clf = nn.Sequential( |
|
nn.Linear(1536, 128), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(128, 2)) |
|
|
|
def forward(self, image): |
|
image_features = self.model(image) |
|
return self.clf(image_features) |
|
|
|
|
|
class Model5M(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0) |
|
|
|
self.clf = nn.Sequential( |
|
nn.Linear(1280, 128), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(128, 2)) |
|
|
|
def forward(self, image): |
|
image_features = self.model(image) |
|
return self.clf(image_features) |
|
|
|
def load_model(name: str): |
|
model = Model200M() if "200M" in name else Model5M() |
|
ckpt = torch.load(name, map_location=torch.device('cpu')) |
|
model.load_state_dict(ckpt) |
|
model.eval() |
|
return model |
|
|
|
model_list = { |
|
'midjourney_200M': load_model('models/midjourney200M.pt'), |
|
'diffusions_200M': load_model('models/diffusions200M.pt'), |
|
'midjourney_5M': load_model('models/midjourney5M.pt'), |
|
'diffusions_5M': load_model('models/diffusions5M.pt') |
|
} |
|
|
|
tfm = torchvision.transforms.Compose([ |
|
torchvision.transforms.Resize((640, 640)), |
|
torchvision.transforms.ToTensor(), |
|
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
tfm_small = torchvision.transforms.Compose([ |
|
torchvision.transforms.Resize((224, 224)), |
|
torchvision.transforms.ToTensor(), |
|
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def predict_from_model(model, img_1): |
|
y = model.forward(img_1[None, ...]) |
|
y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy() |
|
y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy() |
|
return {'created by AI': y_1.tolist(), |
|
'created by human': y_2.tolist()} |
|
|
|
|
|
def predict(raw_image, model_name): |
|
img_1 = tfm(raw_image) |
|
img_2 = tfm_small(raw_image) |
|
|
|
if model_name not in model_list: |
|
return {'error': [0.]} |
|
|
|
model = model_list[model_name] |
|
img = img_1 if "200M" in model_name else img_2 |
|
return predict_from_model(model, img) |
|
|
|
general_examples = [ |
|
["images/general/img_1.jpg"], |
|
["images/general/img_2.jpg"], |
|
["images/general/img_3.jpg"], |
|
["images/general/img_4.jpg"], |
|
["images/general/img_5.jpg"], |
|
["images/general/img_6.jpg"], |
|
["images/general/img_7.jpg"], |
|
["images/general/img_8.jpg"], |
|
["images/general/img_9.jpg"], |
|
["images/general/img_10.jpg"], |
|
] |
|
|
|
optic_examples = [ |
|
["images/optic/img_1.jpg"], |
|
["images/optic/img_2.jpg"], |
|
["images/optic/img_3.jpg"], |
|
["images/optic/img_4.jpg"], |
|
["images/optic/img_5.jpg"], |
|
] |
|
|
|
famous_deepfake_examples = [ |
|
["images/famous_deepfakes/img_1.jpg"], |
|
["images/famous_deepfakes/img_2.jpg"], |
|
["images/famous_deepfakes/img_3.jpg"], |
|
["images/famous_deepfakes/img_4.webp"], |
|
] |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
<h1 style="text-align: center;">For Fake's Sake: a set of models for detecting generated and synthetic images</h3> |
|
This is a demo space for synthetic image detectors: |
|
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>, |
|
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>, |
|
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>, |
|
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>.<br> |
|
We provide several detectors for images generated by popular tools, such as Midjourney and Stable Diffusion.<br> |
|
Please refer to model cards for evaluation metrics and limitations. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil") |
|
drop_down = gr.Dropdown(model_list.keys(), type="value", label="Model", value="diffusions_200M") |
|
with gr.Row(): |
|
gr.ClearButton(components=[image_input]) |
|
submit_button = gr.Button("Submit", variant="primary") |
|
with gr.Column(): |
|
result_score = gr.Label(label='result', num_top_classes=2) |
|
with gr.Tab("Examples"): |
|
gr.Examples(examples=general_examples, inputs=image_input) |
|
|
|
|
|
with gr.Tab("Widely known deepfakes"): |
|
gr.Examples(examples=famous_deepfake_examples, inputs=image_input) |
|
|
|
submit_button.click(predict, inputs=[image_input, drop_down], outputs=result_score) |
|
|
|
gr.Markdown( |
|
""" |
|
<h3>Models</h3> |
|
<p><code>*_200M</code> models are based on <code>convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384</code> with image size <code>640x640</code></p> |
|
<p><code>*_5M</code> models are based on <code>tf_mobilenetv3_large_100.in1k</code> with image size <code>224x224</code></p> |
|
|
|
<h3>Details</h3> |
|
<li>Model cards: <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>, |
|
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>, |
|
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>, |
|
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>. |
|
</li> |
|
<li>License: CC-By-SA-3.0</li> |
|
""" |
|
) |
|
|
|
demo.launch() |