Spaces:
Build error
Build error
File size: 5,999 Bytes
0757c55 fe42b63 0757c55 fe42b63 0757c55 fe42b63 0757c55 fe42b63 0757c55 fe42b63 0757c55 2288e12 0757c55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import argparse, os, sys, glob
import datetime, time
from omegaconf import OmegaConf
import math
import torch
from decord import VideoReader, cpu
import torchvision
from pytorch_lightning import seed_everything
from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import instantiate_from_config
from lvdm.utils.saving_utils import tensor_to_mp4
from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis
import torchvision.transforms._transforms_video as transforms_video
from huggingface_hub import hf_hub_download
def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
info_str = ''
vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
max_frames = len(vidreader)
# auto
if frame_stride != 0:
if frame_stride * (video_frames-1) >= max_frames:
info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n"
frame_stride = 0
if frame_stride == 0:
frame_stride = max_frames / video_frames
# if temp_stride < 1:
# info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n"
if frame_stride > 8:
frame_stride = 8
info_str += "Warning: The current input video length is longer than 128 frames, we will process only the first 128 frames.\n"
info_str += f"Frame Stride is set to {frame_stride}"
frame_indices = [int(frame_stride*i) for i in range(video_frames)]
frames = vidreader.get_batch(frame_indices)
## [t,h,w,c] -> [c,t,h,w]
frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
frame_tensor = (frame_tensor / 255. - 0.5) * 2
return frame_tensor, info_str
class VideoControl:
def __init__(self, result_dir='./tmp/') -> None:
self.savedir = result_dir
self.download_model()
config_path = "models/adapter_t2v_depth/model_config.yaml"
ckpt_path = "models/base_t2v/model.ckpt"
adapter_ckpt = "models/adapter_t2v_depth/adapter.pth"
config = OmegaConf.load(config_path)
model_config = config.pop("model", OmegaConf.create())
model = instantiate_from_config(model_config)
model = model.to('cuda')
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
model.eval()
self.model = model
self.resolution=256
self.spatial_transform = transforms_video.CenterCropVideo(self.resolution)
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0):
## load video
print("input video", input_video)
info_str = ''
try:
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
except:
os.remove(input_video)
return 'please input video', None, None, None
if h < w:
scale = h / self.resolution
else:
scale = w / self.resolution
h = math.ceil(h / scale)
w = math.ceil(w / scale)
try:
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
except:
os.remove(input_video)
return 'load video error', None, None, None
video = self.spatial_transform(video)
print('video shape', video.shape)
h, w = 32, 32
bs = 1
channels = self.model.channels
frames = self.model.temporal_length
noise_shape = [bs, channels, frames, h, w]
## inference
start = time.time()
prompt = input_prompt
video = video.unsqueeze(0).to("cuda")
with torch.no_grad():
batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
batch_samples = batch_samples[0]
os.makedirs(self.savedir, exist_ok=True)
filename = prompt
filename = filename.replace("/", "_slash_") if "/" in filename else filename
filename = filename.replace(" ", "_") if " " in filename else filename
if len(filename) > 200:
filename = filename[:200]
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4')
origin_path = os.path.join(self.savedir, f'{filename}.mp4')
tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8)
tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8)
tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8)
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
# delete video
(path, input_filename) = os.path.split(input_video)
if input_filename != 'flamingo.mp4':
os.remove(input_video)
print('delete input video')
# print(input_video)
return info_str, origin_path, depth_path, video_path
def download_model(self):
REPO_ID = 'VideoCrafter/t2v-version-1-1'
filename_list = ['models/base_t2v/model.ckpt',
"models/adapter_t2v_depth/adapter.pth",
"models/adapter_t2v_depth/dpt_hybrid-midas.pt"
]
for filename in filename_list:
if not os.path.exists(filename):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)
if __name__ == "__main__":
vc = VideoControl('./result')
info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k") |