|
|
|
|
|
from __future__ import annotations |
|
import argparse |
|
import functools |
|
import os |
|
import pathlib |
|
import sys |
|
from typing import Callable |
|
import uuid |
|
|
|
sys.path.insert(0, 'APDrawingGAN2') |
|
|
|
import gradio as gr |
|
import huggingface_hub |
|
import numpy as np |
|
import PIL.Image |
|
|
|
from io import BytesIO |
|
import shutil |
|
|
|
from options.test_options import TestOptions |
|
from data import CreateDataLoader |
|
from models import create_model |
|
|
|
from util import html |
|
|
|
import ntpath |
|
from util import util |
|
|
|
|
|
ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2' |
|
TITLE = 'yiranran/APDrawingGAN2' |
|
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}. |
|
|
|
""" |
|
ARTICLE = """ |
|
|
|
""" |
|
|
|
|
|
MODEL_REPO = 'hylee/apdrawing_model' |
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--device', type=str, default='cpu') |
|
parser.add_argument('--theme', type=str) |
|
parser.add_argument('--live', action='store_true') |
|
parser.add_argument('--share', action='store_true') |
|
parser.add_argument('--port', type=int) |
|
parser.add_argument('--disable-queue', |
|
dest='enable_queue', |
|
action='store_false') |
|
parser.add_argument('--allow-flagging', type=str, default='never') |
|
parser.add_argument('--allow-screenshot', action='store_true') |
|
return parser.parse_args() |
|
|
|
|
|
def load_checkpoint(): |
|
dir = 'checkpoint' |
|
checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO, |
|
'checkpoints.zip', |
|
force_filename='checkpoints.zip') |
|
print(checkpoint_path) |
|
shutil.unpack_archive(checkpoint_path, extract_dir=dir) |
|
|
|
print(os.listdir(dir+'/checkpoints')) |
|
|
|
return dir+'/checkpoints' |
|
|
|
|
|
def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256): |
|
short_path = ntpath.basename(image_path[0]) |
|
name = os.path.splitext(short_path)[0] |
|
|
|
imgs = [] |
|
|
|
for label, im_data in visuals.items(): |
|
im = util.tensor2im(im_data) |
|
image_name = '%s_%s.png' % (name, label) |
|
save_path = os.path.join(image_dir, image_name) |
|
h, w, _ = im.shape |
|
if aspect_ratio > 1.0: |
|
im = np.array(PIL.Image.fromarray(arr).resize(im, (h, int(w * aspect_ratio)))) |
|
if aspect_ratio < 1.0: |
|
im = np.array(PIL.Image.fromarray(arr).resize(im, (int(h / aspect_ratio), w))) |
|
util.save_image(im, save_path) |
|
imgs.append(save_path) |
|
|
|
return imgs |
|
|
|
|
|
SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"] |
|
def compress_UUID(): |
|
''' |
|
根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串 |
|
包括:[0-9a-zA-Z\-_]共64个 |
|
长度:(32-2)/3*2=20 |
|
备注:可在地球上人zhi人都用,使用100年不重复(2^120) |
|
:return:String |
|
''' |
|
row = str(uuid.uuid4()).replace('-', '') |
|
safe_code = '' |
|
for i in range(10): |
|
enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10) |
|
safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)]) |
|
safe_code = safe_code.replace('-', '') |
|
return safe_code |
|
|
|
|
|
def run( |
|
image, |
|
model, |
|
opt, |
|
) -> tuple[PIL.Image.Image]: |
|
|
|
dataroot = 'images/'+compress_UUID() |
|
opt.dataroot = os.path.join(dataroot, 'src/') |
|
os.makedirs(opt.dataroot, exist_ok=True) |
|
opt.results_dir = os.path.join(dataroot, 'results/') |
|
os.makedirs(opt.results_dir, exist_ok=True) |
|
|
|
shutil.copy(image.name, opt.dataroot) |
|
|
|
data_loader = CreateDataLoader(opt) |
|
dataset = data_loader.load_data() |
|
|
|
imgs = [image.name] |
|
|
|
|
|
for i, data in enumerate(dataset): |
|
if i >= opt.how_many: |
|
break |
|
model.set_input(data) |
|
model.test() |
|
visuals = model.get_current_visuals() |
|
img_path = model.get_image_paths() |
|
|
|
|
|
imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) |
|
|
|
print(imgs) |
|
return PIL.Image.open(imgs[0]) |
|
|
|
|
|
def main(): |
|
gr.close_all() |
|
|
|
args = parse_args() |
|
|
|
checkpoint_dir = load_checkpoint() |
|
|
|
opt = TestOptions().parse() |
|
opt.num_threads = 1 |
|
opt.batch_size = 1 |
|
opt.serial_batches = True |
|
opt.no_flip = True |
|
opt.display_id = -1 |
|
|
|
''' |
|
python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single |
|
''' |
|
opt.dataroot = 'dataset/test_single' |
|
opt.name = 'apdrawinggan++_author' |
|
opt.model = 'test' |
|
opt.use_resnet = True |
|
opt.netG = 'resnet_9blocks' |
|
opt.which_epoch = 150 |
|
opt.how_many = 1000 |
|
opt.gpu_ids = -1 |
|
opt.gpu_ids_p = -1 |
|
opt.imagefolder = 'images-single' |
|
|
|
opt.checkpoints_dir = checkpoint_dir |
|
|
|
|
|
model = create_model(opt) |
|
model.setup(opt) |
|
|
|
func = functools.partial(run, model=model, opt=opt) |
|
func = functools.update_wrapper(func, run) |
|
|
|
|
|
gr.Interface( |
|
func, |
|
[ |
|
gr.inputs.Image(type='file', label='Input Image'), |
|
], |
|
[ |
|
gr.outputs.Image( |
|
type='pil', |
|
label='Result'), |
|
], |
|
|
|
theme=args.theme, |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
article=ARTICLE, |
|
allow_screenshot=args.allow_screenshot, |
|
allow_flagging=args.allow_flagging, |
|
live=args.live, |
|
).launch( |
|
enable_queue=args.enable_queue, |
|
server_port=args.port, |
|
share=args.share, |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|