File size: 6,330 Bytes
be791d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import torch
import argparse
import torchvision
from pipeline_videogen import VideoGenPipeline
from diffusers.schedulers import DDIMScheduler
from diffusers.models import AutoencoderKL
from diffusers.models import AutoencoderKLTemporalDecoder
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from models import get_models
import imageio
from PIL import Image
import numpy as np
from datasets import video_transforms
from torchvision import transforms
from einops import rearrange, repeat
from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
from copy import deepcopy
def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image)
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
image = image.unsqueeze(2)
return image
def main(args):
if args.seed:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 # torch.float16
unet = get_models(args).to(device, dtype=dtype)
if args.enable_vae_temporal_decoder:
if args.use_dct:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
else:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
else:
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
# set eval mode
unet.eval()
vae.eval()
text_encoder.eval()
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
unet=unet).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
# videogen_pipeline.enable_vae_slicing()
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
transform_video = video_transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), # center crop using shor edge, then resize
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
for i, (image, prompt) in enumerate(args.image_prompts):
if args.use_dct:
base_content = prepare_image("./animated_images/" + image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
else:
base_content = prepare_image("./animated_images/" + image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
if args.use_dct:
# filter params
print("Using DCT!")
base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
# define filter
freq_filter = dct_low_pass_filter(dct_coefficients=base_content,
percentage=0.23)
noise = torch.randn(1, 4, 15, 40, 64).to(device)
# add noise to base_content
diffuse_timesteps = torch.full((1,),int(975))
diffuse_timesteps = diffuse_timesteps.long()
# 3d content
base_content_noise = scheduler.add_noise(
original_samples=base_content_repeat.to(device),
noise=noise,
timesteps=diffuse_timesteps.to(device))
# 3d content
latents = exchanged_mixed_dct_freq(noise=noise,
base_content=base_content_noise,
LPF_3d=freq_filter).to(dtype=torch.float16)
base_content = base_content.to(dtype=torch.float16)
videos = videogen_pipeline(prompt,
latents=latents if args.use_dct else None,
base_content=base_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % i + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
args = parser.parse_args()
main(OmegaConf.load(args.config))
|