Spaces:
Running
Running
File size: 5,427 Bytes
1f25689 2064e3d e0f15a2 26bb579 e0f15a2 1f25689 92d9b00 139ee29 e0f15a2 a62524f e0f15a2 7232d95 0155031 7232d95 e0f15a2 139ee29 d6166cf e0f15a2 b2929ff e0f15a2 f877ad1 e0f15a2 f877ad1 e0f15a2 1f25689 7232d95 ab3d80a 7232d95 9407709 8afe88f 1bd26eb 8afe88f ae0812f 8afe88f ae0812f 8afe88f 0155031 ab3d80a 0155031 ab3d80a 0155031 7232d95 838acff 1f25689 838acff 1f25689 838acff 1f25689 838acff e8df8bc 838acff ca106e1 838acff 9791f04 838acff e0f15a2 7de1788 838acff 1f25689 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import numpy as np
import gradio as gr
import cv2
import os
import argparse
from inference import Predictor
import io
#from black import to_black
# os.system("wget https://huggingface.co/YANGYYYY/cartoonize/tree/main/GeneratorV2_train_photo_Hayao_init.pt")
# if os.path.exists("GeneratorV2_train_photo_Hayao_init.pt"):
# print("下载成功!")
# else:
# print("下载失败!")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt')
parser.add_argument('--device', type=str, default='cpu', help='Device, cuda or cpu')
return parser.parse_args()
def parse_args_video():
parser = argparse.ArgumentParser()
parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt')
parser.add_argument('--src', type=str, default='dataset/video/花.mp4', help='Path to input video')
parser.add_argument('--out', type=str, default='dataset/video_Hayao/hua_hayao.mp4', help='Path to save new video')
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--start', type=int, default=0, help='Start time of video (second)')
parser.add_argument('--end', type=int, default=10, help='End time of video (second), 0 if not set')
return parser.parse_args()
def transfer(image, transfer_style):
if transfer_style == "Hayao":
#output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像
#os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt")
args = parse_args()
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
elif transfer_style == "Shinkai":
args = parse_args()
args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt'
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
elif transfer_style == "Kon Satoshi":
args = parse_args()
args.weight = 'GeneratorV2_train_photo_Paprika_init.pt'
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
else:
return image
def transfer_video(video_input, transfer_style):
if transfer_style == "Hayao":
#output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像
#os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt")
args = parse_args_video()
# # 加载视频文件
# #video_binary = io.BytesIO(video)
# cap = cv2.VideoCapture(video)
# # 读取视频帧并保存到一个列表中
# video_frames = []
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# video_frames.append(frame)
# # 关闭视频文件
# cap.release()
args.src = video_input
args.out = "video.mp4"
Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end)
return args.out
#anime_video = Predictor(args.weight).transform_video(video, args.batch_size, args.start, args.end)
#return anime_video
elif transfer_style == "Shinkai":
args = parse_args()
args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt'
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
elif transfer_style == "Kon Satoshi":
args = parse_args()
args.weight = 'GeneratorV2_train_photo_Paprika_init.pt'
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
else:
return image
def clear_output(input_widget):
input_widget = np.array([])
with gr.Blocks() as demo:
gr.Markdown("Transfer image or video files using this demo.")
with gr.Tabs():
with gr.TabItem("Transfer Image"):
with gr.Row():
image_input = gr.Image()
image_output = gr.Image()
with gr.Row():
image_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"])
image_button = gr.Button("Transfer")
clear_image_button = gr.Button("Clear")
with gr.TabItem("Transfer Video"):
with gr.Row():
video_input = gr.Video()
video_output = gr.Video()
with gr.Row():
video_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"])
video_button = gr.Button("Transfer")
clear_video_button = gr.Button("Clear")
image_button.click(transfer, inputs=[image_input,image_dropdown], outputs=image_output)
video_button.click(transfer_video, inputs=[video_input,video_dropdown],outputs=video_output)
clear_image_button.click(clear_output, inputs=image_input,outputs=image_output)
clear_video_button.click(clear_output, inputs=video_input,outputs=video_output)
demo.launch()
# 启动接口
#demo.launch(server_name='127.0.0.1',server_port=7788)
|