Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import torch | |
import tempfile | |
import random | |
import string | |
import json | |
from omegaconf import OmegaConf,ListConfig | |
from train import main as train_main | |
from inference import inference as inference_main | |
# 模拟训练函数 | |
def train_model(video, config): | |
output_dir = 'results' | |
os.makedirs(output_dir, exist_ok=True) | |
cur_save_dir = os.path.join(output_dir, str(len(os.listdir(output_dir))).zfill(2)) | |
config.dataset.single_video_path = video | |
config.train.output_dir = cur_save_dir | |
# copy video to cur_save_dir | |
video_name = 'source.mp4' | |
video_path = os.path.join(cur_save_dir, video_name) | |
os.system(f"cp {video} {video_path}") | |
train_main(config) | |
# cur_save_dir = 'results/06' | |
return cur_save_dir | |
# 模拟推理函数 | |
def inference_model(text, checkpoint, inference_steps, video_type,seed): | |
checkpoint = os.path.join('results',checkpoint) | |
embedding_dir = '/'.join(checkpoint.split('/')[:-1]) | |
video_round = checkpoint.split('/')[-1] | |
video_path = inference_main( | |
embedding_dir=embedding_dir, | |
prompt=text, | |
video_round=video_round, | |
save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]), | |
motion_type=video_type, | |
seed=seed, | |
inference_steps=inference_steps | |
) | |
return video_path | |
# 获取checkpoint文件列表 | |
def get_checkpoints(checkpoint_dir): | |
checkpoints = [] | |
for root, dirs, files in os.walk(checkpoint_dir): | |
for file in files: | |
if file == 'motion_embed.pt': | |
checkpoints.append('/'.join(root.split('/')[-2:])) | |
return checkpoints | |
def extract_combinations(motion_embeddings_combinations): | |
assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required" | |
combinations = [] | |
for combination in motion_embeddings_combinations: | |
name, resolution = combination.split(" ") | |
combinations.append([name, int(resolution)]) | |
return combinations | |
def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
default_config = OmegaConf.load('configs/config.yaml') | |
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
default_config.model.unet = unet | |
default_config.train.checkpointing_steps = checkpointing_steps | |
default_config.train.max_train_steps = max_train_steps | |
return default_config | |
def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
default_config = OmegaConf.load('configs/config.yaml') | |
default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
default_config.model.unet = unet | |
default_config.train.checkpointing_steps = checkpointing_steps | |
default_config.train.max_train_steps = max_train_steps | |
return default_config | |
def update_preview_video(checkpoint_dir): | |
# get the parent dir of the checkpoint | |
parent_dir = '/'.join(checkpoint_dir.split('/')[:-1]) | |
return gr.update(value=f'results/{parent_dir}/source.mp4') | |
if __name__ == "__main__": | |
inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640'] | |
default_motion_embeddings_combinations = ['down 1280','up 1280'] | |
examples_train = [ | |
'assets/train/car_turn.mp4', | |
'assets/train/pan_up.mp4', | |
'assets/train/run_up.mp4', | |
'assets/train/train_ride.mp4', | |
'assets/train/orbit_shot.mp4', | |
'assets/train/dolly_zoom_out.mp4', | |
'assets/train/santa_dance.mp4', | |
] | |
examples_inference = [ | |
['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'], | |
['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint-100'], | |
['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint-300'], | |
['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'], | |
['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint-200'], | |
['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'], | |
['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint-200'], | |
] | |
# 创建Gradio界面 | |
with gr.Blocks() as demo: | |
with gr.Tab("Train"): | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Upload Video") | |
train_button = gr.Button("Train") | |
with gr.Column(): | |
checkpoint_output = gr.Textbox(label="Checkpoint Directory") | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
motion_embeddings_combinations = gr.Dropdown(label="Motion Embeddings Combinations", choices=inject_motion_embeddings_combinations, multiselect=True,value=default_motion_embeddings_combinations) | |
unet_dropdown = gr.Dropdown(label="Unet", choices=["videoCrafter2", "zeroscope_v2_576w"], value="videoCrafter2") | |
checkpointing_steps = gr.Dropdown(label="Checkpointing Steps",choices=[100,50],value=100) | |
max_train_steps = gr.Slider(label="Max Train Steps", minimum=200,maximum=500,value=200,step=50) | |
# examples | |
gr.Examples(examples=examples_train,inputs=[video_input]) | |
train_button.click( | |
lambda video, mec, u, cs, mts: train_model(video, generate_config_train(mec, u, cs, mts)), | |
inputs=[video_input, motion_embeddings_combinations, unet_dropdown, checkpointing_steps, max_train_steps], | |
outputs=checkpoint_output | |
) | |
with gr.Tab("Inference"): | |
with gr.Row(): | |
with gr.Column(): | |
preview_video = gr.Video(label="Preview Video") | |
text_input = gr.Textbox(label="Input Text") | |
checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results')) | |
seed = gr.Number(label="Seed", value=0) | |
inference_button = gr.Button("Generate Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video") | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
inference_steps = gr.Number(label="Inference Steps", value=30) | |
motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object") | |
gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown]) | |
def update_checkpoints(checkpoint_dir): | |
return gr.update(choices=get_checkpoints('results')) | |
checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video) | |
checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown) | |
inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video) | |
# 启动Gradio界面 | |
demo.launch() |