Spaces:
Running
Running
import numpy as np | |
import gradio as gr | |
import cv2 | |
import os | |
import argparse | |
from inference import Predictor | |
import io | |
#from black import to_black | |
# os.system("wget https://huggingface.co/YANGYYYY/cartoonize/tree/main/GeneratorV2_train_photo_Hayao_init.pt") | |
# if os.path.exists("GeneratorV2_train_photo_Hayao_init.pt"): | |
# print("下载成功!") | |
# else: | |
# print("下载失败!") | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt') | |
parser.add_argument('--device', type=str, default='cpu', help='Device, cuda or cpu') | |
return parser.parse_args() | |
def parse_args_video(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt') | |
parser.add_argument('--src', type=str, default='dataset/video/花.mp4', help='Path to input video') | |
parser.add_argument('--out', type=str, default='dataset/video_Hayao/hua_hayao.mp4', help='Path to save new video') | |
parser.add_argument('--batch-size', type=int, default=4) | |
parser.add_argument('--start', type=int, default=0, help='Start time of video (second)') | |
parser.add_argument('--end', type=int, default=10, help='End time of video (second), 0 if not set') | |
return parser.parse_args() | |
def transfer(image, transfer_style): | |
if transfer_style == "Hayao": | |
#output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像 | |
#os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt") | |
args = parse_args() | |
predictor = Predictor(args.weight, args.device) | |
anime_img = predictor.transform_image(image) | |
return anime_img | |
elif transfer_style == "Shinkai": | |
args = parse_args() | |
args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt' | |
predictor = Predictor(args.weight, args.device) | |
anime_img = predictor.transform_image(image) | |
return anime_img | |
elif transfer_style == "Kon Satoshi": | |
args = parse_args() | |
args.weight = 'GeneratorV2_train_photo_Paprika_init.pt' | |
predictor = Predictor(args.weight, args.device) | |
anime_img = predictor.transform_image(image) | |
return anime_img | |
else: | |
return image | |
def transfer_video(video_input, transfer_style): | |
if transfer_style == "Hayao": | |
#output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像 | |
#os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt") | |
args = parse_args_video() | |
args.src = video_input | |
args.out = "video.mp4" | |
Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) | |
return args.out | |
#anime_video = Predictor(args.weight).transform_video(video, args.batch_size, args.start, args.end) | |
#return anime_video | |
elif transfer_style == "Shinkai": | |
args = parse_args_video() | |
args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt' | |
args.src = video_input | |
args.out = "video.mp4" | |
Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) | |
return args.out | |
elif transfer_style == "Kon Satoshi": | |
args = parse_args_video() | |
args.weight = 'GeneratorV2_train_photo_Paprika_init.pt' | |
args.src = video_input | |
args.out = "video.mp4" | |
Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) | |
return args.out | |
else: | |
return 0 | |
def clear_output(input_widget): | |
input_widget = np.array([]) | |
with gr.Blocks() as demo: | |
gr.Markdown("Transfer image or video files using this demo.") | |
with gr.Tabs(): | |
with gr.TabItem("Transfer Image"): | |
with gr.Row(): | |
image_input = gr.Image() | |
image_output = gr.Image() | |
with gr.Row(): | |
image_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"]) | |
image_button = gr.Button("Transfer") | |
clear_image_button = gr.Button("Clear") | |
with gr.TabItem("Transfer Video"): | |
with gr.Row(): | |
video_input = gr.Video() | |
video_output = gr.Video() | |
with gr.Row(): | |
video_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"]) | |
video_button = gr.Button("Transfer") | |
clear_video_button = gr.Button("Clear") | |
image_button.click(transfer, inputs=[image_input,image_dropdown], outputs=image_output) | |
video_button.click(transfer_video, inputs=[video_input,video_dropdown],outputs=video_output) | |
clear_image_button.click(clear_output, inputs=image_input,outputs=image_output) | |
clear_video_button.click(clear_output, inputs=video_input,outputs=video_output) | |
demo.launch() | |
# 启动接口 | |
#demo.launch(server_name='127.0.0.1',server_port=7788) | |