File size: 7,601 Bytes
d16004a
6dd3263
 
 
d16004a
6dd3263
 
 
d16004a
 
6dd3263
 
 
 
 
 
 
d16004a
6dd3263
 
 
 
 
 
 
d16004a
6dd3263
 
 
d16004a
6dd3263
 
 
 
d16004a
6dd3263
 
d16004a
6dd3263
 
 
 
 
 
 
 
 
d16004a
6dd3263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d16004a
 
 
6dd3263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import gradio as gr
import os
import torch
import tempfile
import random
import string
import json
from omegaconf import OmegaConf,ListConfig


from train import main as train_main
from inference import inference as inference_main
# 模拟训练函数
def train_model(video, config):
    output_dir = 'results'
    os.makedirs(output_dir, exist_ok=True)
    cur_save_dir = os.path.join(output_dir, str(len(os.listdir(output_dir))).zfill(2))

    config.dataset.single_video_path = video
    config.train.output_dir = cur_save_dir
    
    # copy video to cur_save_dir
    video_name = 'source.mp4'
    video_path = os.path.join(cur_save_dir, video_name)
    os.system(f"cp {video} {video_path}")

    train_main(config)
    # cur_save_dir = 'results/06'
    return cur_save_dir

# 模拟推理函数
def inference_model(text, checkpoint, inference_steps, video_type,seed):
    
    checkpoint = os.path.join('results',checkpoint)

    embedding_dir = '/'.join(checkpoint.split('/')[:-1])
    video_round = checkpoint.split('/')[-1]

    video_path = inference_main(
        embedding_dir=embedding_dir,
        prompt=text, 
        video_round=video_round,
        save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]),
        motion_type=video_type,
        seed=seed,
        inference_steps=inference_steps
        )

    return video_path


# 获取checkpoint文件列表
def get_checkpoints(checkpoint_dir):
    
    checkpoints = []
    for root, dirs, files in os.walk(checkpoint_dir):
        for file in files:
            if file == 'motion_embed.pt':
                checkpoints.append('/'.join(root.split('/')[-2:]))
    return checkpoints


def extract_combinations(motion_embeddings_combinations):
    assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required"
    combinations = []
    for combination in motion_embeddings_combinations:
        name, resolution = combination.split(" ")
        combinations.append([name, int(resolution)])
    return combinations


def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):

    default_config = OmegaConf.load('configs/config.yaml')

    default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
    default_config.model.unet = unet
    default_config.train.checkpointing_steps = checkpointing_steps
    default_config.train.max_train_steps = max_train_steps

    return default_config


def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):

    default_config = OmegaConf.load('configs/config.yaml')

    default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
    default_config.model.unet = unet
    default_config.train.checkpointing_steps = checkpointing_steps
    default_config.train.max_train_steps = max_train_steps

    return default_config


def update_preview_video(checkpoint_dir):
    # get the parent dir of the checkpoint
    parent_dir = '/'.join(checkpoint_dir.split('/')[:-1])
    return gr.update(value=f'results/{parent_dir}/source.mp4')


if __name__ == "__main__":
    inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640']
    default_motion_embeddings_combinations = ['down 1280','up 1280']

    examples_train = [
        'assets/train/car_turn.mp4',
        'assets/train/pan_up.mp4',
        'assets/train/run_up.mp4',
        'assets/train/train_ride.mp4',
        'assets/train/orbit_shot.mp4',
        'assets/train/dolly_zoom_out.mp4',
        'assets/train/santa_dance.mp4',
    ]

    examples_inference = [
        ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'],
        ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint-100'],
        ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint-300'],

        ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'],
        ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint-200'],
        ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'],
        ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint-200'], 
    ]

    # 创建Gradio界面
    with gr.Blocks() as demo:
        with gr.Tab("Train"):
            with gr.Row():
                with gr.Column():
                    video_input = gr.Video(label="Upload Video")
                    train_button = gr.Button("Train")
                with gr.Column():
                    checkpoint_output = gr.Textbox(label="Checkpoint Directory")
            
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    motion_embeddings_combinations = gr.Dropdown(label="Motion Embeddings Combinations", choices=inject_motion_embeddings_combinations, multiselect=True,value=default_motion_embeddings_combinations)
                    unet_dropdown = gr.Dropdown(label="Unet", choices=["videoCrafter2", "zeroscope_v2_576w"], value="videoCrafter2")
                    checkpointing_steps = gr.Dropdown(label="Checkpointing Steps",choices=[100,50],value=100)
                    max_train_steps = gr.Slider(label="Max Train Steps", minimum=200,maximum=500,value=200,step=50)
            
            # examples
            gr.Examples(examples=examples_train,inputs=[video_input])


            train_button.click(
                lambda video, mec, u, cs, mts: train_model(video, generate_config_train(mec, u, cs, mts)),
                inputs=[video_input, motion_embeddings_combinations, unet_dropdown, checkpointing_steps, max_train_steps],
                outputs=checkpoint_output
            )

        with gr.Tab("Inference"):
            with gr.Row():
                with gr.Column():
                    preview_video = gr.Video(label="Preview Video")
                    text_input = gr.Textbox(label="Input Text")
                    checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results'))
                    seed = gr.Number(label="Seed", value=0)
                    inference_button = gr.Button("Generate Video")
                
                with gr.Column():
                    
                    output_video = gr.Video(label="Output Video")

            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    inference_steps = gr.Number(label="Inference Steps", value=30)
                    motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object")

            gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown])


            def update_checkpoints(checkpoint_dir):
                return gr.update(choices=get_checkpoints('results'))

            checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video)
            checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown)
            inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)

    # 启动Gradio界面
    demo.launch()