Spaces:
Runtime error
Runtime error
File size: 4,051 Bytes
8d4d98f 46afa05 1fe1457 8d4d98f 3f2d4dc 8d4d98f 1fe1457 8d4d98f feb9da2 8d4d98f ddaf006 8d4d98f 8fd4221 8d4d98f 474b6cf 46a015d feb9da2 aa0db9c 91959e5 38887ff 91959e5 1fe1457 91959e5 aa0db9c a058c0e 3f2d4dc 8d4d98f a058c0e 5d457fc aa0db9c fec4733 5d457fc 7fee812 28df4ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
from PIL import Image
import torch
import gradio as gr
os.system("pip install dlib")
import torch
torch.backends.cudnn.benchmark = True
from torchvision import transforms, utils
from util import *
from PIL import Image
import math
import random
import numpy as np
from torch import nn, autograd, optim
from torch.nn import functional as F
from tqdm import tqdm
import lpips
from model import *
from e4e_projection import projection as e4e_projection
from copy import deepcopy
import imageio
os.makedirs('inversion_codes', exist_ok=True)
os.makedirs('style_images', exist_ok=True)
os.makedirs('style_images_aligned', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
device = 'cpu'
os.system("gdown https://drive.google.com/uc?id=1-AG7JPTWc9REBrkll3OyEpZwSOWhlX0j")
latent_dim = 512
# Load original generator
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
mean_latent = original_generator.mean_latent(10000)
# to be finetuned generator
generatorjojo = deepcopy(original_generator)
generatordisney = deepcopy(original_generator)
transform = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK")
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
ckptjojo = torch.load('jojo.pt', map_location=lambda storage, loc: storage)
generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi")
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
generatordisney.load_state_dict(ckptdisney["g"], strict=False)
def inference(img, model):
aligned_face = align_face(img)
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
if model == 'JoJo':
with torch.no_grad():
generator.eval()
#original_my_sample = original_generator(my_w, input_is_latent=True)
my_sample = generatorjojo(my_w, input_is_latent=True)
else:
with torch.no_grad():
generator.eval()
#original_my_sample = original_generator(my_w, input_is_latent=True)
my_sample = generatordisney(my_w, input_is_latent=True)
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
imageio.imwrite('filename.jpeg', npimage)
return 'filename.jpeg'
title = "JoJoGAN"
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
examples=[['iu.jpeg']]
gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gr.inputs.Dropdown(choices=['JoJo', 'Disney'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
|