File size: 5,428 Bytes
1f25689
 
2064e3d
e0f15a2
 
 
26bb579
e0f15a2
1f25689
92d9b00
 
 
 
 
139ee29
e0f15a2
 
a62524f
e0f15a2
 
 
 
7232d95
 
 
0155031
 
7232d95
 
 
 
 
 
e0f15a2
 
 
139ee29
d6166cf
e0f15a2
b2929ff
e0f15a2
 
f877ad1
 
 
 
 
e0f15a2
f877ad1
 
 
 
 
e0f15a2
 
1f25689
7232d95
0c394ef
7232d95
 
 
9407709
8afe88f
 
1bd26eb
8afe88f
ae0812f
8afe88f
 
 
 
 
 
 
ae0812f
8afe88f
 
0155031
 
 
 
 
 
7232d95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838acff
 
 
1f25689
838acff
1f25689
838acff
1f25689
 
 
838acff
e8df8bc
838acff
 
 
 
ca106e1
 
838acff
9791f04
838acff
 
 
e0f15a2
931c498
838acff
 
1f25689
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import gradio as gr
import cv2
import os
import argparse
from inference import Predictor
import io
#from black import to_black

# os.system("wget https://huggingface.co/YANGYYYY/cartoonize/tree/main/GeneratorV2_train_photo_Hayao_init.pt")
# if os.path.exists("GeneratorV2_train_photo_Hayao_init.pt"):
#     print("下载成功!")
# else:
#     print("下载失败!")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt')
    parser.add_argument('--device', type=str, default='cpu', help='Device, cuda or cpu')

    return parser.parse_args()

def parse_args_video():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt')
    parser.add_argument('--src', type=str, default='dataset/video/花.mp4', help='Path to input video')
    parser.add_argument('--out', type=str, default='dataset/video_Hayao/hua_hayao.mp4', help='Path to save new video')
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--start', type=int, default=0, help='Start time of video (second)')
    parser.add_argument('--end', type=int, default=10, help='End time of video (second), 0 if not set')

    return parser.parse_args()

def transfer(image, transfer_style):
    if transfer_style == "Hayao":
         #output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像
        #os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt")
        args = parse_args()
        predictor = Predictor(args.weight, args.device)
        anime_img = predictor.transform_image(image)
        return anime_img
    elif transfer_style == "Shinkai":
        args = parse_args()
        args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt'
        predictor = Predictor(args.weight, args.device)
        anime_img = predictor.transform_image(image)
        return anime_img
    elif transfer_style == "Kon Satoshi":
        args = parse_args()
        args.weight = 'GeneratorV2_train_photo_Paprika_init.pt'
        predictor = Predictor(args.weight, args.device)
        anime_img = predictor.transform_image(image)
        return anime_img
    else:
        return image

        
def transfer_video(video_input, transfer_style,video_output):
    if transfer_style == "Hayao":
         #output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像
        #os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt")
        args = parse_args_video()
        # # 加载视频文件
        # #video_binary = io.BytesIO(video)
            
        # cap = cv2.VideoCapture(video)
        
        # # 读取视频帧并保存到一个列表中
        # video_frames = []
        # while True:
        #     ret, frame = cap.read()
        #     if not ret:
        #         break
        #     video_frames.append(frame)
        
        # # 关闭视频文件
        # cap.release()
        args.src = video_input
        args.out = video_output
        Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end)

        #anime_video = Predictor(args.weight).transform_video(video, args.batch_size, args.start, args.end)
        #return anime_video
    elif transfer_style == "Shinkai":
        args = parse_args()
        args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt'
        predictor = Predictor(args.weight, args.device)
        anime_img = predictor.transform_image(image)
        return anime_img
    elif transfer_style == "Kon Satoshi":
        args = parse_args()
        args.weight = 'GeneratorV2_train_photo_Paprika_init.pt'
        predictor = Predictor(args.weight, args.device)
        anime_img = predictor.transform_image(image)
        return anime_img
    else:
        return image

def clear_output(input_widget):
    input_widget = np.array([])

with gr.Blocks() as demo:
    gr.Markdown("Transfer image or video files using this demo.")
    with gr.Tabs():
        with gr.TabItem("Transfer Image"):
            with gr.Row():
                image_input = gr.Image()
                image_output = gr.Image()
            with gr.Row():
                image_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"])
                image_button = gr.Button("Transfer")
                clear_image_button = gr.Button("Clear")
        with gr.TabItem("Transfer Video"):
            with gr.Row():
                video_input = gr.Video()
                video_output = gr.Video()
            with gr.Row():
                video_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"])
                video_button = gr.Button("Transfer")
                clear_video_button = gr.Button("Clear")

    image_button.click(transfer, inputs=[image_input,image_dropdown], outputs=image_output)
    video_button.click(transfer_video, inputs=[video_input,video_dropdown,outputs],outputs=video_output)
    clear_image_button.click(clear_output, inputs=image_input,outputs=image_output)
    clear_video_button.click(clear_output, inputs=video_input,outputs=video_output)
demo.launch()

# 启动接口
#demo.launch(server_name='127.0.0.1',server_port=7788)