Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,601 Bytes
d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 d16004a 6dd3263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gradio as gr
import os
import torch
import tempfile
import random
import string
import json
from omegaconf import OmegaConf,ListConfig
from train import main as train_main
from inference import inference as inference_main
# 模拟训练函数
def train_model(video, config):
output_dir = 'results'
os.makedirs(output_dir, exist_ok=True)
cur_save_dir = os.path.join(output_dir, str(len(os.listdir(output_dir))).zfill(2))
config.dataset.single_video_path = video
config.train.output_dir = cur_save_dir
# copy video to cur_save_dir
video_name = 'source.mp4'
video_path = os.path.join(cur_save_dir, video_name)
os.system(f"cp {video} {video_path}")
train_main(config)
# cur_save_dir = 'results/06'
return cur_save_dir
# 模拟推理函数
def inference_model(text, checkpoint, inference_steps, video_type,seed):
checkpoint = os.path.join('results',checkpoint)
embedding_dir = '/'.join(checkpoint.split('/')[:-1])
video_round = checkpoint.split('/')[-1]
video_path = inference_main(
embedding_dir=embedding_dir,
prompt=text,
video_round=video_round,
save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]),
motion_type=video_type,
seed=seed,
inference_steps=inference_steps
)
return video_path
# 获取checkpoint文件列表
def get_checkpoints(checkpoint_dir):
checkpoints = []
for root, dirs, files in os.walk(checkpoint_dir):
for file in files:
if file == 'motion_embed.pt':
checkpoints.append('/'.join(root.split('/')[-2:]))
return checkpoints
def extract_combinations(motion_embeddings_combinations):
assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required"
combinations = []
for combination in motion_embeddings_combinations:
name, resolution = combination.split(" ")
combinations.append([name, int(resolution)])
return combinations
def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):
default_config = OmegaConf.load('configs/config.yaml')
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
default_config.model.unet = unet
default_config.train.checkpointing_steps = checkpointing_steps
default_config.train.max_train_steps = max_train_steps
return default_config
def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):
default_config = OmegaConf.load('configs/config.yaml')
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
default_config.model.unet = unet
default_config.train.checkpointing_steps = checkpointing_steps
default_config.train.max_train_steps = max_train_steps
return default_config
def update_preview_video(checkpoint_dir):
# get the parent dir of the checkpoint
parent_dir = '/'.join(checkpoint_dir.split('/')[:-1])
return gr.update(value=f'results/{parent_dir}/source.mp4')
if __name__ == "__main__":
inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640']
default_motion_embeddings_combinations = ['down 1280','up 1280']
examples_train = [
'assets/train/car_turn.mp4',
'assets/train/pan_up.mp4',
'assets/train/run_up.mp4',
'assets/train/train_ride.mp4',
'assets/train/orbit_shot.mp4',
'assets/train/dolly_zoom_out.mp4',
'assets/train/santa_dance.mp4',
]
examples_inference = [
['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'],
['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint-100'],
['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint-300'],
['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'],
['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint-200'],
['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'],
['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint-200'],
]
# 创建Gradio界面
with gr.Blocks() as demo:
with gr.Tab("Train"):
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video")
train_button = gr.Button("Train")
with gr.Column():
checkpoint_output = gr.Textbox(label="Checkpoint Directory")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
motion_embeddings_combinations = gr.Dropdown(label="Motion Embeddings Combinations", choices=inject_motion_embeddings_combinations, multiselect=True,value=default_motion_embeddings_combinations)
unet_dropdown = gr.Dropdown(label="Unet", choices=["videoCrafter2", "zeroscope_v2_576w"], value="videoCrafter2")
checkpointing_steps = gr.Dropdown(label="Checkpointing Steps",choices=[100,50],value=100)
max_train_steps = gr.Slider(label="Max Train Steps", minimum=200,maximum=500,value=200,step=50)
# examples
gr.Examples(examples=examples_train,inputs=[video_input])
train_button.click(
lambda video, mec, u, cs, mts: train_model(video, generate_config_train(mec, u, cs, mts)),
inputs=[video_input, motion_embeddings_combinations, unet_dropdown, checkpointing_steps, max_train_steps],
outputs=checkpoint_output
)
with gr.Tab("Inference"):
with gr.Row():
with gr.Column():
preview_video = gr.Video(label="Preview Video")
text_input = gr.Textbox(label="Input Text")
checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results'))
seed = gr.Number(label="Seed", value=0)
inference_button = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Output Video")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
inference_steps = gr.Number(label="Inference Steps", value=30)
motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object")
gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown])
def update_checkpoints(checkpoint_dir):
return gr.update(choices=get_checkpoints('results'))
checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video)
checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown)
inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)
# 启动Gradio界面
demo.launch() |