|
import io |
|
from io import BytesIO |
|
import os |
|
import shutil |
|
import requests |
|
import numpy as np |
|
from PIL import Image, ImageOps |
|
import math |
|
import matplotlib.pyplot as plt |
|
import pickle |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
import torchvision.transforms.functional as TF |
|
from torch.utils.checkpoint import checkpoint |
|
from torchvision.models import vgg16 |
|
from torchmetrics.image.fid import FrechetInceptionDistance |
|
from torchmetrics.functional import structural_similarity_index_measure |
|
from facenet_pytorch import InceptionResnetV1 |
|
from taming.models.vqgan import VQModel |
|
from omegaconf import OmegaConf |
|
from taming.models.vqgan import GumbelVQ |
|
import gradio as gr |
|
from modules.finetunedvqgan import Generator |
|
from modules.modelz import DeepfakeToSourceTransformer |
|
from modules.frameworkeval import DF |
|
from modules.segmentface import FaceSegmenter |
|
from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
transform = T.Compose([ |
|
T.Resize((256, 256)), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) |
|
|
|
|
|
|
|
def gen_sources(deepfake_img): |
|
|
|
|
|
segmenter = FaceSegmenter(threshold=0.5) |
|
deepfake_seg = segmenter.segment_face(deepfake_img) |
|
|
|
config_path = "./models/config.yaml" |
|
|
|
checkpoint_path_f = "./models/model_vaq1_ff.pth" |
|
checkpoint_f = torch.load(checkpoint_path_f, map_location=device) |
|
model_vaq_f = Generator(config_path, device) |
|
model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True) |
|
model_vaq_f.eval() |
|
|
|
checkpoint_path_g = "./models/model_vaq2_gg.pth" |
|
checkpoint_g = torch.load(checkpoint_path_g, map_location=device) |
|
model_vaq_g = Generator(config_path, device) |
|
model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True) |
|
model_vaq_g.eval() |
|
|
|
model_z1 = DeepfakeToSourceTransformer().to(device) |
|
model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True) |
|
model_z1.eval() |
|
|
|
model_z2 = DeepfakeToSourceTransformer().to(device) |
|
model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth",map_location=device),strict=True) |
|
model_z2.eval() |
|
|
|
criterion = DF() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
img = Image.open(deepfake_img).convert('RGB') |
|
segimg = Image.open(deepfake_seg).convert('RGB') |
|
df_img = transform(img).unsqueeze(0).to(device) |
|
seg_img = transform(segimg).unsqueeze(0).to(device) |
|
|
|
|
|
z_df, _, _ = model_vaq_f.encode(df_img) |
|
z_seg, _, _ = model_vaq_g.encode(seg_img) |
|
rec_z_img1 = model_z1(z_df) |
|
rec_z_img2 = model_z2(z_seg) |
|
rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0) |
|
rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0) |
|
rec_img1_pil = T.ToPILImage()(rec_img1) |
|
rec_img2_pil = T.ToPILImage()(rec_img2) |
|
|
|
|
|
buffer1 = BytesIO() |
|
buffer2 = BytesIO() |
|
rec_img1_pil.save(buffer1, format="PNG") |
|
rec_img2_pil.save(buffer2, format="PNG") |
|
|
|
|
|
result = client.predict( |
|
target=file(buffer1), |
|
source=file(buffer2), slider=100, adv_slider=100, |
|
settings=["Adversarial Defense"], api_name="/run_inference" |
|
) |
|
|
|
|
|
dfimage_pil = Image.open(result) |
|
buffer3 = BytesIO() |
|
dfimage_pil.save(buffer3, format="PNG") |
|
rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device) |
|
rec_loss,_ = criterion(df_img, rec_df) |
|
|
|
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3)) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=gen_sources, |
|
inputs=gr.Image(type="pil", label="Input Image"), |
|
outputs=[ |
|
gr.Image(type="pil", label="Recovered Source Image 1 (Target Image)"), |
|
gr.Image(type="pil", label="Recovered Source Image 2 (Source Image)"), |
|
gr.Image(type="pil", label="Reconstructed Deepfake Image"), |
|
gr.Number(label="Reconstruction Loss") |
|
], |
|
examples = ["./images/df1.jpg","./images/df2.jpg","./images/df3.jpg","./images/df4.jpg"], |
|
theme = gr.themes.Soft(), |
|
title="Uncovering Deepfake Image for Identifying Source Images", |
|
description="Upload an DeepFake image.", |
|
) |
|
|
|
interface.launch(debug=True) |