MotionInversion / app.py
ziyangmai's picture
update page
6dd3263
raw
history blame
7.6 kB
import gradio as gr
import os
import torch
import tempfile
import random
import string
import json
from omegaconf import OmegaConf,ListConfig
from train import main as train_main
from inference import inference as inference_main
# 模拟训练函数
def train_model(video, config):
output_dir = 'results'
os.makedirs(output_dir, exist_ok=True)
cur_save_dir = os.path.join(output_dir, str(len(os.listdir(output_dir))).zfill(2))
config.dataset.single_video_path = video
config.train.output_dir = cur_save_dir
# copy video to cur_save_dir
video_name = 'source.mp4'
video_path = os.path.join(cur_save_dir, video_name)
os.system(f"cp {video} {video_path}")
train_main(config)
# cur_save_dir = 'results/06'
return cur_save_dir
# 模拟推理函数
def inference_model(text, checkpoint, inference_steps, video_type,seed):
checkpoint = os.path.join('results',checkpoint)
embedding_dir = '/'.join(checkpoint.split('/')[:-1])
video_round = checkpoint.split('/')[-1]
video_path = inference_main(
embedding_dir=embedding_dir,
prompt=text,
video_round=video_round,
save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]),
motion_type=video_type,
seed=seed,
inference_steps=inference_steps
)
return video_path
# 获取checkpoint文件列表
def get_checkpoints(checkpoint_dir):
checkpoints = []
for root, dirs, files in os.walk(checkpoint_dir):
for file in files:
if file == 'motion_embed.pt':
checkpoints.append('/'.join(root.split('/')[-2:]))
return checkpoints
def extract_combinations(motion_embeddings_combinations):
assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required"
combinations = []
for combination in motion_embeddings_combinations:
name, resolution = combination.split(" ")
combinations.append([name, int(resolution)])
return combinations
def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):
default_config = OmegaConf.load('configs/config.yaml')
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
default_config.model.unet = unet
default_config.train.checkpointing_steps = checkpointing_steps
default_config.train.max_train_steps = max_train_steps
return default_config
def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):
default_config = OmegaConf.load('configs/config.yaml')
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
default_config.model.unet = unet
default_config.train.checkpointing_steps = checkpointing_steps
default_config.train.max_train_steps = max_train_steps
return default_config
def update_preview_video(checkpoint_dir):
# get the parent dir of the checkpoint
parent_dir = '/'.join(checkpoint_dir.split('/')[:-1])
return gr.update(value=f'results/{parent_dir}/source.mp4')
if __name__ == "__main__":
inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640']
default_motion_embeddings_combinations = ['down 1280','up 1280']
examples_train = [
'assets/train/car_turn.mp4',
'assets/train/pan_up.mp4',
'assets/train/run_up.mp4',
'assets/train/train_ride.mp4',
'assets/train/orbit_shot.mp4',
'assets/train/dolly_zoom_out.mp4',
'assets/train/santa_dance.mp4',
]
examples_inference = [
['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'],
['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint-100'],
['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint-300'],
['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'],
['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint-200'],
['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'],
['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint-200'],
]
# 创建Gradio界面
with gr.Blocks() as demo:
with gr.Tab("Train"):
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video")
train_button = gr.Button("Train")
with gr.Column():
checkpoint_output = gr.Textbox(label="Checkpoint Directory")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
motion_embeddings_combinations = gr.Dropdown(label="Motion Embeddings Combinations", choices=inject_motion_embeddings_combinations, multiselect=True,value=default_motion_embeddings_combinations)
unet_dropdown = gr.Dropdown(label="Unet", choices=["videoCrafter2", "zeroscope_v2_576w"], value="videoCrafter2")
checkpointing_steps = gr.Dropdown(label="Checkpointing Steps",choices=[100,50],value=100)
max_train_steps = gr.Slider(label="Max Train Steps", minimum=200,maximum=500,value=200,step=50)
# examples
gr.Examples(examples=examples_train,inputs=[video_input])
train_button.click(
lambda video, mec, u, cs, mts: train_model(video, generate_config_train(mec, u, cs, mts)),
inputs=[video_input, motion_embeddings_combinations, unet_dropdown, checkpointing_steps, max_train_steps],
outputs=checkpoint_output
)
with gr.Tab("Inference"):
with gr.Row():
with gr.Column():
preview_video = gr.Video(label="Preview Video")
text_input = gr.Textbox(label="Input Text")
checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results'))
seed = gr.Number(label="Seed", value=0)
inference_button = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Output Video")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
inference_steps = gr.Number(label="Inference Steps", value=30)
motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object")
gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown])
def update_checkpoints(checkpoint_dir):
return gr.update(choices=get_checkpoints('results'))
checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video)
checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown)
inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)
# 启动Gradio界面
demo.launch()