Spaces:
Runtime error
Runtime error
File size: 6,281 Bytes
8d4d98f ce4e510 8d4d98f a06952b 0e519a2 a06952b 8d4d98f 3f2d4dc 8d4d98f 4c188b8 8d4d98f d8e7405 8d4d98f ddaf006 8d4d98f 8fd4221 ec8f0b0 49ce528 c3d2c2f e00d153 c5e7a23 9769454 ec8f0b0 8d4d98f 0e519a2 46a015d 3bc45b9 aa0db9c 1bfaf08 91959e5 38887ff f0f65c8 91959e5 ae134dc ec8f0b0 77a11db 49ce528 5356ae6 c3d2c2f 94f8dde e00d153 77a11db c5e7a23 2116b62 9769454 2116b62 9769454 91959e5 4c188b8 49ce528 4c188b8 91959e5 ec8f0b0 91959e5 49ce528 ec8f0b0 c3d2c2f 49ce528 e00d153 c3d2c2f c5e7a23 e00d153 9769454 c5e7a23 9769454 ec8f0b0 aa0db9c a058c0e 3f2d4dc 8d4d98f a058c0e 5d457fc aa0db9c 836ff97 5d457fc 90d55de ce4e510 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
from PIL import Image
import torch
os.system("pip install gradio==2.5.3")
import gradio as gr
import torch
torch.backends.cudnn.benchmark = True
from torchvision import transforms, utils
from util import *
from PIL import Image
import math
import random
import numpy as np
from torch import nn, autograd, optim
from torch.nn import functional as F
from tqdm import tqdm
import lpips
from model import *
os.makedirs('models', exist_ok=True)
os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK")
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
from e4e_projection import projection as e4e_projection
from copy import deepcopy
import imageio
os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
device = 'cpu'
os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_btZ")
latent_dim = 512
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
mean_latent = original_generator.mean_latent(10000)
generatorjojo = deepcopy(original_generator)
generatordisney = deepcopy(original_generator)
generatorjinx = deepcopy(original_generator)
generatorcaitlyn = deepcopy(original_generator)
generatoryasuho = deepcopy(original_generator)
generatorarcanemulti = deepcopy(original_generator)
generatorart = deepcopy(original_generator)
generatorspider = deepcopy(original_generator)
transform = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
os.system("wget https://huggingface.co/akhaliq/JoJoGAN-jojo/resolve/main/jojo_preserve_color.pt")
ckptjojo = torch.load('jojo_preserve_color.pt', map_location=lambda storage, loc: storage)
generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
os.system("wget https://huggingface.co/akhaliq/jojogan-disney/resolve/main/disney_preserve_color.pt")
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
generatordisney.load_state_dict(ckptdisney["g"], strict=False)
os.system("wget https://huggingface.co/akhaliq/jojo-gan-jinx/resolve/main/arcane_jinx_preserve_color.pt")
ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
os.system("wget https://huggingface.co/akhaliq/jojogan-arcane/resolve/main/arcane_caitlyn_preserve_color.pt")
ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage)
generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
os.system("wget https://huggingface.co/akhaliq/JoJoGAN-jojo/resolve/main/jojo_yasuho_preserve_color.pt")
ckptyasuho = torch.load('jojo_yasuho_preserve_color.pt', map_location=lambda storage, loc: storage)
generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
os.system("wget https://huggingface.co/akhaliq/jojogan-arcane/resolve/main/arcane_multi_preserve_color.pt")
ckptarcanemulti = torch.load('arcane_multi_preserve_color.pt', map_location=lambda storage, loc: storage)
generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
os.system("wget https://huggingface.co/akhaliq/jojo-gan-art/resolve/main/art.pt")
ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage)
generatorart.load_state_dict(ckptart["g"], strict=False)
os.system("wget https://huggingface.co/akhaliq/jojo-gan-spiderverse/resolve/main/Spiderverse-face-500iters-8face.pt")
ckptspider = torch.load('Spiderverse-face-500iters-8face.pt', map_location=lambda storage, loc: storage)
generatorspider.load_state_dict(ckptspider["g"], strict=False)
def inference(img, model):
aligned_face = align_face(img)
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
if model == 'JoJo':
with torch.no_grad():
my_sample = generatorjojo(my_w, input_is_latent=True)
elif model == 'Disney':
with torch.no_grad():
my_sample = generatordisney(my_w, input_is_latent=True)
elif model == 'Jinx':
with torch.no_grad():
my_sample = generatorjinx(my_w, input_is_latent=True)
elif model == 'Caitlyn':
with torch.no_grad():
my_sample = generatorcaitlyn(my_w, input_is_latent=True)
elif model == 'Yasuho':
with torch.no_grad():
my_sample = generatoryasuho(my_w, input_is_latent=True)
elif model == 'Arcane Multi':
with torch.no_grad():
my_sample = generatorarcanemulti(my_w, input_is_latent=True)
elif model == 'Art':
with torch.no_grad():
my_sample = generatorart(my_w, input_is_latent=True)
else:
with torch.no_grad():
my_sample = generatorspider(my_w, input_is_latent=True)
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
imageio.imwrite('filename.jpeg', npimage)
return 'filename.jpeg'
title = "JoJoGAN"
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
examples=[['mona.png','Jinx']]
gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging="never",examples=examples,allow_screenshot=False,enable_queue=True).launch()
|