VidToMe / generate.py
jadechoghari's picture
Update generate.py
b8ac6f6 verified
import torch.nn as nn
import torch
import numpy as np
from tqdm import tqdm
import os
from transformers import logging
from .controlnet_utils import CONTROLNET_DICT
from .config_utils import load_config, save_config
from .utils import get_controlnet_kwargs, get_frame_ids, get_latents_dir, init_model, seed_everything
from .utils import prepare_control, load_latent, load_video, prepare_depth, save_video
from .pnp_utils import register_time, register_attention_control, register_conv_control
# will cause an issue
# from . import vidtome
from .patch import apply_patch, update_patch
# suppress partial model loading warning
logging.set_verbosity_error()
class Generator(nn.Module):
def __init__(self, pipe, scheduler, config):
super().__init__()
self.device = config.device
self.seed = config.seed
self.model_key = config.model_key
self.config = config
gene_config = config.generation
float_precision = gene_config.float_precision if "float_precision" in gene_config else config.float_precision
if float_precision == "fp16":
self.dtype = torch.float16
print("[INFO] float precision fp16. Use torch.float16.")
else:
self.dtype = torch.float32
print("[INFO] float precision fp32. Use torch.float32.")
self.pipe = pipe
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.unet = pipe.unet
self.text_encoder = pipe.text_encoder
if config.enable_xformers_memory_efficient_attention:
try:
pipe.enable_xformers_memory_efficient_attention()
except ModuleNotFoundError:
print("[WARNING] xformers not found. Disable xformers attention.")
self.n_timesteps = gene_config.n_timesteps
scheduler.set_timesteps(gene_config.n_timesteps, device=self.device)
self.scheduler = scheduler
self.batch_size = 2
self.control = gene_config.control
self.use_depth = config.sd_version == "depth"
self.use_controlnet = self.control in CONTROLNET_DICT.keys()
self.use_pnp = self.control == "pnp"
if self.use_controlnet:
self.controlnet = pipe.controlnet
self.controlnet_scale = gene_config.control_scale
elif self.use_pnp:
pnp_f_t = int(gene_config.n_timesteps * gene_config.pnp_f_t)
pnp_attn_t = int(gene_config.n_timesteps * gene_config.pnp_attn_t)
self.batch_size += 1
self.init_pnp(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
self.chunk_size = gene_config.chunk_size
self.chunk_ord = gene_config.chunk_ord
self.merge_global = gene_config.merge_global
self.local_merge_ratio = gene_config.local_merge_ratio
self.global_merge_ratio = gene_config.global_merge_ratio
self.global_rand = gene_config.global_rand
self.align_batch = gene_config.align_batch
self.prompt = gene_config.prompt
self.negative_prompt = gene_config.negative_prompt
self.guidance_scale = gene_config.guidance_scale
self.save_frame = gene_config.save_frame
self.frame_height, self.frame_width = config.height, config.width
self.work_dir = config.work_dir
self.chunk_ord = gene_config.chunk_ord
if "mix" in self.chunk_ord:
self.perm_div = float(self.chunk_ord.split("-")[-1]) if "-" in self.chunk_ord else 3.
self.chunk_ord = "mix"
# Patch VidToMe to model
self.activate_vidtome()
if gene_config.use_lora:
self.pipe.load_lora_weights(**gene_config.lora)
def activate_vidtome(self):
apply_patch(self.pipe, self.local_merge_ratio, self.merge_global, self.global_merge_ratio,
seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
@torch.no_grad()
def get_text_embeds_input(self, prompt, negative_prompt):
text_embeds = self.get_text_embeds(
prompt, negative_prompt, self.device)
if self.use_pnp:
pnp_guidance_embeds = self.get_text_embeds("", device=self.device)
text_embeds = torch.cat(
[pnp_guidance_embeds, text_embeds], dim=0)
return text_embeds
@torch.no_grad()
def get_text_embeds(self, prompt, negative_prompt=None, device="cuda"):
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors='pt')
text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
if negative_prompt is not None:
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
return_tensors='pt')
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
@torch.no_grad()
def prepare_data(self, data_path, latent_path, frame_ids):
self.frames = load_video(data_path, self.frame_height,
self.frame_width, frame_ids=frame_ids, device=self.device)
self.init_noise = load_latent(
latent_path, t=self.scheduler.timesteps[0], frame_ids=frame_ids).to(self.dtype).to(self.device)
if self.use_depth:
self.depths = prepare_depth(
self.pipe, self.frames, frame_ids, self.work_dir).to(self.init_noise)
if self.use_controlnet:
self.controlnet_images = prepare_control(
self.control, self.frames, frame_ids, self.work_dir).to(self.init_noise)
@torch.no_grad()
def decode_latents(self, latents):
with torch.autocast(device_type=self.device, dtype=self.dtype):
latents = 1 / 0.18215 * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
@torch.no_grad()
def decode_latents_batch(self, latents):
imgs = []
batch_latents = latents.split(self.batch_size, dim=0)
for latent in batch_latents:
imgs += [self.decode_latents(latent)]
imgs = torch.cat(imgs)
return imgs
@torch.no_grad()
def encode_imgs(self, imgs):
with torch.autocast(device_type=self.device, dtype=self.dtype):
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.mean * 0.18215
return latents
@torch.no_grad()
def encode_imgs_batch(self, imgs):
latents = []
batch_imgs = imgs.split(self.batch_size, dim=0)
for img in batch_imgs:
latents += [self.encode_imgs(img)]
latents = torch.cat(latents)
return latents
def get_chunks(self, flen):
x_index = torch.arange(flen)
# The first chunk has a random length
rand_first = np.random.randint(0, self.chunk_size) + 1
chunks = x_index[rand_first:].split(self.chunk_size, dim=0)
chunks = [x_index[:rand_first]] + list(chunks) if len(chunks[0]) > 0 else [x_index[:rand_first]]
if np.random.rand() > 0.5:
chunks = chunks[::-1]
# Chunk order only matter when we do global token merging
if self.merge_global == False:
return chunks
# Chunk order. "seq": sequential order. "rand": full permutation. "mix": partial permutation.
if self.chunk_ord == "rand":
order = torch.randperm(len(chunks))
elif self.chunk_ord == "mix":
randord = torch.randperm(len(chunks)).tolist()
rand_len = int(len(randord) / self.perm_div)
seqord = sorted(randord[rand_len:])
if rand_len > 0:
randord = randord[:rand_len]
if abs(seqord[-1] - randord[-1]) < abs(seqord[0] - randord[-1]):
seqord = seqord[::-1]
order = randord + seqord
else:
order = seqord
else:
order = torch.arange(len(chunks))
chunks = [chunks[i] for i in order]
return chunks
@torch.no_grad()
def ddim_sample(self, x, conds):
print("[INFO] denoising frames...")
timesteps = self.scheduler.timesteps
noises = torch.zeros_like(x)
for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
self.pre_iter(x, t)
# Split video into chunks and denoise
chunks = self.get_chunks(len(x))
for chunk in chunks:
torch.cuda.empty_cache()
noises[chunk] = self.pred_noise(
x[chunk], conds, t, batch_idx=chunk)
x = self.pred_next_x(x, noises, t, i, inversion=False)
self.post_iter(x, t)
return x
def pre_iter(self, x, t):
if self.use_pnp:
# Prepare PnP
register_time(self, t.item())
cur_latents = load_latent(self.latent_path, t=t, frame_ids = self.frame_ids)
self.cur_latents = cur_latents
def post_iter(self, x, t):
if self.merge_global:
# Reset global tokens
update_patch(self.pipe, global_tokens = None)
@torch.no_grad()
def pred_noise(self, x, cond, t, batch_idx=None):
flen = len(x)
text_embed_input = cond.repeat_interleave(flen, dim=0)
# For classifier-free guidance
latent_model_input = torch.cat([x, x])
batch_size = 2
if self.use_pnp:
# Cat latents from inverted source frames for PnP operation
source_latents = self.cur_latents
if batch_idx is not None:
source_latents = source_latents[batch_idx]
latent_model_input = torch.cat([source_latents.to(x), latent_model_input])
batch_size += 1
# For sd-depth model
if self.use_depth:
depth = self.depths
if batch_idx is not None:
depth = depth[batch_idx]
depth = depth.repeat(batch_size, 1, 1, 1)
latent_model_input = torch.cat([latent_model_input, depth.to(x)], dim=1)
kwargs = dict()
# Compute controlnet outputs
if self.use_controlnet:
controlnet_cond = self.controlnet_images
if batch_idx is not None:
controlnet_cond = controlnet_cond[batch_idx]
controlnet_cond = controlnet_cond.repeat(batch_size, 1, 1, 1)
controlnet_kwargs = get_controlnet_kwargs(
self.controlnet, latent_model_input, text_embed_input, t, controlnet_cond, self.controlnet_scale)
kwargs.update(controlnet_kwargs)
# Pred noise!
eps = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input, **kwargs).sample
noise_pred_uncond, noise_pred_cond = eps.chunk(batch_size)[-2:]
# CFG
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
return noise_pred
@torch.no_grad()
def pred_next_x(self, x, eps, t, i, inversion=False):
if inversion:
timesteps = reversed(self.scheduler.timesteps)
else:
timesteps = self.scheduler.timesteps
alpha_prod_t = self.scheduler.alphas_cumprod[t]
if inversion:
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[timesteps[i - 1]]
if i > 0 else self.scheduler.final_alpha_cumprod
)
else:
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[timesteps[i + 1]]
if i < len(timesteps) - 1
else self.scheduler.final_alpha_cumprod
)
mu = alpha_prod_t ** 0.5
sigma = (1 - alpha_prod_t) ** 0.5
mu_prev = alpha_prod_t_prev ** 0.5
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
if inversion:
pred_x0 = (x - sigma_prev * eps) / mu_prev
x = mu * pred_x0 + sigma * eps
else:
pred_x0 = (x - sigma * eps) / mu
x = mu_prev * pred_x0 + sigma_prev * eps
return x
def init_pnp(self, conv_injection_t, qk_injection_t):
qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
register_attention_control(
self, qk_injection_timesteps, num_inputs=self.batch_size)
register_conv_control(
self, conv_injection_timesteps, num_inputs=self.batch_size)
def check_latent_exists(self, latent_path):
if self.use_pnp:
timesteps = self.scheduler.timesteps
else:
timesteps = [self.scheduler.timesteps[0]]
for ts in timesteps:
cur_latent_path = os.path.join(
latent_path, f'noisy_latents_{ts}.pt')
if not os.path.exists(cur_latent_path):
return False
return True
@torch.no_grad()
def __call__(self, data_path, latent_path, output_path, frame_ids):
self.scheduler.set_timesteps(self.n_timesteps)
latent_path = get_latents_dir(latent_path, self.model_key)
assert self.check_latent_exists(
latent_path), f"Required latent not found at {latent_path}. \
Note: If using PnP as control, you need inversion latents saved \
at each generation timestep."
self.data_path = data_path
self.latent_path = latent_path
self.frame_ids = frame_ids
self.prepare_data(data_path, latent_path, frame_ids)
print(f"[INFO] initial noise latent shape: {self.init_noise.shape}")
for edit_name, edit_prompt in self.prompt.items():
print(f"[INFO] current prompt: {edit_prompt}")
conds = self.get_text_embeds_input(edit_prompt, self.negative_prompt)
# Comment this if you have enough GPU memory
clean_latent = self.ddim_sample(self.init_noise, conds)
torch.cuda.empty_cache()
clean_frames = self.decode_latents_batch(clean_latent)
# cur_output_path = os.path.join(output_path, edit_name)
cur_output_path = os.path.join("./", edit_name)
save_config(self.config, cur_output_path, gene = True)
save_video(clean_frames, cur_output_path, save_frame = self.save_frame)
if __name__ == "__main__":
config = load_config()
pipe, scheduler, model_key = init_model(
config.device, config.sd_version, config.model_key, config.generation.control, config.float_precision)
config.model_key = model_key
seed_everything(config.seed)
generator = Generator(pipe, scheduler, config)
frame_ids = get_frame_ids(
config.generation.frame_range, config.generation.frame_ids)
generator(config.input_path, config.generation.latents_path,
config.generation.output_path, frame_ids=frame_ids)