|
|
|
|
|
from __future__ import annotations |
|
import argparse |
|
import functools |
|
import os |
|
import pathlib |
|
import sys |
|
from typing import Callable |
|
import uuid |
|
|
|
sys.path.insert(0, 'APDrawingGAN2') |
|
|
|
import gradio as gr |
|
import huggingface_hub |
|
import numpy as np |
|
import PIL.Image |
|
|
|
from io import BytesIO |
|
import shutil |
|
|
|
from options.test_options import TestOptions |
|
from data import CreateDataLoader |
|
from models import create_model |
|
import dlib |
|
import preprocess.get_partmask |
|
from util import html |
|
|
|
import ntpath |
|
from util import util |
|
|
|
ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2' |
|
TITLE = 'yiranran/APDrawingGAN2' |
|
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}. |
|
|
|
""" |
|
ARTICLE = """ |
|
|
|
""" |
|
|
|
MODEL_REPO = 'hylee/apdrawing_model' |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--device', type=str, default='cpu') |
|
parser.add_argument('--theme', type=str) |
|
parser.add_argument('--live', action='store_true') |
|
parser.add_argument('--share', action='store_true') |
|
parser.add_argument('--port', type=int) |
|
parser.add_argument('--disable-queue', |
|
dest='enable_queue', |
|
action='store_false') |
|
parser.add_argument('--allow-flagging', type=str, default='never') |
|
parser.add_argument('--allow-screenshot', action='store_true') |
|
return parser.parse_args() |
|
|
|
|
|
def load_checkpoint(): |
|
dir = 'checkpoint' |
|
checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO, |
|
'checkpoints.zip', |
|
force_filename='checkpoints.zip') |
|
print(checkpoint_path) |
|
shutil.unpack_archive(checkpoint_path, extract_dir=dir) |
|
|
|
print(os.listdir(dir + '/checkpoints')) |
|
|
|
return dir + '/checkpoints' |
|
|
|
|
|
|
|
def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256): |
|
short_path = ntpath.basename(image_path[0]) |
|
name = os.path.splitext(short_path)[0] |
|
|
|
imgs = [] |
|
|
|
for label, im_data in visuals.items(): |
|
im = util.tensor2im(im_data) |
|
image_name = '%s_%s.png' % (name, label) |
|
save_path = os.path.join(image_dir, image_name) |
|
h, w, _ = im.shape |
|
if aspect_ratio > 1.0: |
|
im = np.array(PIL.Image.fromarray(arr).resize(im, (h, int(w * aspect_ratio)))) |
|
if aspect_ratio < 1.0: |
|
im = np.array(PIL.Image.fromarray(arr).resize(im, (int(h / aspect_ratio), w))) |
|
util.save_image(im, save_path) |
|
imgs.append(save_path) |
|
|
|
return imgs |
|
|
|
|
|
SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"] |
|
|
|
|
|
def compress_UUID(): |
|
''' |
|
根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串 |
|
包括:[0-9a-zA-Z\-_]共64个 |
|
长度:(32-2)/3*2=20 |
|
备注:可在地球上人zhi人都用,使用100年不重复(2^120) |
|
:return:String |
|
''' |
|
row = str(uuid.uuid4()).replace('-', '') |
|
safe_code = '' |
|
for i in range(10): |
|
enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10) |
|
safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)]) |
|
safe_code = safe_code.replace('-', '') |
|
return safe_code |
|
|
|
|
|
|
|
def get_68lm(imgfile, savepath, detector, predictor): |
|
image = cv2.imread(imgfile) |
|
rgbImg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
rects = detector(rgbImg, 1) |
|
for (i, rect) in enumerate(rects): |
|
landmarks = predictor(rgbImg, rect) |
|
landmarks = shape_to_np(landmarks) |
|
f = open(savepath, 'w') |
|
for i in range(len(landmarks)): |
|
lm = landmarks[i] |
|
print(lm[0], lm[1], file=f) |
|
f.close() |
|
|
|
|
|
def run( |
|
image, |
|
model, |
|
opt, |
|
detector, |
|
predictor, |
|
) -> tuple[PIL.Image.Image,PIL.Image.Image,PIL.Image.Image,PIL.Image.Image]: |
|
dataroot = 'images/' + compress_UUID() |
|
opt.dataroot = os.path.join(dataroot, 'src/') |
|
os.makedirs(opt.dataroot, exist_ok=True) |
|
opt.results_dir = os.path.join(dataroot, 'results/') |
|
os.makedirs(opt.results_dir, exist_ok=True) |
|
|
|
opt.lm_dir = os.path.join(dataroot, 'landmark/') |
|
opt.bg_dir = os.path.join(dataroot, 'mask/') |
|
os.makedirs(opt.lm_dir, exist_ok=True) |
|
os.makedirs(opt.bg_dir, exist_ok=True) |
|
|
|
shutil.copy(image.name, opt.dataroot) |
|
|
|
fullname = os.path.basename(image.name) |
|
name = fullname.split(".")[0] |
|
|
|
imgfile = os.path.join(opt.dataroot, fullname) |
|
lmfile = os.path.join(opt.lm_dir, name+'.txt') |
|
|
|
get_68lm(imgfile, lmfile, detector, predictor) |
|
|
|
imgs = [] |
|
for part in ['eyel', 'eyer', 'nose', 'mouth']: |
|
savepath = os.path.join(opt.bg_dir + part, name+'.png') |
|
get_partmask.get_partmask(imgfile, part, lmfile, savepath) |
|
imgs.append(savepath) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return PIL.Image.open(imgs[0]),PIL.Image.open(imgs[1]),PIL.Image.open(imgs[2]),PIL.Image.open(imgs[3]) |
|
|
|
|
|
def main(): |
|
gr.close_all() |
|
|
|
args = parse_args() |
|
|
|
checkpoint_dir = load_checkpoint() |
|
|
|
opt = TestOptions().parse() |
|
opt.num_threads = 1 |
|
opt.batch_size = 1 |
|
opt.serial_batches = True |
|
opt.no_flip = True |
|
opt.display_id = -1 |
|
|
|
''' |
|
python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single |
|
''' |
|
opt.dataroot = 'dataset/test_single' |
|
opt.name = 'apdrawinggan++_author' |
|
opt.model = 'test' |
|
opt.use_resnet = True |
|
opt.netG = 'resnet_9blocks' |
|
opt.which_epoch = 150 |
|
opt.how_many = 1000 |
|
opt.gpu_ids = '' |
|
opt.gpu_ids_p = '' |
|
opt.imagefolder = 'images-single' |
|
|
|
opt.checkpoints_dir = checkpoint_dir |
|
|
|
model = create_model(opt) |
|
model.setup(opt) |
|
|
|
''' |
|
预处理数据 |
|
''' |
|
detector = dlib.get_frontal_face_detector() |
|
predictor = dlib.shape_predictor(checkpoint_dir + '/shape_predictor_68_face_landmarks.dat') |
|
|
|
func = functools.partial(run, model=model, opt=opt, detector=detector, predictor=predictor) |
|
func = functools.update_wrapper(func, run) |
|
|
|
gr.Interface( |
|
func, |
|
[ |
|
gr.inputs.Image(type='file', label='Input Image'), |
|
], |
|
[ |
|
gr.outputs.Image( |
|
type='pil', |
|
label='Result'), |
|
], |
|
|
|
theme=args.theme, |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
article=ARTICLE, |
|
allow_screenshot=args.allow_screenshot, |
|
allow_flagging=args.allow_flagging, |
|
live=args.live, |
|
).launch( |
|
enable_queue=args.enable_queue, |
|
server_port=args.port, |
|
share=args.share, |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|