VidToMe / utils.py
jadechoghari's picture
Update utils.py
ecbaf30 verified
import contextlib
import random
import numpy as np
import os
from glob import glob
from PIL import Image, ImageSequence
import torch
from torchvision.io import read_video, write_video
import torchvision.transforms as T
from diffusers import DDIMScheduler, StableDiffusionControlNetPipeline, StableDiffusionPipeline, StableDiffusionDepth2ImgPipeline, ControlNetModel
from .controlnet_utils import CONTROLNET_DICT, control_preprocess
from einops import rearrange
FRAME_EXT = [".jpg", ".png"]
def init_model(device="cuda", sd_version="1.5", model_key=None, control_type="none", weight_dtype="fp16"):
use_depth = False
if model_key is None:
if sd_version == '2.1':
model_key = "stabilityai/stable-diffusion-2-1-base"
elif sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
elif sd_version == 'depth':
model_key = "stabilityai/stable-diffusion-2-depth"
use_depth = True
else:
raise ValueError(
f'Stable-diffusion version {sd_version} not supported.')
print(f'[INFO] loading stable diffusion from: {model_key}')
else:
print(f'[INFO] loading custome model from: {model_key}')
scheduler = DDIMScheduler.from_pretrained(
model_key, subfolder="scheduler")
if weight_dtype == "fp16":
weight_dtype = torch.float16
else:
weight_dtype = torch.float32
if control_type not in ["none", "pnp"]:
controlnet_key = CONTROLNET_DICT[control_type]
print(f'[INFO] loading controlnet from: {controlnet_key}')
controlnet = ControlNetModel.from_pretrained(
controlnet_key, torch_dtype=weight_dtype)
print(f'[INFO] loaded controlnet!')
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_key, controlnet=controlnet, torch_dtype=weight_dtype
)
elif use_depth:
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
model_key, torch_dtype=weight_dtype
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
# model_key, torch_dtype=weight_dtype
model_key, torch_dtype=weight_dtype,
)
return pipe.to(device), scheduler, model_key
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = T.ToTensor()(image)
return image.unsqueeze(0)
def process_frames(frames, h, w):
fh, fw = frames.shape[-2:]
h = int(np.floor(h / 64.0)) * 64
w = int(np.floor(w / 64.0)) * 64
nw = int(fw / fh * h)
if nw >= w:
size = (h, nw)
else:
size = (int(fh / fw * w), w)
assert len(frames.shape) >= 3
if len(frames.shape) == 3:
frames = [frames]
print(
f"[INFO] frame size {(fh, fw)} resize to {size} and centercrop to {(h, w)}")
frame_ls = []
for frame in frames:
resized_frame = T.Resize(size)(frame)
cropped_frame = T.CenterCrop([h, w])(resized_frame)
# croped_frame = T.FiveCrop([h, w])(resized_frame)[0]
frame_ls.append(cropped_frame)
return torch.stack(frame_ls)
def glob_frame_paths(video_path):
frame_paths = []
for ext in FRAME_EXT:
frame_paths += glob(os.path.join(video_path, f"*{ext}"))
frame_paths = sorted(frame_paths)
return frame_paths
def load_video(video_path, h, w, frame_ids=None, device="cuda"):
if ".mp4" in video_path:
frames, _, _ = read_video(
video_path, output_format="TCHW", pts_unit="sec")
frames = frames / 255
elif ".gif" in video_path:
frames = Image.open(video_path)
frame_ls = []
for frame in ImageSequence.Iterator(frames):
frame_ls += [T.ToTensor()(frame.convert("RGB"))]
frames = torch.stack(frame_ls)
else:
frame_paths = glob_frame_paths(video_path)
frame_ls = []
for frame_path in frame_paths:
frame = load_image(frame_path)
frame_ls.append(frame)
frames = torch.cat(frame_ls)
if frame_ids is not None:
frames = frames[frame_ids]
print(f"[INFO] loaded video with {len(frames)} frames from: {video_path}")
frames = process_frames(frames, h, w)
return frames.to(device)
def save_video(frames: torch.Tensor, path, frame_ids=None, save_frame=False):
os.makedirs(path, exist_ok=True)
if frame_ids is None:
frame_ids = [i for i in range(len(frames))]
frames = frames[frame_ids]
proc_frames = (rearrange(frames, "T C H W -> T H W C") * 255).to(torch.uint8).cpu()
write_video(os.path.join(path, "output.mp4"), proc_frames, fps = 30, video_codec="h264")
print(f"[INFO] save video to {os.path.join(path, 'output.mp4')}")
if save_frame:
save_frames(frames, os.path.join(path, "frames"), frame_ids = frame_ids)
def save_frames(frames: torch.Tensor, path, ext="png", frame_ids=None):
os.makedirs(path, exist_ok=True)
if frame_ids is None:
frame_ids = [i for i in range(len(frames))]
for i, frame in zip(frame_ids, frames):
T.ToPILImage()(frame).save(
os.path.join(path, '{:04}.{}'.format(i, ext)))
def load_latent(latent_path, t, frame_ids=None):
latent_fname = f'noisy_latents_{t}.pt'
lp = os.path.join(latent_path, latent_fname)
assert os.path.exists(
lp), f"Latent at timestep {t} not found in {latent_path}."
latents = torch.load(lp)
if frame_ids is not None:
latents = latents[frame_ids]
# print(f"[INFO] loaded initial latent from {lp}")
return latents
@torch.no_grad()
def prepare_depth(pipe, frames, frame_ids, work_dir):
depth_ls = []
depth_dir = os.path.join(work_dir, "depth")
os.makedirs(depth_dir, exist_ok=True)
for frame, frame_id in zip(frames, frame_ids):
depth_path = os.path.join(depth_dir, "{:04}.pt".format(frame_id))
depth = load_depth(pipe, depth_path, frame)
depth_ls += [depth]
print(f"[INFO] loaded depth images from {depth_path}")
return torch.cat(depth_ls)
# From pix2video: code/file_utils.py
def load_depth(model, depth_path, input_image, dtype=torch.float32):
if os.path.exists(depth_path):
depth_map = torch.load(depth_path)
else:
input_image = T.ToPILImage()(input_image.squeeze())
depth_map = prepare_depth_map(
model, input_image, dtype=dtype, device=model.device)
torch.save(depth_map, depth_path)
depth_image = (((depth_map + 1.0) / 2.0) * 255).to(torch.uint8)
T.ToPILImage()(depth_image.squeeze()).convert(
"L").save(depth_path.replace(".pt", ".png"))
return depth_map
@torch.no_grad()
def prepare_depth_map(model, image, depth_map=None, batch_size=1, do_classifier_free_guidance=False, dtype=torch.float32, device="cuda"):
if isinstance(image, Image.Image):
image = [image]
else:
image = list(image)
if isinstance(image[0], Image.Image):
width, height = image[0].size
elif isinstance(image[0], np.ndarray):
width, height = image[0].shape[:-1]
else:
height, width = image[0].shape[-2:]
if depth_map is None:
pixel_values = model.feature_extractor(
images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device=device)
# The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
# So we use `torch.autocast` here for half precision inference.
context_manger = torch.autocast(
"cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
with context_manger:
ret = model.depth_estimator(pixel_values)
depth_map = ret.predicted_depth
# depth_image = ret.depth
else:
depth_map = depth_map.to(device=device, dtype=dtype)
indices = depth_map != -1
bg_indices = depth_map == -1
min_d = depth_map[indices].min()
if bg_indices.sum() > 0:
depth_map[bg_indices] = min_d - 10
# min_d = min_d - 10
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(height // model.vae_scale_factor,
width // model.vae_scale_factor),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
depth_map = depth_map.to(dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if depth_map.shape[0] < batch_size:
repeat_by = batch_size // depth_map.shape[0]
depth_map = depth_map.repeat(repeat_by, 1, 1, 1)
depth_map = torch.cat(
[depth_map] * 2) if do_classifier_free_guidance else depth_map
return depth_map
def get_latents_dir(latents_path, model_key):
model_key = model_key.split("/")[-1]
return os.path.join(latents_path, model_key)
def get_controlnet_kwargs(controlnet, x, cond, t, controlnet_cond, controlnet_scale=1.0):
down_block_res_samples, mid_block_res_sample = controlnet(
x,
t,
encoder_hidden_states=cond,
controlnet_cond=controlnet_cond,
return_dict=False,
)
down_block_res_samples = [
down_block_res_sample * controlnet_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_scale
controlnet_kwargs = {"down_block_additional_residuals": down_block_res_samples,
"mid_block_additional_residual": mid_block_res_sample}
return controlnet_kwargs
def get_frame_ids(frame_range, frame_ids=None):
if frame_ids is None:
frame_ids = list(range(*frame_range))
frame_ids = sorted(frame_ids)
if len(frame_ids) > 4:
frame_ids_str = "{} {} ... {} {}".format(
*frame_ids[:2], *frame_ids[-2:])
else:
frame_ids_str = " ".join(["{}"] * len(frame_ids)).format(*frame_ids)
print("[INFO] frame indexes: ", frame_ids_str)
return frame_ids
def prepare_control(control, frames, frame_ids, save_path):
if control not in CONTROLNET_DICT.keys():
print(f"[WARNING] unknown controlnet type {control}")
return None
control_subdir = f'{save_path}/{control}_image'
preprocess_flag = True
if os.path.exists(control_subdir):
print(f"[INFO] load control image from {control_subdir}.")
control_image_ls = []
for frame_id in frame_ids:
image_path = os.path.join(
control_subdir, "{:04}.png".format(frame_id))
if not os.path.exists(image_path):
break
control_image_ls += [load_image(image_path)]
else:
preprocess_flag = False
control_images = torch.cat(control_image_ls)
if preprocess_flag:
print("[INFO] preprocessing control images...")
control_images = control_preprocess(frames, control)
print(f"[INFO] save control images to {control_subdir}.")
os.makedirs(control_subdir, exist_ok=True)
for image, frame_id in zip(control_images, frame_ids):
image_path = os.path.join(
control_subdir, "{:04}.png".format(frame_id))
T.ToPILImage()(image).save(image_path)
return control_images
def isinstance_str(x: object, cls_name: str):
"""
Checks whether x has any class *named* cls_name in its ancestry.
Doesn't require access to the class's implementation.
Useful for patching!
"""
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False
def init_generator(device: torch.device, fallback: torch.Generator=None):
"""
Forks the current default random generator given device.
"""
if device.type == "cpu":
return torch.Generator(device="cpu").set_state(torch.get_rng_state())
elif device.type == "cuda":
return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
else:
if fallback is None:
return init_generator(torch.device("cpu"))
else:
return fallback
def join_frame(x, fsize):
""" Join multi-frame tokens """
x = rearrange(x, "(B F) N C -> B (F N) C", F=fsize)
return x
def split_frame(x, fsize):
""" Split multi-frame tokens """
x = rearrange(x, "B (F N) C -> (B F) N C", F=fsize)
return x
def func_warper(funcs):
""" Warp a function sequence """
def fn(x, **kwarg):
for func in funcs:
x = func(x, **kwarg)
return x
return fn
def join_warper(fsize):
def fn(x):
x = join_frame(x, fsize)
return x
return fn
def split_warper(fsize):
def fn(x):
x = split_frame(x, fsize)
return x
return fn