AnimeGAN / app.py
ChaosTong's picture
Update app.py
531be4a verified
raw
history blame
2.71 kB
import os
import cv2
import numpy as np
import gradio as gr
from inference import Predictor
from utils.image_processing import resize_image
os.makedirs('output', exist_ok=True)
def inference(
image: np.ndarray,
style,
imgsz=None,
retain_color=False,
):
if imgsz is not None:
imgsz = int(imgsz)
retain_color = retain_color
weight = {
"AnimeGANv2_Hayao": "GeneratorV2_gldv2_Hayao.pt",
"AnimeGANv2_Shinkai": "GeneratorV2_gldv2_Shinkai.pt",
"AnimeGANv2_Arcane": "GeneratorV2_ffhq_Arcane_210624_e350.pt",
"AnimeGANv2_Test": "GeneratorV2_train_photo_Hayao.pt",
"SummerWar": "GeneratorV2_train_photo_SummerWar.pt",
"Hetalia": "GeneratorV2_train_photo_Hetalia.pt",
}[style]
predictor = Predictor(
weight,
device='cpu',
retain_color=retain_color,
imgsz=imgsz,
)
save_path = f"output/out.jpg"
image = resize_image(image, width=imgsz)
anime_image = predictor.transform(image)[0]
cv2.imwrite(save_path, anime_image[..., ::-1])
return anime_image, save_path
title = "图片动漫风格转换"
description = r"""将图片转换成动漫风格"""
gr.Interface(
fn=inference,
inputs=[
gr.components.Image(label="输入图片"),
gr.Dropdown(
[
'AnimeGANv2_Hayao',
'AnimeGANv2_Shinkai',
'AnimeGANv2_Arcane',
'AnimeGANv2_Test',
'SummerWar',
'Hetalia',
],
type="value",
value='AnimeGANv2_Hayao',
label='转换风格'
),
gr.Dropdown(
[
None,
416,
512,
768,
1024,
1536,
],
type="value",
value=None,
label='图片大小'
),
gr.Checkbox(value=False, label="保留原图颜色"),
],
outputs=[
gr.components.Image(type="numpy", label="转换后图片"),
gr.components.File(label="下载转换图片")
],
title=title,
description=description,
allow_flagging="never",
examples=[
['example/face/girl4.jpg', 'AnimeGANv2_Arcane', None],
['example/face/leo.jpg', 'AnimeGANv2_Arcane', None],
['example/face/cap.jpg', 'AnimeGANv2_Arcane', None],
['example/face/anne.jpg', 'AnimeGANv2_Arcane', None],
['example/landscape/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', None],
['example/landscape/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', None],
]
).launch()
# server_name="0.0.0.0", server_port=8080