Sumsub-ffs-demo / app.py
kalaidin's picture
Update app.py
295487b
raw
history blame
6.27 kB
import os
os.system("python -m pip install --upgrade pip")
os.system("pip install git+https://github.com/rwightman/pytorch-image-models")
os.system("pip install git+https://github.com/huggingface/huggingface_hub")
import gradio as gr
import timm
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
class Model200M(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False,
num_classes=0)
self.clf = nn.Sequential(
nn.Linear(1536, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 2))
def forward(self, image):
image_features = self.model(image)
return self.clf(image_features)
class Model5M(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0)
self.clf = nn.Sequential(
nn.Linear(1280, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 2))
def forward(self, image):
image_features = self.model(image)
return self.clf(image_features)
def load_model(name: str):
model = Model200M() if "200M" in name else Model5M()
ckpt = torch.load(name, map_location=torch.device('cpu'))
model.load_state_dict(ckpt)
model.eval()
return model
model_list = {
'midjourney_200M': load_model('models/midjourney200M.pt'),
'diffusions_200M': load_model('models/diffusions200M.pt'),
'midjourney_5M': load_model('models/midjourney5M.pt'),
'diffusions_5M': load_model('models/diffusions5M.pt')
}
tfm = torchvision.transforms.Compose([
torchvision.transforms.Resize((640, 640)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
tfm_small = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def predict_from_model(model, img_1):
y = model.forward(img_1[None, ...])
y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy()
y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy()
return {'created by AI': y_1.tolist(),
'created by human': y_2.tolist()}
def predict(raw_image, model_name):
img_1 = tfm(raw_image)
img_2 = tfm_small(raw_image)
if model_name not in model_list:
return {'error': [0.]}
model = model_list[model_name]
img = img_1 if "200M" in model_name else img_2
return predict_from_model(model, img)
general_examples = [
["images/general/img_1.jpg"],
["images/general/img_2.jpg"],
["images/general/img_3.jpg"],
["images/general/img_4.jpg"],
["images/general/img_5.jpg"],
["images/general/img_6.jpg"],
["images/general/img_7.jpg"],
["images/general/img_8.jpg"],
["images/general/img_9.jpg"],
["images/general/img_10.jpg"],
]
optic_examples = [
["images/optic/img_1.jpg"],
["images/optic/img_2.jpg"],
["images/optic/img_3.jpg"],
["images/optic/img_4.jpg"],
["images/optic/img_5.jpg"],
]
famous_deepfake_examples = [
["images/famous_deepfakes/img_1.jpg"],
["images/famous_deepfakes/img_2.jpg"],
["images/famous_deepfakes/img_3.jpg"],
["images/famous_deepfakes/img_4.webp"],
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<h1 style="text-align: center;">For Fake's Sake: a set of models for detecting generated and synthetic images</h3>
This is a demo space for synthetic image detectors:
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>,
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>,
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>,
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>.<br>
We provide several detectors for images generated by popular tools, such as Midjourney and Stable Diffusion.<br>
Please refer to model cards for evaluation metrics and limitations.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
drop_down = gr.Dropdown(model_list.keys(), type="value", label="Model", value="diffusions_200M")
with gr.Row():
gr.ClearButton(components=[image_input])
submit_button = gr.Button("Submit", variant="primary")
with gr.Column():
result_score = gr.Label(label='result', num_top_classes=2)
with gr.Tab("Examples"):
gr.Examples(examples=general_examples, inputs=image_input)
# with gr.Tab("More examples"):
# gr.Examples(examples=optic_examples, inputs=image_input)
with gr.Tab("Widely known deepfakes"):
gr.Examples(examples=famous_deepfake_examples, inputs=image_input)
submit_button.click(predict, inputs=[image_input, drop_down], outputs=result_score)
gr.Markdown(
"""
<h3>Models</h3>
<p><code>*_200M</code> models are based on <code>convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384</code> with image size <code>640x640</code></p>
<p><code>*_5M</code> models are based on <code>tf_mobilenetv3_large_100.in1k</code> with image size <code>224x224</code></p>
<h3>Details</h3>
<li>Model cards: <a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_200'>midjourney200M</a>,
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_mj_5'>midjourney5M</a>,
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_200'>diffusions200M</a>,
<a href='https://huggingface.co/Sumsub/Sumsub-ffs-synthetic-1.0_sd_5'>diffusions5M</a>.
</li>
<li>License: CC-By-SA-3.0</li>
"""
)
demo.launch()