|
import os |
|
import uuid |
|
from omegaconf import OmegaConf |
|
import spaces |
|
|
|
import random |
|
|
|
import imageio |
|
import torch |
|
import torchvision |
|
import gradio as gr |
|
import numpy as np |
|
from gradio.components import Textbox, Video |
|
|
|
from utils.common_utils import load_model_checkpoint |
|
from utils.utils import instantiate_from_config |
|
from scheduler.t2v_turbo_scheduler import T2VTurboScheduler |
|
from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline |
|
|
|
DESCRIPTION = """# T2V-Turbo 🚀 |
|
|
|
Our model is distilled from [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2/). |
|
T2V-Turbo learns a LoRA on top of the base model by aligning to the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master) and [InternVid2 Stage 2 Model](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4). |
|
T2V-Turbo-v2 optimizes the training techniques by finetuning the full base model and further aligns to [CLIPScore](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K) |
|
|
|
T2V-Turbo trains on pure WebVid-10M data, whereas T2V-Turbo-v2 carufully optimizes different learning objectives with a mixutre of VidGen-1M and WebVid-10M data. |
|
|
|
Moreover, T2V-Turbo-v2 supports to distill motion priors from the training videos. |
|
|
|
[Project page for T2V-Turbo](https://t2v-turbo.github.io) 😄 |
|
[Project page for T2V-Turbo-v2](https://t2v-turbo-v2.github.io) 🛫 |
|
""" |
|
if torch.cuda.is_available(): |
|
DESCRIPTION += "\n<p>Running on CUDA 😀</p>" |
|
elif hasattr(torch, "xpu") and torch.xpu.is_available(): |
|
DESCRIPTION += "\n<p>Running on XPU 🤓</p>" |
|
else: |
|
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
|
|
def save_video(video_array, video_save_path, fps: int = 16): |
|
video = video_array.detach().cpu() |
|
video = torch.clamp(video.float(), -1.0, 1.0) |
|
video = video.permute(1, 0, 2, 3) |
|
video = (video + 1.0) / 2.0 |
|
video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1) |
|
|
|
torchvision.io.write_video( |
|
video_save_path, video, fps=fps, video_codec="h264", options={"crf": "10"} |
|
) |
|
|
|
example_txt = [ |
|
"An astronaut riding a horse.", |
|
"Darth vader surfing in waves.", |
|
"light wind, feathers moving, she moves her gaze, 4k", |
|
"a girl floating underwater.", |
|
"Pikachu snowboarding.", |
|
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", |
|
"A musician strums his guitar, serenading the moonlit night.", |
|
] |
|
|
|
examples = [[i, 7.5, 0.5, 16, 16, 0, True, "bf16"] for i in example_txt] |
|
|
|
@spaces.GPU(duration=120) |
|
@torch.inference_mode() |
|
def generate( |
|
prompt: str, |
|
guidance_scale: float = 7.5, |
|
percentage: float = 0.5, |
|
num_inference_steps: int = 4, |
|
num_frames: int = 16, |
|
seed: int = 0, |
|
randomize_seed: bool = False, |
|
param_dtype="bf16", |
|
motion_gs: float = 0.05, |
|
fps: int = 8, |
|
): |
|
|
|
seed = randomize_seed_fn(seed, randomize_seed) |
|
torch.manual_seed(seed) |
|
|
|
if param_dtype == "bf16": |
|
dtype = torch.bfloat16 |
|
unet.dtype = torch.bfloat16 |
|
elif param_dtype == "fp16": |
|
dtype = torch.float16 |
|
unet.dtype = torch.float16 |
|
elif param_dtype == "fp32": |
|
dtype = torch.float32 |
|
unet.dtype = torch.float32 |
|
else: |
|
raise ValueError(f"Unknown dtype: {param_dtype}") |
|
|
|
pipeline.unet.to(device, dtype) |
|
pipeline.text_encoder.to(device, dtype) |
|
pipeline.vae.to(device, dtype) |
|
pipeline.to(device, dtype) |
|
|
|
result = pipeline( |
|
prompt=prompt, |
|
frames=num_frames, |
|
fps=fps, |
|
guidance_scale=guidance_scale, |
|
motion_gs=motion_gs, |
|
use_motion_cond=True, |
|
percentage=percentage, |
|
num_inference_steps=num_inference_steps, |
|
lcm_origin_steps=200, |
|
num_videos_per_prompt=1, |
|
) |
|
|
|
torch.cuda.empty_cache() |
|
tmp_save_path = "tmp.mp4" |
|
root_path = "./videos/" |
|
os.makedirs(root_path, exist_ok=True) |
|
video_save_path = os.path.join(root_path, tmp_save_path) |
|
|
|
save_video(result[0], video_save_path, fps=fps) |
|
display_model_info = f"Video size: {num_frames}x320x512, Sampling Step: {num_inference_steps}, Guidance Scale: {guidance_scale}" |
|
return video_save_path, prompt, display_model_info, seed |
|
|
|
|
|
block_css = """ |
|
#buttons button { |
|
min-width: min(120px,100%); |
|
} |
|
""" |
|
|
|
|
|
if __name__ == "__main__": |
|
device = torch.device("cuda:0") |
|
|
|
config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml") |
|
model_config = config.pop("model", OmegaConf.create()) |
|
pretrained_t2v = instantiate_from_config(model_config) |
|
pretrained_t2v = load_model_checkpoint(pretrained_t2v, "checkpoints/VideoCrafter2_model.ckpt") |
|
|
|
unet_config = model_config["params"]["unet_config"] |
|
unet_config["params"]["use_checkpoint"] = False |
|
unet_config["params"]["time_cond_proj_dim"] = 256 |
|
unet_config["params"]["motion_cond_proj_dim"] = 256 |
|
|
|
unet = instantiate_from_config(unet_config) |
|
|
|
unet.load_state_dict(torch.load("checkpoints/unet_mg.pt", map_location=device)) |
|
unet.eval() |
|
|
|
pretrained_t2v.model.diffusion_model = unet |
|
scheduler = T2VTurboScheduler( |
|
linear_start=model_config["params"]["linear_start"], |
|
linear_end=model_config["params"]["linear_end"], |
|
) |
|
pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config) |
|
pipeline.to(device) |
|
|
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=[ |
|
Textbox(label="", placeholder="Please enter your prompt. \n"), |
|
gr.Slider( |
|
label="Guidance scale", |
|
minimum=2, |
|
maximum=14, |
|
step=0.1, |
|
value=7.5, |
|
), |
|
gr.Slider( |
|
label="Percentage of steps to apply motion guidance (v2 w/ MG only)", |
|
minimum=0.0, |
|
maximum=0.5, |
|
step=0.05, |
|
value=0.5, |
|
), |
|
gr.Slider( |
|
label="Number of inference steps", |
|
minimum=4, |
|
maximum=50, |
|
step=1, |
|
value=16, |
|
), |
|
gr.Slider( |
|
label="Number of Video Frames", |
|
minimum=16, |
|
maximum=48, |
|
step=8, |
|
value=16, |
|
), |
|
gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
randomize=True, |
|
), |
|
gr.Checkbox(label="Randomize seed", value=True), |
|
gr.Radio( |
|
["bf16", "fp16", "fp32"], |
|
label="torch.dtype", |
|
value="bf16", |
|
interactive=True, |
|
info="Dtype for inference. Default is bf16.", |
|
) |
|
], |
|
outputs=[ |
|
gr.Video(label="Generated Video", width=512, height=320, interactive=False, autoplay=True), |
|
Textbox(label="input prompt"), |
|
Textbox(label="model info"), |
|
gr.Slider(label="seed"), |
|
], |
|
description=DESCRIPTION, |
|
theme=gr.themes.Default(), |
|
css=block_css, |
|
examples=examples, |
|
cache_examples=False, |
|
concurrency_limit=10, |
|
) |
|
demo.launch() |
|
|