File size: 2,707 Bytes
6915a56 04917da 6915a56 04917da 6915a56 04917da 6915a56 7929161 6915a56 7929161 6915a56 531be4a 7929161 d5baf32 0fd923a 6915a56 7929161 6915a56 7929161 6915a56 7929161 d5baf32 0fd923a 6915a56 7929161 6915a56 7929161 6915a56 7929161 6915a56 ef81edb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
import cv2
import numpy as np
import gradio as gr
from inference import Predictor
from utils.image_processing import resize_image
os.makedirs('output', exist_ok=True)
def inference(
image: np.ndarray,
style,
imgsz=None,
retain_color=False,
):
if imgsz is not None:
imgsz = int(imgsz)
retain_color = retain_color
weight = {
"AnimeGANv2_Hayao": "GeneratorV2_gldv2_Hayao.pt",
"AnimeGANv2_Shinkai": "GeneratorV2_gldv2_Shinkai.pt",
"AnimeGANv2_Arcane": "GeneratorV2_ffhq_Arcane_210624_e350.pt",
"AnimeGANv2_Test": "GeneratorV2_train_photo_Hayao.pt",
"SummerWar": "GeneratorV2_train_photo_SummerWar.pt",
"Hetalia": "GeneratorV2_train_photo_Hetalia.pt",
}[style]
predictor = Predictor(
weight,
device='cpu',
retain_color=retain_color,
imgsz=imgsz,
)
save_path = f"output/out.jpg"
image = resize_image(image, width=imgsz)
anime_image = predictor.transform(image)[0]
cv2.imwrite(save_path, anime_image[..., ::-1])
return anime_image, save_path
title = "图片动漫风格转换"
description = r"""将图片转换成动漫风格"""
gr.Interface(
fn=inference,
inputs=[
gr.components.Image(label="输入图片"),
gr.Dropdown(
[
'AnimeGANv2_Hayao',
'AnimeGANv2_Shinkai',
'AnimeGANv2_Arcane',
'AnimeGANv2_Test',
'SummerWar',
'Hetalia',
],
type="value",
value='AnimeGANv2_Hayao',
label='转换风格'
),
gr.Dropdown(
[
None,
416,
512,
768,
1024,
1536,
],
type="value",
value=None,
label='图片大小'
),
gr.Checkbox(value=False, label="保留原图颜色"),
],
outputs=[
gr.components.Image(type="numpy", label="转换后图片"),
gr.components.File(label="下载转换图片")
],
title=title,
description=description,
allow_flagging="never",
examples=[
['example/face/girl4.jpg', 'AnimeGANv2_Arcane', None],
['example/face/leo.jpg', 'AnimeGANv2_Arcane', None],
['example/face/cap.jpg', 'AnimeGANv2_Arcane', None],
['example/face/anne.jpg', 'AnimeGANv2_Arcane', None],
['example/landscape/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', None],
['example/landscape/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', None],
]
).launch()
# server_name="0.0.0.0", server_port=8080 |