import numpy as np import gradio as gr import cv2 import os import argparse from inference import Predictor import io #from black import to_black # os.system("wget https://huggingface.co/YANGYYYY/cartoonize/tree/main/GeneratorV2_train_photo_Hayao_init.pt") # if os.path.exists("GeneratorV2_train_photo_Hayao_init.pt"): # print("下载成功!") # else: # print("下载失败!") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt') parser.add_argument('--device', type=str, default='cpu', help='Device, cuda or cpu') return parser.parse_args() def parse_args_video(): parser = argparse.ArgumentParser() parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt') parser.add_argument('--src', type=str, default='dataset/video/花.mp4', help='Path to input video') parser.add_argument('--out', type=str, default='dataset/video_Hayao/hua_hayao.mp4', help='Path to save new video') parser.add_argument('--batch-size', type=int, default=4) parser.add_argument('--start', type=int, default=0, help='Start time of video (second)') parser.add_argument('--end', type=int, default=10, help='End time of video (second), 0 if not set') return parser.parse_args() def transfer(image, transfer_style): if transfer_style == "Hayao": #output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像 #os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt") args = parse_args() predictor = Predictor(args.weight, args.device) anime_img = predictor.transform_image(image) return anime_img elif transfer_style == "Shinkai": args = parse_args() args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt' predictor = Predictor(args.weight, args.device) anime_img = predictor.transform_image(image) return anime_img elif transfer_style == "Kon Satoshi": args = parse_args() args.weight = 'GeneratorV2_train_photo_Paprika_init.pt' predictor = Predictor(args.weight, args.device) anime_img = predictor.transform_image(image) return anime_img else: return image def transfer_video(video_input, transfer_style): if transfer_style == "Hayao": #output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像 #os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt") args = parse_args_video() args.src = video_input args.out = "video.mp4" Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) return args.out #anime_video = Predictor(args.weight).transform_video(video, args.batch_size, args.start, args.end) #return anime_video elif transfer_style == "Shinkai": args = parse_args_video() args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt' args.src = video_input args.out = "video.mp4" Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) return args.out elif transfer_style == "Kon Satoshi": args = parse_args_video() args.weight = 'GeneratorV2_train_photo_Paprika_init.pt' args.src = video_input args.out = "video.mp4" Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) return args.out else: return 0 def clear_output(input_widget): input_widget = np.array([]) with gr.Blocks() as demo: gr.Markdown("Transfer image or video files using this demo.") with gr.Tabs(): with gr.TabItem("Transfer Image"): with gr.Row(): image_input = gr.Image() image_output = gr.Image() with gr.Row(): image_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"]) image_button = gr.Button("Transfer") clear_image_button = gr.Button("Clear") with gr.TabItem("Transfer Video"): with gr.Row(): video_input = gr.Video() video_output = gr.Video() with gr.Row(): video_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"]) video_button = gr.Button("Transfer") clear_video_button = gr.Button("Clear") image_button.click(transfer, inputs=[image_input,image_dropdown], outputs=image_output) video_button.click(transfer_video, inputs=[video_input,video_dropdown],outputs=video_output) clear_image_button.click(clear_output, inputs=image_input,outputs=image_output) clear_video_button.click(clear_output, inputs=video_input,outputs=video_output) demo.launch() # 启动接口 #demo.launch(server_name='127.0.0.1',server_port=7788)