import io import os import shutil import requests import time import numpy as np from PIL import Image, ImageOps from math import nan import math import pickle import warnings warnings.filterwarnings("ignore") import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, ConcatDataset, DataLoader from torchvision.datasets import ImageFolder import torchvision.transforms as T import torchvision.transforms.functional as TF from torch.cuda.amp import autocast, GradScaler import jax import jax.numpy as jnp import transformers from transformers.modeling_flax_utils import FlaxPreTrainedModel from vqgan_jax.modeling_flax_vqgan import VQModel import gradio as gr device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Model_Z1(nn.Module): def __init__(self): super(Model_Z1, self).__init__() self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1) self.batchnorm = nn.BatchNorm2d(2048) self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1) self.batchnorm2 = nn.BatchNorm2d(256) self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1) self.batchnorm3 = nn.BatchNorm2d(1024) self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1) self.batchnorm4 = nn.BatchNorm2d(256) self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) self.batchnorm5 = nn.BatchNorm2d(512) self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) self.elu = nn.ELU() def forward(self, x): res = x x = self.elu(self.conv1(x)) x = self.batchnorm(x) x = self.elu(self.conv2(x)) + res x = self.batchnorm2(x) x = self.elu(self.conv3(x)) x = self.batchnorm3(x) x = self.elu(self.conv4(x)) + res x = self.batchnorm4(x) x = self.elu(self.conv5(x)) x = self.batchnorm5(x) out = self.elu(self.conv6(x)) + res return out class Model_Z(nn.Module): def __init__(self): super(Model_Z, self).__init__() self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1) self.batchnorm = nn.BatchNorm2d(2048) self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1) self.batchnorm2 = nn.BatchNorm2d(256) self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1) self.batchnorm3 = nn.BatchNorm2d(1024) self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1) self.batchnorm4 = nn.BatchNorm2d(256) self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) self.batchnorm5 = nn.BatchNorm2d(512) self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) self.batchnorm6 = nn.BatchNorm2d(256) self.conv7 = nn.Conv2d(in_channels=256, out_channels=448, kernel_size=3, padding=1) self.batchnorm7 = nn.BatchNorm2d(448) self.conv8 = nn.Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1) self.batchnorm8 = nn.BatchNorm2d(384) self.conv9 = nn.Conv2d(in_channels=384, out_channels=320, kernel_size=3, padding=1) self.batchnorm9 = nn.BatchNorm2d(320) self.conv10 = nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, padding=1) self.elu = nn.ELU() def forward(self, x): res = x x = self.elu(self.conv1(x)) x = self.batchnorm(x) x = self.elu(self.conv2(x)) + res x = self.batchnorm2(x) x = self.elu(self.conv3(x)) x = self.batchnorm3(x) x = self.elu(self.conv4(x)) + res x = self.batchnorm4(x) x = self.elu(self.conv5(x)) x = self.batchnorm5(x) x = self.elu(self.conv6(x)) + res x = self.batchnorm6(x) x = self.elu(self.conv7(x)) x = self.batchnorm7(x) x = self.elu(self.conv8(x)) x = self.batchnorm8(x) x = self.elu(self.conv9(x)) x = self.batchnorm9(x) out = self.elu(self.conv10(x)) + res return out def tensor_jax(x): if x.dim() == 3: x = x. unsqueeze(0) x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy() # Convert from (N, C, H, W) to (N, H, W, C) and move to CPU x_jax = jnp.array(x_np) return x_jax def jax_to_tensor(x): x_tensor = torch.tensor(np.array(x),requires_grad=True).permute(0, 3, 1, 2).to(device) # Convert from (N, H, W, C) to (N, C, H, W) return x_tensor # Define the transform transform = T.Compose([ T.Resize((256, 256)), T.ToTensor() ]) def gen_sources(img): model_name = "dalle-mini/vqgan_imagenet_f16_16384" model_vaq = VQModel.from_pretrained(model_name) model_z1 = Model_Z1() model_z1 = model_z1.to(device) model_z1.load_state_dict(torch.load("./model_z1.pth",map_location=device)) model_z2 = Model_Z() model_z2 = model_z2.to(device) model_z2.load_state_dict(torch.load("./model_z2.pth",map_location=device)) model_zdf = Model_Z() model_zdf = model_zdf.to(device) model_zdf.load_state_dict(torch.load("./model_zdf.pth",map_location=device)) criterion = nn.MSELoss() model_z1.eval() model_z2.eval() model_zdf.eval() with torch.no_grad(): img = img.convert('RGB') df_img = transform(img) df_img = df_img.unsqueeze(0) # Change shape to (1, 3, 256, 256) df_img = df_img.to(device) #convert images: tensor --> jax_array df_img_jax = tensor_jax(df_img) #calculate quantized_code(z) for all images z_df,_ = model_vaq.encode(df_img_jax) #convert quantized_code(z): jax_array --> tensor z_df_tensor = jax_to_tensor(z_df) ##---------------------------------------------------------------------- ##----------------------model_z1----------------------- outputs_z1 = model_z1(z_df_tensor) #generate img1 z1_rec_jax = tensor_jax(outputs_z1) rec_img1 = model_vaq.decode(z1_rec_jax) ##---------------------------------------------------------------------- ##----------------------model_z2----------------------- outputs_z2 = model_z2(z_df_tensor) #generate img2 z2_rec_jax = tensor_jax(outputs_z2) rec_img2 = model_vaq.decode(z2_rec_jax) ##---------------------------------------------------------------------- ##----------------------model_zdf----------------------- z_rec = outputs_z1 + outputs_z2 outputs_zdf = model_zdf(z_rec) lossdf = criterion(outputs_zdf, z_df_tensor) #calculate dfimg reconstruction loss zdf_rec_jax = tensor_jax(outputs_zdf) rec_df = model_vaq.decode(zdf_rec_jax) rec_df_tensor = jax_to_tensor(rec_df) dfimgloss = criterion(rec_df_tensor, df_img) # Convert tensor back to a PIL image rec_img1 = jax_to_tensor(rec_img1) rec_img1 = rec_img1.squeeze(0) rec_img2 = jax_to_tensor(rec_img2) rec_img2 = rec_img2.squeeze(0) rec_df = jax_to_tensor(rec_df) rec_df = rec_df.squeeze(0) rec_img1_pil = T.ToPILImage()(rec_img1) rec_img2_pil = T.ToPILImage()(rec_img2) rec_df_pil = T.ToPILImage()(rec_df) return (rec_img1_pil, rec_img2_pil, round(dfimgloss.item(),3)) # Create the Gradio interface interface = gr.Interface( fn=gen_sources, inputs=gr.Image(type="pil", label="Input Image"), outputs=[ gr.Image(type="pil", label="Source Image 1"), gr.Image(type="pil", label="Source Image 2"), #gr.Image(type="pil", label="Deepfake Image"), gr.Number(label="Reconstruction Loss") ], examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"], theme = gr.themes.Soft(), title="Uncovering Deepfake Image", description="Upload an image.", ) interface.launch()