Spaces:
Sleeping
Sleeping
import io | |
import os | |
import shutil | |
import requests | |
import time | |
import numpy as np | |
from PIL import Image, ImageOps | |
from math import nan | |
import math | |
import pickle | |
import warnings | |
warnings.filterwarnings("ignore") | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.utils.data import Dataset, ConcatDataset, DataLoader | |
from torchvision.datasets import ImageFolder | |
import torchvision.transforms as T | |
import torchvision.transforms.functional as TF | |
from torch.cuda.amp import autocast, GradScaler | |
import jax | |
import jax.numpy as jnp | |
import transformers | |
from transformers.modeling_flax_utils import FlaxPreTrainedModel | |
from vqgan_jax.modeling_flax_vqgan import VQModel | |
import gradio as gr | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Model_Z1(nn.Module): | |
def __init__(self): | |
super(Model_Z1, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1) | |
self.batchnorm = nn.BatchNorm2d(2048) | |
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1) | |
self.batchnorm2 = nn.BatchNorm2d(256) | |
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1) | |
self.batchnorm3 = nn.BatchNorm2d(1024) | |
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1) | |
self.batchnorm4 = nn.BatchNorm2d(256) | |
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) | |
self.batchnorm5 = nn.BatchNorm2d(512) | |
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) | |
self.elu = nn.ELU() | |
def forward(self, x): | |
res = x | |
x = self.elu(self.conv1(x)) | |
x = self.batchnorm(x) | |
x = self.elu(self.conv2(x)) + res | |
x = self.batchnorm2(x) | |
x = self.elu(self.conv3(x)) | |
x = self.batchnorm3(x) | |
x = self.elu(self.conv4(x)) + res | |
x = self.batchnorm4(x) | |
x = self.elu(self.conv5(x)) | |
x = self.batchnorm5(x) | |
out = self.elu(self.conv6(x)) + res | |
return out | |
class Model_Z(nn.Module): | |
def __init__(self): | |
super(Model_Z, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1) | |
self.batchnorm = nn.BatchNorm2d(2048) | |
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1) | |
self.batchnorm2 = nn.BatchNorm2d(256) | |
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1) | |
self.batchnorm3 = nn.BatchNorm2d(1024) | |
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1) | |
self.batchnorm4 = nn.BatchNorm2d(256) | |
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) | |
self.batchnorm5 = nn.BatchNorm2d(512) | |
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) | |
self.batchnorm6 = nn.BatchNorm2d(256) | |
self.conv7 = nn.Conv2d(in_channels=256, out_channels=448, kernel_size=3, padding=1) | |
self.batchnorm7 = nn.BatchNorm2d(448) | |
self.conv8 = nn.Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1) | |
self.batchnorm8 = nn.BatchNorm2d(384) | |
self.conv9 = nn.Conv2d(in_channels=384, out_channels=320, kernel_size=3, padding=1) | |
self.batchnorm9 = nn.BatchNorm2d(320) | |
self.conv10 = nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, padding=1) | |
self.elu = nn.ELU() | |
def forward(self, x): | |
res = x | |
x = self.elu(self.conv1(x)) | |
x = self.batchnorm(x) | |
x = self.elu(self.conv2(x)) + res | |
x = self.batchnorm2(x) | |
x = self.elu(self.conv3(x)) | |
x = self.batchnorm3(x) | |
x = self.elu(self.conv4(x)) + res | |
x = self.batchnorm4(x) | |
x = self.elu(self.conv5(x)) | |
x = self.batchnorm5(x) | |
x = self.elu(self.conv6(x)) + res | |
x = self.batchnorm6(x) | |
x = self.elu(self.conv7(x)) | |
x = self.batchnorm7(x) | |
x = self.elu(self.conv8(x)) | |
x = self.batchnorm8(x) | |
x = self.elu(self.conv9(x)) | |
x = self.batchnorm9(x) | |
out = self.elu(self.conv10(x)) + res | |
return out | |
def tensor_jax(x): | |
if x.dim() == 3: | |
x = x. unsqueeze(0) | |
x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy() # Convert from (N, C, H, W) to (N, H, W, C) and move to CPU | |
x_jax = jnp.array(x_np) | |
return x_jax | |
def jax_to_tensor(x): | |
x_tensor = torch.tensor(np.array(x),requires_grad=True).permute(0, 3, 1, 2).to(device) # Convert from (N, H, W, C) to (N, C, H, W) | |
return x_tensor | |
# Define the transform | |
transform = T.Compose([ | |
T.Resize((256, 256)), | |
T.ToTensor() | |
]) | |
def gen_sources(img): | |
model_name = "dalle-mini/vqgan_imagenet_f16_16384" | |
model_vaq = VQModel.from_pretrained(model_name) | |
model_z1 = Model_Z1() | |
model_z1 = model_z1.to(device) | |
model_z1.load_state_dict(torch.load("./model_z1.pth",map_location=device)) | |
model_z2 = Model_Z() | |
model_z2 = model_z2.to(device) | |
model_z2.load_state_dict(torch.load("./model_z2.pth",map_location=device)) | |
model_zdf = Model_Z() | |
model_zdf = model_zdf.to(device) | |
model_zdf.load_state_dict(torch.load("./model_zdf.pth",map_location=device)) | |
criterion = nn.MSELoss() | |
model_z1.eval() | |
model_z2.eval() | |
model_zdf.eval() | |
with torch.no_grad(): | |
img = img.convert('RGB') | |
df_img = transform(img) | |
df_img = df_img.unsqueeze(0) # Change shape to (1, 3, 256, 256) | |
df_img = df_img.to(device) | |
#convert images: tensor --> jax_array | |
df_img_jax = tensor_jax(df_img) | |
#calculate quantized_code(z) for all images | |
z_df,_ = model_vaq.encode(df_img_jax) | |
#convert quantized_code(z): jax_array --> tensor | |
z_df_tensor = jax_to_tensor(z_df) | |
##---------------------------------------------------------------------- | |
##----------------------model_z1----------------------- | |
outputs_z1 = model_z1(z_df_tensor) | |
#generate img1 | |
z1_rec_jax = tensor_jax(outputs_z1) | |
rec_img1 = model_vaq.decode(z1_rec_jax) | |
##---------------------------------------------------------------------- | |
##----------------------model_z2----------------------- | |
outputs_z2 = model_z2(z_df_tensor) | |
#generate img2 | |
z2_rec_jax = tensor_jax(outputs_z2) | |
rec_img2 = model_vaq.decode(z2_rec_jax) | |
##---------------------------------------------------------------------- | |
##----------------------model_zdf----------------------- | |
z_rec = outputs_z1 + outputs_z2 | |
outputs_zdf = model_zdf(z_rec) | |
lossdf = criterion(outputs_zdf, z_df_tensor) | |
#calculate dfimg reconstruction loss | |
zdf_rec_jax = tensor_jax(outputs_zdf) | |
rec_df = model_vaq.decode(zdf_rec_jax) | |
rec_df_tensor = jax_to_tensor(rec_df) | |
dfimgloss = criterion(rec_df_tensor, df_img) | |
# Convert tensor back to a PIL image | |
rec_img1 = jax_to_tensor(rec_img1) | |
rec_img1 = rec_img1.squeeze(0) | |
rec_img2 = jax_to_tensor(rec_img2) | |
rec_img2 = rec_img2.squeeze(0) | |
rec_df = jax_to_tensor(rec_df) | |
rec_df = rec_df.squeeze(0) | |
rec_img1_pil = T.ToPILImage()(rec_img1) | |
rec_img2_pil = T.ToPILImage()(rec_img2) | |
rec_df_pil = T.ToPILImage()(rec_df) | |
return (rec_img1_pil, rec_img2_pil, round(dfimgloss.item(),3)) | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=gen_sources, | |
inputs=gr.Image(type="pil", label="Input Image"), | |
outputs=[ | |
gr.Image(type="pil", label="Source Image 1"), | |
gr.Image(type="pil", label="Source Image 2"), | |
#gr.Image(type="pil", label="Deepfake Image"), | |
gr.Number(label="Reconstruction Loss") | |
], | |
examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"], | |
theme = gr.themes.Soft(), | |
title="Uncovering Deepfake Image", | |
description="Upload an image.", | |
) | |
interface.launch() | |